Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions Lib/contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import abc
import os
import sys
import threading
import _collections_abc
from collections import deque
from functools import wraps
Expand Down Expand Up @@ -390,22 +391,76 @@ async def __aexit__(self, *exc_info):
await self.thing.aclose()


class _PerThreadStream:
def __init__(self, default_stream):
self.default_stream = default_stream
# each stack entry is (stream, thread_id). thread_id is None if
# per_thread=False.
self._stack = []
self._lock = threading.Lock()

@property
def _current_stream(self):
thread_id = threading.get_ident()
# look for the most recent redirect which was either:
# * per_thread=False
# * per_thread=True, and in our thread
#
# If none match, fall back to the default stream.
with self._lock:
for stream, entry_thread_id in reversed(self._stack):
if entry_thread_id is None or entry_thread_id == thread_id:
return stream
return self.default_stream

def add_entry(self, entry):
with self._lock:
self._stack.append(entry)

def remove_entry(self, entry):
# remove by identity, not equality, in case two streams compare equal
with self._lock:
for i, e in enumerate(self._stack):
if e is entry:
del self._stack[i]
return

def __getattr__(self, name):
return getattr(self._current_stream, name)


class _RedirectStream(AbstractContextManager):

_stream = None
_lock = None
_stream_ref = None

def __init__(self, new_target):
def __init__(self, new_target, *, per_thread=False):
self._new_target = new_target
# We use a list of old targets to make this CM re-entrant
self._old_targets = []
self._per_thread = per_thread
self._entries = [] # stack for reentrant usage

def __enter__(self):
self._old_targets.append(getattr(sys, self._stream))
setattr(sys, self._stream, self._new_target)
with self._lock:
if self._stream_ref is None:
type(self)._stream_ref = _PerThreadStream(getattr(sys, self._stream))
setattr(sys, self._stream, self._stream_ref)
entry = (
self._new_target,
threading.get_ident() if self._per_thread else None,
)
self._entries.append(entry)
self._stream_ref.add_entry(entry)

return self._new_target

def __exit__(self, exctype, excinst, exctb):
setattr(sys, self._stream, self._old_targets.pop())
with self._lock:
entry = self._entries.pop()
self._stream_ref.remove_entry(entry)
if len(self._stream_ref._stack) == 0:
setattr(sys, self._stream, self._stream_ref.default_stream)
type(self)._stream_ref = None


class redirect_stdout(_RedirectStream):
Expand All @@ -422,12 +477,16 @@ class redirect_stdout(_RedirectStream):
"""

_stream = "stdout"
_lock = threading.Lock()
_stream_ref = None


class redirect_stderr(_RedirectStream):
"""Context manager for temporarily redirecting stderr to another file."""

_stream = "stderr"
_lock = threading.Lock()
_stream_ref = None


class suppress(AbstractContextManager):
Expand Down
Loading
Loading