Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GUI] Advanced option for loss-based early stopping #401

Merged
merged 3 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,36 @@ def _nasty_checks_modal():
modal.mainloop()


class _ValidationStopping(pl.callbacks.EarlyStopping):
"""
Callback to indicate to stop training if the validation metric is good enough,
without the other conditions that EarlyStopping usually forces like patience.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.patience = np.inf


def _get_callbacks(threshold_esr: Optional[float]):
callbacks = [
pl.callbacks.model_checkpoint.ModelCheckpoint(
filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}",
save_top_k=3,
monitor="val_loss",
every_n_epochs=1,
),
pl.callbacks.model_checkpoint.ModelCheckpoint(
filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1
),
]
if threshold_esr is not None:
callbacks.append(
_ValidationStopping(monitor="ESR", stopping_threshold=threshold_esr)
)
return callbacks


def train(
input_path: str,
output_path: str,
Expand All @@ -1099,7 +1129,11 @@ def train(
ignore_checks: bool = False,
local: bool = False,
fit_cab: bool = False,
threshold_esr: Optional[bool] = None,
) -> Optional[Model]:
"""
:param threshold_esr: Stop training if ESR is better than this. Ignore if `None`.
"""
if seed is not None:
torch.manual_seed(seed)

Expand Down Expand Up @@ -1164,17 +1198,7 @@ def train(
)

trainer = pl.Trainer(
callbacks=[
pl.callbacks.model_checkpoint.ModelCheckpoint(
filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}",
save_top_k=3,
monitor="val_loss",
every_n_epochs=1,
),
pl.callbacks.model_checkpoint.ModelCheckpoint(
filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1
),
],
callbacks=_get_callbacks(threshold_esr),
default_root_dir=train_path,
**learning_config["trainer"],
)
Expand Down
64 changes: 50 additions & 14 deletions nam/train/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def _ensure_graceful_shutdowns():

_DEFAULT_DELAY = None
_DEFAULT_IGNORE_CHECKS = False
_DEFAULT_THRESHOLD_ESR = None

_ADVANCED_OPTIONS_LEFT_WIDTH = 12
_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
Expand All @@ -73,10 +74,21 @@ def _ensure_graceful_shutdowns():

@dataclass
class _AdvancedOptions(object):
"""
:param architecture: Which architecture to use.
:param num_epochs: How many epochs to train for.
:param latency: Latency between the input and output audio, in samples.
None means we don't know and it has to be calibrated.
:param ignore_checks: Keep going even if a check says that something is wrong.
:param threshold_esr: Stop training if the ESR gets better than this. If None, don't
stop.
"""

architecture: core.Architecture
num_epochs: int
delay: Optional[int]
latency: Optional[int]
ignore_checks: bool
threshold_esr: Optional[float]


class _PathType(Enum):
Expand Down Expand Up @@ -268,6 +280,7 @@ def __init__(self):
_DEFAULT_NUM_EPOCHS,
_DEFAULT_DELAY,
_DEFAULT_IGNORE_CHECKS,
_DEFAULT_THRESHOLD_ESR,
)
# Window to edit them:

Expand Down Expand Up @@ -378,8 +391,9 @@ def _train(self):
# Advanced options:
num_epochs = self.advanced_options.num_epochs
architecture = self.advanced_options.architecture
delay = self.advanced_options.delay
delay = self.advanced_options.latency
file_list = self._path_button_output.val
threshold_esr = self.advanced_options.threshold_esr

# Advanced-er options
# If you're poking around looking for these, then maybe it's time to learn to
Expand Down Expand Up @@ -413,6 +427,7 @@ def _train(self):
].variable.get(),
local=True,
fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
threshold_esr=threshold_esr,
)
if trained_model is None:
print("Model training failed! Skip exporting...")
Expand Down Expand Up @@ -443,14 +458,18 @@ def _non_negative_int(val):
return val


def _int_or_null(val):
def _type_or_null(T, val):
val = val.rstrip()
if val == "null":
return val
return int(val)
return T(val)


_int_or_null = partial(_type_or_null, int)
_float_or_null = partial(_type_or_null, float)


def _int_or_null_inv(val):
def _type_or_null_inv(val):
return "null" if val is None else str(val)


Expand Down Expand Up @@ -602,16 +621,26 @@ def __init__(self, parent: _GUI):
)

# Delay: text box
self._frame_delay = tk.Frame(self._root)
self._frame_delay.pack()
self._frame_latency = tk.Frame(self._root)
self._frame_latency.pack()

self._delay = _LabeledText(
self._frame_delay,
"Delay",
default=_int_or_null_inv(self._parent.advanced_options.delay),
self._latency = _LabeledText(
self._frame_latency,
"Reamp latency",
default=_type_or_null_inv(self._parent.advanced_options.latency),
type=_int_or_null,
)

# Threshold ESR
self._frame_threshold_esr = tk.Frame(self._root)
self._frame_threshold_esr.pack()
self._threshold_esr = _LabeledText(
self._frame_threshold_esr,
"Threshold ESR",
default=_type_or_null_inv(self._parent.advanced_options.threshold_esr),
type=_float_or_null,
)

# "Ok": apply and destory
self._frame_ok = tk.Frame(self._root)
self._frame_ok.pack()
Expand All @@ -636,10 +665,17 @@ def _apply_and_destroy(self):
epochs = self._epochs.get()
if epochs is not None:
self._parent.advanced_options.num_epochs = epochs
delay = self._delay.get()
latency = self._latency.get()
# Value None is returned as "null" to disambiguate from non-set.
if delay is not None:
self._parent.advanced_options.delay = None if delay == "null" else delay
if latency is not None:
self._parent.advanced_options.latency = (
None if latency == "null" else latency
)
threshold_esr = self._threshold_esr.get()
if threshold_esr is not None:
self._parent.advanced_options.threshold_esr = (
None if threshold_esr == "null" else threshold_esr
)
self._root.destroy()


Expand Down
Loading