jpayne@68
|
1 """
|
jpayne@68
|
2 Asynchronous progressbar decorator for iterators.
|
jpayne@68
|
3 Includes a default `range` iterator printing to `stderr`.
|
jpayne@68
|
4
|
jpayne@68
|
5 Usage:
|
jpayne@68
|
6 >>> from tqdm.asyncio import trange, tqdm
|
jpayne@68
|
7 >>> async for i in trange(10):
|
jpayne@68
|
8 ... ...
|
jpayne@68
|
9 """
|
jpayne@68
|
10 import asyncio
|
jpayne@68
|
11 from sys import version_info
|
jpayne@68
|
12
|
jpayne@68
|
13 from .std import tqdm as std_tqdm
|
jpayne@68
|
14
|
jpayne@68
|
15 __author__ = {"github.com/": ["casperdcl"]}
|
jpayne@68
|
16 __all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange']
|
jpayne@68
|
17
|
jpayne@68
|
18
|
jpayne@68
|
19 class tqdm_asyncio(std_tqdm):
|
jpayne@68
|
20 """
|
jpayne@68
|
21 Asynchronous-friendly version of tqdm.
|
jpayne@68
|
22 """
|
jpayne@68
|
23 def __init__(self, iterable=None, *args, **kwargs):
|
jpayne@68
|
24 super().__init__(iterable, *args, **kwargs)
|
jpayne@68
|
25 self.iterable_awaitable = False
|
jpayne@68
|
26 if iterable is not None:
|
jpayne@68
|
27 if hasattr(iterable, "__anext__"):
|
jpayne@68
|
28 self.iterable_next = iterable.__anext__
|
jpayne@68
|
29 self.iterable_awaitable = True
|
jpayne@68
|
30 elif hasattr(iterable, "__next__"):
|
jpayne@68
|
31 self.iterable_next = iterable.__next__
|
jpayne@68
|
32 else:
|
jpayne@68
|
33 self.iterable_iterator = iter(iterable)
|
jpayne@68
|
34 self.iterable_next = self.iterable_iterator.__next__
|
jpayne@68
|
35
|
jpayne@68
|
36 def __aiter__(self):
|
jpayne@68
|
37 return self
|
jpayne@68
|
38
|
jpayne@68
|
39 async def __anext__(self):
|
jpayne@68
|
40 try:
|
jpayne@68
|
41 if self.iterable_awaitable:
|
jpayne@68
|
42 res = await self.iterable_next()
|
jpayne@68
|
43 else:
|
jpayne@68
|
44 res = self.iterable_next()
|
jpayne@68
|
45 self.update()
|
jpayne@68
|
46 return res
|
jpayne@68
|
47 except StopIteration:
|
jpayne@68
|
48 self.close()
|
jpayne@68
|
49 raise StopAsyncIteration
|
jpayne@68
|
50 except BaseException:
|
jpayne@68
|
51 self.close()
|
jpayne@68
|
52 raise
|
jpayne@68
|
53
|
jpayne@68
|
54 def send(self, *args, **kwargs):
|
jpayne@68
|
55 return self.iterable.send(*args, **kwargs)
|
jpayne@68
|
56
|
jpayne@68
|
57 @classmethod
|
jpayne@68
|
58 def as_completed(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs):
|
jpayne@68
|
59 """
|
jpayne@68
|
60 Wrapper for `asyncio.as_completed`.
|
jpayne@68
|
61 """
|
jpayne@68
|
62 if total is None:
|
jpayne@68
|
63 total = len(fs)
|
jpayne@68
|
64 kwargs = {}
|
jpayne@68
|
65 if version_info[:2] < (3, 10):
|
jpayne@68
|
66 kwargs['loop'] = loop
|
jpayne@68
|
67 yield from cls(asyncio.as_completed(fs, timeout=timeout, **kwargs),
|
jpayne@68
|
68 total=total, **tqdm_kwargs)
|
jpayne@68
|
69
|
jpayne@68
|
70 @classmethod
|
jpayne@68
|
71 async def gather(cls, *fs, loop=None, timeout=None, total=None, **tqdm_kwargs):
|
jpayne@68
|
72 """
|
jpayne@68
|
73 Wrapper for `asyncio.gather`.
|
jpayne@68
|
74 """
|
jpayne@68
|
75 async def wrap_awaitable(i, f):
|
jpayne@68
|
76 return i, await f
|
jpayne@68
|
77
|
jpayne@68
|
78 ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)]
|
jpayne@68
|
79 res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout,
|
jpayne@68
|
80 total=total, **tqdm_kwargs)]
|
jpayne@68
|
81 return [i for _, i in sorted(res)]
|
jpayne@68
|
82
|
jpayne@68
|
83
|
jpayne@68
|
84 def tarange(*args, **kwargs):
|
jpayne@68
|
85 """
|
jpayne@68
|
86 A shortcut for `tqdm.asyncio.tqdm(range(*args), **kwargs)`.
|
jpayne@68
|
87 """
|
jpayne@68
|
88 return tqdm_asyncio(range(*args), **kwargs)
|
jpayne@68
|
89
|
jpayne@68
|
90
|
jpayne@68
|
91 # Aliases
|
jpayne@68
|
92 tqdm = tqdm_asyncio
|
jpayne@68
|
93 trange = tarange
|