diff --git a/train.py b/train.py index 17e816c06ede..e2cd5ec85c09 100644 --- a/train.py +++ b/train.py @@ -138,6 +138,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # Batch size if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size batch_size = check_train_batch_size(model, imgsz) + loggers.on_params_update({"batch_size": batch_size}) # Optimizer nbs = 64 # nominal batch size diff --git a/utils/callbacks.py b/utils/callbacks.py index c9d936ef082d..13d82ebc2e41 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -32,7 +32,7 @@ def __init__(self): 'on_fit_epoch_end': [], # fit = train + val 'on_model_save': [], 'on_train_end': [], - + 'on_params_update': [], 'teardown': [], } diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index 2a68d9785071..7a1df2a45ea7 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -157,3 +157,9 @@ def on_train_end(self, last, best, plots, epoch, results): else: self.wandb.finish_run() self.wandb = WandbLogger(self.opt) + + def on_params_update(self, params): + # Update hyperparams or configs of the experiment + # params: A dict containing {param: value} pairs + if self.wandb: + self.wandb.wandb_run.config.update(params, allow_val_change=True)