jpayne@68: from functools import partial jpayne@68: jpayne@68: from dask.callbacks import Callback jpayne@68: jpayne@68: from .auto import tqdm as tqdm_auto jpayne@68: jpayne@68: __author__ = {"github.com/": ["casperdcl"]} jpayne@68: __all__ = ['TqdmCallback'] jpayne@68: jpayne@68: jpayne@68: class TqdmCallback(Callback): jpayne@68: """Dask callback for task progress.""" jpayne@68: def __init__(self, start=None, pretask=None, tqdm_class=tqdm_auto, jpayne@68: **tqdm_kwargs): jpayne@68: """ jpayne@68: Parameters jpayne@68: ---------- jpayne@68: tqdm_class : optional jpayne@68: `tqdm` class to use for bars [default: `tqdm.auto.tqdm`]. jpayne@68: tqdm_kwargs : optional jpayne@68: Any other arguments used for all bars. jpayne@68: """ jpayne@68: super().__init__(start=start, pretask=pretask) jpayne@68: if tqdm_kwargs: jpayne@68: tqdm_class = partial(tqdm_class, **tqdm_kwargs) jpayne@68: self.tqdm_class = tqdm_class jpayne@68: jpayne@68: def _start_state(self, _, state): jpayne@68: self.pbar = self.tqdm_class(total=sum( jpayne@68: len(state[k]) for k in ['ready', 'waiting', 'running', 'finished'])) jpayne@68: jpayne@68: def _posttask(self, *_, **__): jpayne@68: self.pbar.update() jpayne@68: jpayne@68: def _finish(self, *_, **__): jpayne@68: self.pbar.close() jpayne@68: jpayne@68: def display(self): jpayne@68: """Displays in the current cell in Notebooks.""" jpayne@68: container = getattr(self.bar, 'container', None) jpayne@68: if container is None: jpayne@68: return jpayne@68: from .notebook import display jpayne@68: display(container)