jpayne@68: from copy import copy jpayne@68: from functools import partial jpayne@68: jpayne@68: from .auto import tqdm as tqdm_auto jpayne@68: jpayne@68: try: jpayne@68: import keras jpayne@68: except (ImportError, AttributeError) as e: jpayne@68: try: jpayne@68: from tensorflow import keras jpayne@68: except ImportError: jpayne@68: raise e jpayne@68: __author__ = {"github.com/": ["casperdcl"]} jpayne@68: __all__ = ['TqdmCallback'] jpayne@68: jpayne@68: jpayne@68: class TqdmCallback(keras.callbacks.Callback): jpayne@68: """Keras callback for epoch and batch progress.""" jpayne@68: @staticmethod jpayne@68: def bar2callback(bar, pop=None, delta=(lambda logs: 1)): jpayne@68: def callback(_, logs=None): jpayne@68: n = delta(logs) jpayne@68: if logs: jpayne@68: if pop: jpayne@68: logs = copy(logs) jpayne@68: [logs.pop(i, 0) for i in pop] jpayne@68: bar.set_postfix(logs, refresh=False) jpayne@68: bar.update(n) jpayne@68: jpayne@68: return callback jpayne@68: jpayne@68: def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1, jpayne@68: tqdm_class=tqdm_auto, **tqdm_kwargs): jpayne@68: """ jpayne@68: Parameters jpayne@68: ---------- jpayne@68: epochs : int, optional jpayne@68: data_size : int, optional jpayne@68: Number of training pairs. jpayne@68: batch_size : int, optional jpayne@68: Number of training pairs per batch. jpayne@68: verbose : int jpayne@68: 0: epoch, 1: batch (transient), 2: batch. [default: 1]. jpayne@68: Will be set to `0` unless both `data_size` and `batch_size` jpayne@68: are given. jpayne@68: tqdm_class : optional jpayne@68: `tqdm` class to use for bars [default: `tqdm.auto.tqdm`]. jpayne@68: tqdm_kwargs : optional jpayne@68: Any other arguments used for all bars. jpayne@68: """ jpayne@68: if tqdm_kwargs: jpayne@68: tqdm_class = partial(tqdm_class, **tqdm_kwargs) jpayne@68: self.tqdm_class = tqdm_class jpayne@68: self.epoch_bar = tqdm_class(total=epochs, unit='epoch') jpayne@68: self.on_epoch_end = self.bar2callback(self.epoch_bar) jpayne@68: if data_size and batch_size: jpayne@68: self.batches = batches = (data_size + batch_size - 1) // batch_size jpayne@68: else: jpayne@68: self.batches = batches = None jpayne@68: self.verbose = verbose jpayne@68: if verbose == 1: jpayne@68: self.batch_bar = tqdm_class(total=batches, unit='batch', leave=False) jpayne@68: self.on_batch_end = self.bar2callback( jpayne@68: self.batch_bar, pop=['batch', 'size'], jpayne@68: delta=lambda logs: logs.get('size', 1)) jpayne@68: jpayne@68: def on_train_begin(self, *_, **__): jpayne@68: params = self.params.get jpayne@68: auto_total = params('epochs', params('nb_epoch', None)) jpayne@68: if auto_total is not None and auto_total != self.epoch_bar.total: jpayne@68: self.epoch_bar.reset(total=auto_total) jpayne@68: jpayne@68: def on_epoch_begin(self, epoch, *_, **__): jpayne@68: if self.epoch_bar.n < epoch: jpayne@68: ebar = self.epoch_bar jpayne@68: ebar.n = ebar.last_print_n = ebar.initial = epoch jpayne@68: if self.verbose: jpayne@68: params = self.params.get jpayne@68: total = params('samples', params( jpayne@68: 'nb_sample', params('steps', None))) or self.batches jpayne@68: if self.verbose == 2: jpayne@68: if hasattr(self, 'batch_bar'): jpayne@68: self.batch_bar.close() jpayne@68: self.batch_bar = self.tqdm_class( jpayne@68: total=total, unit='batch', leave=True, jpayne@68: unit_scale=1 / (params('batch_size', 1) or 1)) jpayne@68: self.on_batch_end = self.bar2callback( jpayne@68: self.batch_bar, pop=['batch', 'size'], jpayne@68: delta=lambda logs: logs.get('size', 1)) jpayne@68: elif self.verbose == 1: jpayne@68: self.batch_bar.unit_scale = 1 / (params('batch_size', 1) or 1) jpayne@68: self.batch_bar.reset(total=total) jpayne@68: else: jpayne@68: raise KeyError('Unknown verbosity') jpayne@68: jpayne@68: def on_train_end(self, *_, **__): jpayne@68: if hasattr(self, 'batch_bar'): jpayne@68: self.batch_bar.close() jpayne@68: self.epoch_bar.close() jpayne@68: jpayne@68: def display(self): jpayne@68: """Displays in the current cell in Notebooks.""" jpayne@68: container = getattr(self.epoch_bar, 'container', None) jpayne@68: if container is None: jpayne@68: return jpayne@68: from .notebook import display jpayne@68: display(container) jpayne@68: batch_bar = getattr(self, 'batch_bar', None) jpayne@68: if batch_bar is not None: jpayne@68: display(batch_bar.container) jpayne@68: jpayne@68: @staticmethod jpayne@68: def _implements_train_batch_hooks(): jpayne@68: return True jpayne@68: jpayne@68: @staticmethod jpayne@68: def _implements_test_batch_hooks(): jpayne@68: return True jpayne@68: jpayne@68: @staticmethod jpayne@68: def _implements_predict_batch_hooks(): jpayne@68: return True