Skip to content

Commit

Permalink
[BREAKING] Remove cab-fitting checkbox from GUI trainer (#462)
Browse files Browse the repository at this point in the history
* Remove cab-fitting option from GUI trainer.

It's always on.
fit_cab to fit_mrstft.
Remove from Settings metadata since it's always on.

* colab.run: rename fit_cab kwarg to fit_mrstft, on by default
  • Loading branch information
sdatkinson authored Sep 16, 2024
1 parent 6a610ec commit 8fecd53
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 25 deletions.
4 changes: 2 additions & 2 deletions nam/train/colab.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def run(
seed: Optional[int] = 0,
user_metadata: Optional[UserMetadata] = None,
ignore_checks: bool = False,
fit_cab: bool = False,
fit_mrstft: bool = True,
):
"""
:param epochs: How many epochs we'll train for.
Expand Down Expand Up @@ -115,7 +115,7 @@ def run(
seed=seed,
local=False,
ignore_checks=ignore_checks,
fit_cab=fit_cab,
fit_mrstft=fit_mrstft,
)
model = train_output.model
training_metadata = train_output.metadata
Expand Down
14 changes: 6 additions & 8 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ def _get_configs(
lr: float,
lr_decay: float,
batch_size: int,
fit_cab: bool,
fit_mrstft: bool,
):

data_config = _get_data_config(
Expand Down Expand Up @@ -1012,7 +1012,7 @@ def _get_configs(
"optimizer": {"lr": 0.01},
"lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.995}},
}
if fit_cab:
if fit_mrstft:
model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT
model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF

Expand Down Expand Up @@ -1295,7 +1295,7 @@ def train(
modelname: str = "model",
ignore_checks: bool = False,
local: bool = False,
fit_cab: bool = False,
fit_mrstft: bool = True,
threshold_esr: Optional[bool] = None,
user_metadata: Optional[UserMetadata] = None,
fast_dev_run: Union[bool, int] = False,
Expand Down Expand Up @@ -1351,9 +1351,7 @@ def parse_user_latency(
return TrainOutput(
model=None,
metadata=metadata.TrainingMetadata(
settings=metadata.Settings(
fit_cab=fit_cab, ignore_checks=ignore_checks
),
settings=metadata.Settings(ignore_checks=ignore_checks),
data=metadata.Data(
latency=latency_analysis, checks=data_check_output
),
Expand All @@ -1373,7 +1371,7 @@ def parse_user_latency(
lr,
lr_decay,
batch_size,
fit_cab,
fit_mrstft,
)
assert (
"fast_dev_run" not in learning_config
Expand All @@ -1399,7 +1397,7 @@ def parse_user_latency(
model.net.sample_rate = sample_rate

# Put together the metadata that's needed in checkpoints:
settings_metadata = metadata.Settings(fit_cab=fit_cab, ignore_checks=ignore_checks)
settings_metadata = metadata.Settings(ignore_checks=ignore_checks)
data_metadata = metadata.Data(latency=latency_analysis, checks=data_check_output)

trainer = pl.Trainer(
Expand Down
16 changes: 13 additions & 3 deletions nam/train/gui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ class _CheckboxKeys(Enum):
Keys for checkboxes
"""

FIT_CAB = "fit_cab"
SILENT_TRAINING = "silent_training"
SAVE_PLOT = "save_plot"

Expand Down Expand Up @@ -484,6 +483,18 @@ def __init__(self):

self._check_button_states()

def get_mrstft_fit(self) -> bool:
"""
Use a pre-emphasized multi-resolution shot-time Fourier transform loss during
training.
This improves agreement in the high frequencies, usually with a minimial loss in
ESR.
"""
# Leave this as a public method to anticipate an extension to make it
# changeable.
return True

def _check_button_states(self):
"""
Determine if any buttons should be disabled
Expand Down Expand Up @@ -525,7 +536,6 @@ def make_checkbox(
self._widgets[key] = check_button # For tracking in set-all-widgets ops

self._checkboxes: Dict[_CheckboxKeys, Checkbox] = dict()
make_checkbox(_CheckboxKeys.FIT_CAB, "Cab modeling", False)
make_checkbox(
_CheckboxKeys.SILENT_TRAINING,
"Silent run (suggested for batch training)",
Expand Down Expand Up @@ -616,7 +626,7 @@ def _train2(self, ignore_checks=False):
modelname=basename,
ignore_checks=ignore_checks,
local=True,
fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
fit_mrstft=self.get_mrstft_fit(),
threshold_esr=threshold_esr,
user_metadata=user_metadata,
)
Expand Down
12 changes: 0 additions & 12 deletions nam/train/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,6 @@

from pydantic import BaseModel

__all__ = [
"Data",
"DataChecks",
"Latency",
"LatencyCalibration",
"LatencyCalibrationWarnings",
"Settings",
"TrainingMetadata",
"TRAINING_KEY",
]

# The key under which the metadata are saved in the .nam:
TRAINING_KEY = "training"

Expand All @@ -33,7 +22,6 @@ class Settings(BaseModel):
User-provided settings
"""

fit_cab: bool
ignore_checks: bool


Expand Down

0 comments on commit 8fecd53

Please sign in to comment.