diff --git a/nam/train/core.py b/nam/train/core.py index 60182a30..84b2db9b 100644 --- a/nam/train/core.py +++ b/nam/train/core.py @@ -1079,6 +1079,36 @@ def _nasty_checks_modal(): modal.mainloop() +class _ValidationStopping(pl.callbacks.EarlyStopping): + """ + Callback to indicate to stop training if the validation metric is good enough, + without the other conditions that EarlyStopping usually forces like patience. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.patience = np.inf + + +def _get_callbacks(threshold_esr: Optional[float]): + callbacks = [ + pl.callbacks.model_checkpoint.ModelCheckpoint( + filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}", + save_top_k=3, + monitor="val_loss", + every_n_epochs=1, + ), + pl.callbacks.model_checkpoint.ModelCheckpoint( + filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1 + ), + ] + if threshold_esr is not None: + callbacks.append( + _ValidationStopping(monitor="ESR", stopping_threshold=threshold_esr) + ) + return callbacks + + def train( input_path: str, output_path: str, @@ -1099,7 +1129,11 @@ def train( ignore_checks: bool = False, local: bool = False, fit_cab: bool = False, + threshold_esr: Optional[bool] = None, ) -> Optional[Model]: + """ + :param threshold_esr: Stop training if ESR is better than this. Ignore if `None`. + """ if seed is not None: torch.manual_seed(seed) @@ -1164,17 +1198,7 @@ def train( ) trainer = pl.Trainer( - callbacks=[ - pl.callbacks.model_checkpoint.ModelCheckpoint( - filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}", - save_top_k=3, - monitor="val_loss", - every_n_epochs=1, - ), - pl.callbacks.model_checkpoint.ModelCheckpoint( - filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1 - ), - ], + callbacks=_get_callbacks(threshold_esr), default_root_dir=train_path, **learning_config["trainer"], ) diff --git a/nam/train/gui.py b/nam/train/gui.py index eb28403a..44376301 100644 --- a/nam/train/gui.py +++ b/nam/train/gui.py @@ -65,6 +65,7 @@ def _ensure_graceful_shutdowns(): _DEFAULT_DELAY = None _DEFAULT_IGNORE_CHECKS = False +_DEFAULT_THRESHOLD_ESR = None _ADVANCED_OPTIONS_LEFT_WIDTH = 12 _ADVANCED_OPTIONS_RIGHT_WIDTH = 12 @@ -73,10 +74,21 @@ def _ensure_graceful_shutdowns(): @dataclass class _AdvancedOptions(object): + """ + :param architecture: Which architecture to use. + :param num_epochs: How many epochs to train for. + :param latency: Latency between the input and output audio, in samples. + None means we don't know and it has to be calibrated. + :param ignore_checks: Keep going even if a check says that something is wrong. + :param threshold_esr: Stop training if the ESR gets better than this. If None, don't + stop. + """ + architecture: core.Architecture num_epochs: int - delay: Optional[int] + latency: Optional[int] ignore_checks: bool + threshold_esr: Optional[float] class _PathType(Enum): @@ -268,6 +280,7 @@ def __init__(self): _DEFAULT_NUM_EPOCHS, _DEFAULT_DELAY, _DEFAULT_IGNORE_CHECKS, + _DEFAULT_THRESHOLD_ESR, ) # Window to edit them: @@ -378,8 +391,9 @@ def _train(self): # Advanced options: num_epochs = self.advanced_options.num_epochs architecture = self.advanced_options.architecture - delay = self.advanced_options.delay + delay = self.advanced_options.latency file_list = self._path_button_output.val + threshold_esr = self.advanced_options.threshold_esr # Advanced-er options # If you're poking around looking for these, then maybe it's time to learn to @@ -413,6 +427,7 @@ def _train(self): ].variable.get(), local=True, fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(), + threshold_esr=threshold_esr, ) if trained_model is None: print("Model training failed! Skip exporting...") @@ -443,14 +458,18 @@ def _non_negative_int(val): return val -def _int_or_null(val): +def _type_or_null(T, val): val = val.rstrip() if val == "null": return val - return int(val) + return T(val) + + +_int_or_null = partial(_type_or_null, int) +_float_or_null = partial(_type_or_null, float) -def _int_or_null_inv(val): +def _type_or_null_inv(val): return "null" if val is None else str(val) @@ -602,16 +621,26 @@ def __init__(self, parent: _GUI): ) # Delay: text box - self._frame_delay = tk.Frame(self._root) - self._frame_delay.pack() + self._frame_latency = tk.Frame(self._root) + self._frame_latency.pack() - self._delay = _LabeledText( - self._frame_delay, - "Delay", - default=_int_or_null_inv(self._parent.advanced_options.delay), + self._latency = _LabeledText( + self._frame_latency, + "Reamp latency", + default=_type_or_null_inv(self._parent.advanced_options.latency), type=_int_or_null, ) + # Threshold ESR + self._frame_threshold_esr = tk.Frame(self._root) + self._frame_threshold_esr.pack() + self._threshold_esr = _LabeledText( + self._frame_threshold_esr, + "Threshold ESR", + default=_type_or_null_inv(self._parent.advanced_options.threshold_esr), + type=_float_or_null, + ) + # "Ok": apply and destory self._frame_ok = tk.Frame(self._root) self._frame_ok.pack() @@ -636,10 +665,17 @@ def _apply_and_destroy(self): epochs = self._epochs.get() if epochs is not None: self._parent.advanced_options.num_epochs = epochs - delay = self._delay.get() + latency = self._latency.get() # Value None is returned as "null" to disambiguate from non-set. - if delay is not None: - self._parent.advanced_options.delay = None if delay == "null" else delay + if latency is not None: + self._parent.advanced_options.latency = ( + None if latency == "null" else latency + ) + threshold_esr = self._threshold_esr.get() + if threshold_esr is not None: + self._parent.advanced_options.threshold_esr = ( + None if threshold_esr == "null" else threshold_esr + ) self._root.destroy()