annotate CSP2/CSP2_env/env-d9b9114564458d9d-741b3de822f2aaca6c6caa4325c4afce/lib/python3.8/site-packages/tqdm/keras.py @ 68:5028fdace37b

planemo upload commit 2e9511a184a1ca667c7be0c6321a36dc4e3d116d
author jpayne
date Tue, 18 Mar 2025 16:23:26 -0400
parents
children
rev   line source
jpayne@68 1 from copy import copy
jpayne@68 2 from functools import partial
jpayne@68 3
jpayne@68 4 from .auto import tqdm as tqdm_auto
jpayne@68 5
jpayne@68 6 try:
jpayne@68 7 import keras
jpayne@68 8 except (ImportError, AttributeError) as e:
jpayne@68 9 try:
jpayne@68 10 from tensorflow import keras
jpayne@68 11 except ImportError:
jpayne@68 12 raise e
jpayne@68 13 __author__ = {"github.com/": ["casperdcl"]}
jpayne@68 14 __all__ = ['TqdmCallback']
jpayne@68 15
jpayne@68 16
jpayne@68 17 class TqdmCallback(keras.callbacks.Callback):
jpayne@68 18 """Keras callback for epoch and batch progress."""
jpayne@68 19 @staticmethod
jpayne@68 20 def bar2callback(bar, pop=None, delta=(lambda logs: 1)):
jpayne@68 21 def callback(_, logs=None):
jpayne@68 22 n = delta(logs)
jpayne@68 23 if logs:
jpayne@68 24 if pop:
jpayne@68 25 logs = copy(logs)
jpayne@68 26 [logs.pop(i, 0) for i in pop]
jpayne@68 27 bar.set_postfix(logs, refresh=False)
jpayne@68 28 bar.update(n)
jpayne@68 29
jpayne@68 30 return callback
jpayne@68 31
jpayne@68 32 def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1,
jpayne@68 33 tqdm_class=tqdm_auto, **tqdm_kwargs):
jpayne@68 34 """
jpayne@68 35 Parameters
jpayne@68 36 ----------
jpayne@68 37 epochs : int, optional
jpayne@68 38 data_size : int, optional
jpayne@68 39 Number of training pairs.
jpayne@68 40 batch_size : int, optional
jpayne@68 41 Number of training pairs per batch.
jpayne@68 42 verbose : int
jpayne@68 43 0: epoch, 1: batch (transient), 2: batch. [default: 1].
jpayne@68 44 Will be set to `0` unless both `data_size` and `batch_size`
jpayne@68 45 are given.
jpayne@68 46 tqdm_class : optional
jpayne@68 47 `tqdm` class to use for bars [default: `tqdm.auto.tqdm`].
jpayne@68 48 tqdm_kwargs : optional
jpayne@68 49 Any other arguments used for all bars.
jpayne@68 50 """
jpayne@68 51 if tqdm_kwargs:
jpayne@68 52 tqdm_class = partial(tqdm_class, **tqdm_kwargs)
jpayne@68 53 self.tqdm_class = tqdm_class
jpayne@68 54 self.epoch_bar = tqdm_class(total=epochs, unit='epoch')
jpayne@68 55 self.on_epoch_end = self.bar2callback(self.epoch_bar)
jpayne@68 56 if data_size and batch_size:
jpayne@68 57 self.batches = batches = (data_size + batch_size - 1) // batch_size
jpayne@68 58 else:
jpayne@68 59 self.batches = batches = None
jpayne@68 60 self.verbose = verbose
jpayne@68 61 if verbose == 1:
jpayne@68 62 self.batch_bar = tqdm_class(total=batches, unit='batch', leave=False)
jpayne@68 63 self.on_batch_end = self.bar2callback(
jpayne@68 64 self.batch_bar, pop=['batch', 'size'],
jpayne@68 65 delta=lambda logs: logs.get('size', 1))
jpayne@68 66
jpayne@68 67 def on_train_begin(self, *_, **__):
jpayne@68 68 params = self.params.get
jpayne@68 69 auto_total = params('epochs', params('nb_epoch', None))
jpayne@68 70 if auto_total is not None and auto_total != self.epoch_bar.total:
jpayne@68 71 self.epoch_bar.reset(total=auto_total)
jpayne@68 72
jpayne@68 73 def on_epoch_begin(self, epoch, *_, **__):
jpayne@68 74 if self.epoch_bar.n < epoch:
jpayne@68 75 ebar = self.epoch_bar
jpayne@68 76 ebar.n = ebar.last_print_n = ebar.initial = epoch
jpayne@68 77 if self.verbose:
jpayne@68 78 params = self.params.get
jpayne@68 79 total = params('samples', params(
jpayne@68 80 'nb_sample', params('steps', None))) or self.batches
jpayne@68 81 if self.verbose == 2:
jpayne@68 82 if hasattr(self, 'batch_bar'):
jpayne@68 83 self.batch_bar.close()
jpayne@68 84 self.batch_bar = self.tqdm_class(
jpayne@68 85 total=total, unit='batch', leave=True,
jpayne@68 86 unit_scale=1 / (params('batch_size', 1) or 1))
jpayne@68 87 self.on_batch_end = self.bar2callback(
jpayne@68 88 self.batch_bar, pop=['batch', 'size'],
jpayne@68 89 delta=lambda logs: logs.get('size', 1))
jpayne@68 90 elif self.verbose == 1:
jpayne@68 91 self.batch_bar.unit_scale = 1 / (params('batch_size', 1) or 1)
jpayne@68 92 self.batch_bar.reset(total=total)
jpayne@68 93 else:
jpayne@68 94 raise KeyError('Unknown verbosity')
jpayne@68 95
jpayne@68 96 def on_train_end(self, *_, **__):
jpayne@68 97 if hasattr(self, 'batch_bar'):
jpayne@68 98 self.batch_bar.close()
jpayne@68 99 self.epoch_bar.close()
jpayne@68 100
jpayne@68 101 def display(self):
jpayne@68 102 """Displays in the current cell in Notebooks."""
jpayne@68 103 container = getattr(self.epoch_bar, 'container', None)
jpayne@68 104 if container is None:
jpayne@68 105 return
jpayne@68 106 from .notebook import display
jpayne@68 107 display(container)
jpayne@68 108 batch_bar = getattr(self, 'batch_bar', None)
jpayne@68 109 if batch_bar is not None:
jpayne@68 110 display(batch_bar.container)
jpayne@68 111
jpayne@68 112 @staticmethod
jpayne@68 113 def _implements_train_batch_hooks():
jpayne@68 114 return True
jpayne@68 115
jpayne@68 116 @staticmethod
jpayne@68 117 def _implements_test_batch_hooks():
jpayne@68 118 return True
jpayne@68 119
jpayne@68 120 @staticmethod
jpayne@68 121 def _implements_predict_batch_hooks():
jpayne@68 122 return True