jpayne@68
|
1 """
|
jpayne@68
|
2 General helpers required for `tqdm.std`.
|
jpayne@68
|
3 """
|
jpayne@68
|
4 import os
|
jpayne@68
|
5 import re
|
jpayne@68
|
6 import sys
|
jpayne@68
|
7 from functools import partial, partialmethod, wraps
|
jpayne@68
|
8 from inspect import signature
|
jpayne@68
|
9 # TODO consider using wcswidth third-party package for 0-width characters
|
jpayne@68
|
10 from unicodedata import east_asian_width
|
jpayne@68
|
11 from warnings import warn
|
jpayne@68
|
12 from weakref import proxy
|
jpayne@68
|
13
|
jpayne@68
|
14 _range, _unich, _unicode, _basestring = range, chr, str, str
|
jpayne@68
|
15 CUR_OS = sys.platform
|
jpayne@68
|
16 IS_WIN = any(CUR_OS.startswith(i) for i in ['win32', 'cygwin'])
|
jpayne@68
|
17 IS_NIX = any(CUR_OS.startswith(i) for i in ['aix', 'linux', 'darwin', 'freebsd'])
|
jpayne@68
|
18 RE_ANSI = re.compile(r"\x1b\[[;\d]*[A-Za-z]")
|
jpayne@68
|
19
|
jpayne@68
|
20 try:
|
jpayne@68
|
21 if IS_WIN:
|
jpayne@68
|
22 import colorama
|
jpayne@68
|
23 else:
|
jpayne@68
|
24 raise ImportError
|
jpayne@68
|
25 except ImportError:
|
jpayne@68
|
26 colorama = None
|
jpayne@68
|
27 else:
|
jpayne@68
|
28 try:
|
jpayne@68
|
29 colorama.init(strip=False)
|
jpayne@68
|
30 except TypeError:
|
jpayne@68
|
31 colorama.init()
|
jpayne@68
|
32
|
jpayne@68
|
33
|
jpayne@68
|
34 def envwrap(prefix, types=None, is_method=False):
|
jpayne@68
|
35 """
|
jpayne@68
|
36 Override parameter defaults via `os.environ[prefix + param_name]`.
|
jpayne@68
|
37 Maps UPPER_CASE env vars map to lower_case param names.
|
jpayne@68
|
38 camelCase isn't supported (because Windows ignores case).
|
jpayne@68
|
39
|
jpayne@68
|
40 Precedence (highest first):
|
jpayne@68
|
41
|
jpayne@68
|
42 - call (`foo(a=3)`)
|
jpayne@68
|
43 - environ (`FOO_A=2`)
|
jpayne@68
|
44 - signature (`def foo(a=1)`)
|
jpayne@68
|
45
|
jpayne@68
|
46 Parameters
|
jpayne@68
|
47 ----------
|
jpayne@68
|
48 prefix : str
|
jpayne@68
|
49 Env var prefix, e.g. "FOO_"
|
jpayne@68
|
50 types : dict, optional
|
jpayne@68
|
51 Fallback mappings `{'param_name': type, ...}` if types cannot be
|
jpayne@68
|
52 inferred from function signature.
|
jpayne@68
|
53 Consider using `types=collections.defaultdict(lambda: ast.literal_eval)`.
|
jpayne@68
|
54 is_method : bool, optional
|
jpayne@68
|
55 Whether to use `functools.partialmethod`. If (default: False) use `functools.partial`.
|
jpayne@68
|
56
|
jpayne@68
|
57 Examples
|
jpayne@68
|
58 --------
|
jpayne@68
|
59 ```
|
jpayne@68
|
60 $ cat foo.py
|
jpayne@68
|
61 from tqdm.utils import envwrap
|
jpayne@68
|
62 @envwrap("FOO_")
|
jpayne@68
|
63 def test(a=1, b=2, c=3):
|
jpayne@68
|
64 print(f"received: a={a}, b={b}, c={c}")
|
jpayne@68
|
65
|
jpayne@68
|
66 $ FOO_A=42 FOO_C=1337 python -c 'import foo; foo.test(c=99)'
|
jpayne@68
|
67 received: a=42, b=2, c=99
|
jpayne@68
|
68 ```
|
jpayne@68
|
69 """
|
jpayne@68
|
70 if types is None:
|
jpayne@68
|
71 types = {}
|
jpayne@68
|
72 i = len(prefix)
|
jpayne@68
|
73 env_overrides = {k[i:].lower(): v for k, v in os.environ.items() if k.startswith(prefix)}
|
jpayne@68
|
74 part = partialmethod if is_method else partial
|
jpayne@68
|
75
|
jpayne@68
|
76 def wrap(func):
|
jpayne@68
|
77 params = signature(func).parameters
|
jpayne@68
|
78 # ignore unknown env vars
|
jpayne@68
|
79 overrides = {k: v for k, v in env_overrides.items() if k in params}
|
jpayne@68
|
80 # infer overrides' `type`s
|
jpayne@68
|
81 for k in overrides:
|
jpayne@68
|
82 param = params[k]
|
jpayne@68
|
83 if param.annotation is not param.empty: # typehints
|
jpayne@68
|
84 for typ in getattr(param.annotation, '__args__', (param.annotation,)):
|
jpayne@68
|
85 try:
|
jpayne@68
|
86 overrides[k] = typ(overrides[k])
|
jpayne@68
|
87 except Exception:
|
jpayne@68
|
88 pass
|
jpayne@68
|
89 else:
|
jpayne@68
|
90 break
|
jpayne@68
|
91 elif param.default is not None: # type of default value
|
jpayne@68
|
92 overrides[k] = type(param.default)(overrides[k])
|
jpayne@68
|
93 else:
|
jpayne@68
|
94 try: # `types` fallback
|
jpayne@68
|
95 overrides[k] = types[k](overrides[k])
|
jpayne@68
|
96 except KeyError: # keep unconverted (`str`)
|
jpayne@68
|
97 pass
|
jpayne@68
|
98 return part(func, **overrides)
|
jpayne@68
|
99 return wrap
|
jpayne@68
|
100
|
jpayne@68
|
101
|
jpayne@68
|
102 class FormatReplace(object):
|
jpayne@68
|
103 """
|
jpayne@68
|
104 >>> a = FormatReplace('something')
|
jpayne@68
|
105 >>> f"{a:5d}"
|
jpayne@68
|
106 'something'
|
jpayne@68
|
107 """ # NOQA: P102
|
jpayne@68
|
108 def __init__(self, replace=''):
|
jpayne@68
|
109 self.replace = replace
|
jpayne@68
|
110 self.format_called = 0
|
jpayne@68
|
111
|
jpayne@68
|
112 def __format__(self, _):
|
jpayne@68
|
113 self.format_called += 1
|
jpayne@68
|
114 return self.replace
|
jpayne@68
|
115
|
jpayne@68
|
116
|
jpayne@68
|
117 class Comparable(object):
|
jpayne@68
|
118 """Assumes child has self._comparable attr/@property"""
|
jpayne@68
|
119 def __lt__(self, other):
|
jpayne@68
|
120 return self._comparable < other._comparable
|
jpayne@68
|
121
|
jpayne@68
|
122 def __le__(self, other):
|
jpayne@68
|
123 return (self < other) or (self == other)
|
jpayne@68
|
124
|
jpayne@68
|
125 def __eq__(self, other):
|
jpayne@68
|
126 return self._comparable == other._comparable
|
jpayne@68
|
127
|
jpayne@68
|
128 def __ne__(self, other):
|
jpayne@68
|
129 return not self == other
|
jpayne@68
|
130
|
jpayne@68
|
131 def __gt__(self, other):
|
jpayne@68
|
132 return not self <= other
|
jpayne@68
|
133
|
jpayne@68
|
134 def __ge__(self, other):
|
jpayne@68
|
135 return not self < other
|
jpayne@68
|
136
|
jpayne@68
|
137
|
jpayne@68
|
138 class ObjectWrapper(object):
|
jpayne@68
|
139 def __getattr__(self, name):
|
jpayne@68
|
140 return getattr(self._wrapped, name)
|
jpayne@68
|
141
|
jpayne@68
|
142 def __setattr__(self, name, value):
|
jpayne@68
|
143 return setattr(self._wrapped, name, value)
|
jpayne@68
|
144
|
jpayne@68
|
145 def wrapper_getattr(self, name):
|
jpayne@68
|
146 """Actual `self.getattr` rather than self._wrapped.getattr"""
|
jpayne@68
|
147 try:
|
jpayne@68
|
148 return object.__getattr__(self, name)
|
jpayne@68
|
149 except AttributeError: # py2
|
jpayne@68
|
150 return getattr(self, name)
|
jpayne@68
|
151
|
jpayne@68
|
152 def wrapper_setattr(self, name, value):
|
jpayne@68
|
153 """Actual `self.setattr` rather than self._wrapped.setattr"""
|
jpayne@68
|
154 return object.__setattr__(self, name, value)
|
jpayne@68
|
155
|
jpayne@68
|
156 def __init__(self, wrapped):
|
jpayne@68
|
157 """
|
jpayne@68
|
158 Thin wrapper around a given object
|
jpayne@68
|
159 """
|
jpayne@68
|
160 self.wrapper_setattr('_wrapped', wrapped)
|
jpayne@68
|
161
|
jpayne@68
|
162
|
jpayne@68
|
163 class SimpleTextIOWrapper(ObjectWrapper):
|
jpayne@68
|
164 """
|
jpayne@68
|
165 Change only `.write()` of the wrapped object by encoding the passed
|
jpayne@68
|
166 value and passing the result to the wrapped object's `.write()` method.
|
jpayne@68
|
167 """
|
jpayne@68
|
168 # pylint: disable=too-few-public-methods
|
jpayne@68
|
169 def __init__(self, wrapped, encoding):
|
jpayne@68
|
170 super().__init__(wrapped)
|
jpayne@68
|
171 self.wrapper_setattr('encoding', encoding)
|
jpayne@68
|
172
|
jpayne@68
|
173 def write(self, s):
|
jpayne@68
|
174 """
|
jpayne@68
|
175 Encode `s` and pass to the wrapped object's `.write()` method.
|
jpayne@68
|
176 """
|
jpayne@68
|
177 return self._wrapped.write(s.encode(self.wrapper_getattr('encoding')))
|
jpayne@68
|
178
|
jpayne@68
|
179 def __eq__(self, other):
|
jpayne@68
|
180 return self._wrapped == getattr(other, '_wrapped', other)
|
jpayne@68
|
181
|
jpayne@68
|
182
|
jpayne@68
|
183 class DisableOnWriteError(ObjectWrapper):
|
jpayne@68
|
184 """
|
jpayne@68
|
185 Disable the given `tqdm_instance` upon `write()` or `flush()` errors.
|
jpayne@68
|
186 """
|
jpayne@68
|
187 @staticmethod
|
jpayne@68
|
188 def disable_on_exception(tqdm_instance, func):
|
jpayne@68
|
189 """
|
jpayne@68
|
190 Quietly set `tqdm_instance.miniters=inf` if `func` raises `errno=5`.
|
jpayne@68
|
191 """
|
jpayne@68
|
192 tqdm_instance = proxy(tqdm_instance)
|
jpayne@68
|
193
|
jpayne@68
|
194 def inner(*args, **kwargs):
|
jpayne@68
|
195 try:
|
jpayne@68
|
196 return func(*args, **kwargs)
|
jpayne@68
|
197 except OSError as e:
|
jpayne@68
|
198 if e.errno != 5:
|
jpayne@68
|
199 raise
|
jpayne@68
|
200 try:
|
jpayne@68
|
201 tqdm_instance.miniters = float('inf')
|
jpayne@68
|
202 except ReferenceError:
|
jpayne@68
|
203 pass
|
jpayne@68
|
204 except ValueError as e:
|
jpayne@68
|
205 if 'closed' not in str(e):
|
jpayne@68
|
206 raise
|
jpayne@68
|
207 try:
|
jpayne@68
|
208 tqdm_instance.miniters = float('inf')
|
jpayne@68
|
209 except ReferenceError:
|
jpayne@68
|
210 pass
|
jpayne@68
|
211 return inner
|
jpayne@68
|
212
|
jpayne@68
|
213 def __init__(self, wrapped, tqdm_instance):
|
jpayne@68
|
214 super().__init__(wrapped)
|
jpayne@68
|
215 if hasattr(wrapped, 'write'):
|
jpayne@68
|
216 self.wrapper_setattr(
|
jpayne@68
|
217 'write', self.disable_on_exception(tqdm_instance, wrapped.write))
|
jpayne@68
|
218 if hasattr(wrapped, 'flush'):
|
jpayne@68
|
219 self.wrapper_setattr(
|
jpayne@68
|
220 'flush', self.disable_on_exception(tqdm_instance, wrapped.flush))
|
jpayne@68
|
221
|
jpayne@68
|
222 def __eq__(self, other):
|
jpayne@68
|
223 return self._wrapped == getattr(other, '_wrapped', other)
|
jpayne@68
|
224
|
jpayne@68
|
225
|
jpayne@68
|
226 class CallbackIOWrapper(ObjectWrapper):
|
jpayne@68
|
227 def __init__(self, callback, stream, method="read"):
|
jpayne@68
|
228 """
|
jpayne@68
|
229 Wrap a given `file`-like object's `read()` or `write()` to report
|
jpayne@68
|
230 lengths to the given `callback`
|
jpayne@68
|
231 """
|
jpayne@68
|
232 super().__init__(stream)
|
jpayne@68
|
233 func = getattr(stream, method)
|
jpayne@68
|
234 if method == "write":
|
jpayne@68
|
235 @wraps(func)
|
jpayne@68
|
236 def write(data, *args, **kwargs):
|
jpayne@68
|
237 res = func(data, *args, **kwargs)
|
jpayne@68
|
238 callback(len(data))
|
jpayne@68
|
239 return res
|
jpayne@68
|
240 self.wrapper_setattr('write', write)
|
jpayne@68
|
241 elif method == "read":
|
jpayne@68
|
242 @wraps(func)
|
jpayne@68
|
243 def read(*args, **kwargs):
|
jpayne@68
|
244 data = func(*args, **kwargs)
|
jpayne@68
|
245 callback(len(data))
|
jpayne@68
|
246 return data
|
jpayne@68
|
247 self.wrapper_setattr('read', read)
|
jpayne@68
|
248 else:
|
jpayne@68
|
249 raise KeyError("Can only wrap read/write methods")
|
jpayne@68
|
250
|
jpayne@68
|
251
|
jpayne@68
|
252 def _is_utf(encoding):
|
jpayne@68
|
253 try:
|
jpayne@68
|
254 u'\u2588\u2589'.encode(encoding)
|
jpayne@68
|
255 except UnicodeEncodeError:
|
jpayne@68
|
256 return False
|
jpayne@68
|
257 except Exception:
|
jpayne@68
|
258 try:
|
jpayne@68
|
259 return encoding.lower().startswith('utf-') or ('U8' == encoding)
|
jpayne@68
|
260 except Exception:
|
jpayne@68
|
261 return False
|
jpayne@68
|
262 else:
|
jpayne@68
|
263 return True
|
jpayne@68
|
264
|
jpayne@68
|
265
|
jpayne@68
|
266 def _supports_unicode(fp):
|
jpayne@68
|
267 try:
|
jpayne@68
|
268 return _is_utf(fp.encoding)
|
jpayne@68
|
269 except AttributeError:
|
jpayne@68
|
270 return False
|
jpayne@68
|
271
|
jpayne@68
|
272
|
jpayne@68
|
273 def _is_ascii(s):
|
jpayne@68
|
274 if isinstance(s, str):
|
jpayne@68
|
275 for c in s:
|
jpayne@68
|
276 if ord(c) > 255:
|
jpayne@68
|
277 return False
|
jpayne@68
|
278 return True
|
jpayne@68
|
279 return _supports_unicode(s)
|
jpayne@68
|
280
|
jpayne@68
|
281
|
jpayne@68
|
282 def _screen_shape_wrapper(): # pragma: no cover
|
jpayne@68
|
283 """
|
jpayne@68
|
284 Return a function which returns console dimensions (width, height).
|
jpayne@68
|
285 Supported: linux, osx, windows, cygwin.
|
jpayne@68
|
286 """
|
jpayne@68
|
287 _screen_shape = None
|
jpayne@68
|
288 if IS_WIN:
|
jpayne@68
|
289 _screen_shape = _screen_shape_windows
|
jpayne@68
|
290 if _screen_shape is None:
|
jpayne@68
|
291 _screen_shape = _screen_shape_tput
|
jpayne@68
|
292 if IS_NIX:
|
jpayne@68
|
293 _screen_shape = _screen_shape_linux
|
jpayne@68
|
294 return _screen_shape
|
jpayne@68
|
295
|
jpayne@68
|
296
|
jpayne@68
|
297 def _screen_shape_windows(fp): # pragma: no cover
|
jpayne@68
|
298 try:
|
jpayne@68
|
299 import struct
|
jpayne@68
|
300 from ctypes import create_string_buffer, windll
|
jpayne@68
|
301 from sys import stdin, stdout
|
jpayne@68
|
302
|
jpayne@68
|
303 io_handle = -12 # assume stderr
|
jpayne@68
|
304 if fp == stdin:
|
jpayne@68
|
305 io_handle = -10
|
jpayne@68
|
306 elif fp == stdout:
|
jpayne@68
|
307 io_handle = -11
|
jpayne@68
|
308
|
jpayne@68
|
309 h = windll.kernel32.GetStdHandle(io_handle)
|
jpayne@68
|
310 csbi = create_string_buffer(22)
|
jpayne@68
|
311 res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi)
|
jpayne@68
|
312 if res:
|
jpayne@68
|
313 (_bufx, _bufy, _curx, _cury, _wattr, left, top, right, bottom,
|
jpayne@68
|
314 _maxx, _maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw)
|
jpayne@68
|
315 return right - left, bottom - top # +1
|
jpayne@68
|
316 except Exception: # nosec
|
jpayne@68
|
317 pass
|
jpayne@68
|
318 return None, None
|
jpayne@68
|
319
|
jpayne@68
|
320
|
jpayne@68
|
321 def _screen_shape_tput(*_): # pragma: no cover
|
jpayne@68
|
322 """cygwin xterm (windows)"""
|
jpayne@68
|
323 try:
|
jpayne@68
|
324 import shlex
|
jpayne@68
|
325 from subprocess import check_call # nosec
|
jpayne@68
|
326 return [int(check_call(shlex.split('tput ' + i))) - 1
|
jpayne@68
|
327 for i in ('cols', 'lines')]
|
jpayne@68
|
328 except Exception: # nosec
|
jpayne@68
|
329 pass
|
jpayne@68
|
330 return None, None
|
jpayne@68
|
331
|
jpayne@68
|
332
|
jpayne@68
|
333 def _screen_shape_linux(fp): # pragma: no cover
|
jpayne@68
|
334
|
jpayne@68
|
335 try:
|
jpayne@68
|
336 from array import array
|
jpayne@68
|
337 from fcntl import ioctl
|
jpayne@68
|
338 from termios import TIOCGWINSZ
|
jpayne@68
|
339 except ImportError:
|
jpayne@68
|
340 return None, None
|
jpayne@68
|
341 else:
|
jpayne@68
|
342 try:
|
jpayne@68
|
343 rows, cols = array('h', ioctl(fp, TIOCGWINSZ, '\0' * 8))[:2]
|
jpayne@68
|
344 return cols, rows
|
jpayne@68
|
345 except Exception:
|
jpayne@68
|
346 try:
|
jpayne@68
|
347 return [int(os.environ[i]) - 1 for i in ("COLUMNS", "LINES")]
|
jpayne@68
|
348 except (KeyError, ValueError):
|
jpayne@68
|
349 return None, None
|
jpayne@68
|
350
|
jpayne@68
|
351
|
jpayne@68
|
352 def _environ_cols_wrapper(): # pragma: no cover
|
jpayne@68
|
353 """
|
jpayne@68
|
354 Return a function which returns console width.
|
jpayne@68
|
355 Supported: linux, osx, windows, cygwin.
|
jpayne@68
|
356 """
|
jpayne@68
|
357 warn("Use `_screen_shape_wrapper()(file)[0]` instead of"
|
jpayne@68
|
358 " `_environ_cols_wrapper()(file)`", DeprecationWarning, stacklevel=2)
|
jpayne@68
|
359 shape = _screen_shape_wrapper()
|
jpayne@68
|
360 if not shape:
|
jpayne@68
|
361 return None
|
jpayne@68
|
362
|
jpayne@68
|
363 @wraps(shape)
|
jpayne@68
|
364 def inner(fp):
|
jpayne@68
|
365 return shape(fp)[0]
|
jpayne@68
|
366
|
jpayne@68
|
367 return inner
|
jpayne@68
|
368
|
jpayne@68
|
369
|
jpayne@68
|
370 def _term_move_up(): # pragma: no cover
|
jpayne@68
|
371 return '' if (os.name == 'nt') and (colorama is None) else '\x1b[A'
|
jpayne@68
|
372
|
jpayne@68
|
373
|
jpayne@68
|
374 def _text_width(s):
|
jpayne@68
|
375 return sum(2 if east_asian_width(ch) in 'FW' else 1 for ch in str(s))
|
jpayne@68
|
376
|
jpayne@68
|
377
|
jpayne@68
|
378 def disp_len(data):
|
jpayne@68
|
379 """
|
jpayne@68
|
380 Returns the real on-screen length of a string which may contain
|
jpayne@68
|
381 ANSI control codes and wide chars.
|
jpayne@68
|
382 """
|
jpayne@68
|
383 return _text_width(RE_ANSI.sub('', data))
|
jpayne@68
|
384
|
jpayne@68
|
385
|
jpayne@68
|
386 def disp_trim(data, length):
|
jpayne@68
|
387 """
|
jpayne@68
|
388 Trim a string which may contain ANSI control characters.
|
jpayne@68
|
389 """
|
jpayne@68
|
390 if len(data) == disp_len(data):
|
jpayne@68
|
391 return data[:length]
|
jpayne@68
|
392
|
jpayne@68
|
393 ansi_present = bool(RE_ANSI.search(data))
|
jpayne@68
|
394 while disp_len(data) > length: # carefully delete one char at a time
|
jpayne@68
|
395 data = data[:-1]
|
jpayne@68
|
396 if ansi_present and bool(RE_ANSI.search(data)):
|
jpayne@68
|
397 # assume ANSI reset is required
|
jpayne@68
|
398 return data if data.endswith("\033[0m") else data + "\033[0m"
|
jpayne@68
|
399 return data
|