From 93c2f4df85b18f179971b4cd9d70c02e602e1d1d Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Sun, 11 Feb 2024 08:12:57 -0800 Subject: [PATCH 1/2] Some cleanup --- nam/models/_base.py | 38 -------------------------------------- nam/models/base.py | 4 ++-- 2 files changed, 2 insertions(+), 40 deletions(-) diff --git a/nam/models/_base.py b/nam/models/_base.py index 5eb6e174..763353a8 100644 --- a/nam/models/_base.py +++ b/nam/models/_base.py @@ -188,41 +188,3 @@ def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]: d["loudness"] = self._metadata_loudness() d["gain"] = self._metadata_gain() return d - - -class ParametricBaseNet(_Base): - """ - Parametric inputs - """ - - def forward( - self, - params: torch.Tensor, - x: torch.Tensor, - pad_start: Optional[bool] = None, - **kwargs - ): - pad_start = self.pad_start_default if pad_start is None else pad_start - scalar = x.ndim == 1 - if scalar: - x = x[None] - params = params[None] - if pad_start: - x = torch.cat( - (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1 - ) - y = self._forward(params, x, **kwargs) - if scalar: - y = y[0] - return y - - @abc.abstractmethod - def _forward(self, params: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """ - The true forward method. - - :param params: (N,D) - :param x: (N,L1) - :return: (N,L1-RF+1) - """ - pass diff --git a/nam/models/base.py b/nam/models/base.py index a448e4f6..7bb1f83d 100644 --- a/nam/models/base.py +++ b/nam/models/base.py @@ -4,8 +4,8 @@ """ Implements the base PyTorch Lightning model. -This is meant to combine an acutal model (subclassed from `._base.BaseNet` or -`._base.ParametricBaseNet`) along with loss function boilerplate. +This is meant to combine an actual model (subclassed from `._base.BaseNet`) +along with loss function boilerplate. For the base *PyTorch* model containing the actual architecture, see `._base`. """ From f2e6b4bac262f9c882d2ea625255641ab25f5b89 Mon Sep 17 00:00:00 2001 From: Steven Atkinson Date: Sun, 11 Feb 2024 08:14:20 -0800 Subject: [PATCH 2/2] Black --- setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a0b0c08b..d4dffe49 100644 --- a/setup.py +++ b/setup.py @@ -5,17 +5,20 @@ from distutils.util import convert_path from setuptools import setup, find_packages + def get_additional_requirements(): # Issue 294 try: import transformers + # This may not be unnecessarily straict a requirement, but I'd rather - # fix this promptly than leave a chance that it wouldn't be fixed + # fix this promptly than leave a chance that it wouldn't be fixed # properly. return ["transformers>=4"] except ModuleNotFoundError: return [] + main_ns = {} ver_path = convert_path("nam/_version.py") with open(ver_path) as ver_file: