diff --git a/nam/models/_base.py b/nam/models/_base.py index 5eb6e174..763353a8 100644 --- a/nam/models/_base.py +++ b/nam/models/_base.py @@ -188,41 +188,3 @@ def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]: d["loudness"] = self._metadata_loudness() d["gain"] = self._metadata_gain() return d - - -class ParametricBaseNet(_Base): - """ - Parametric inputs - """ - - def forward( - self, - params: torch.Tensor, - x: torch.Tensor, - pad_start: Optional[bool] = None, - **kwargs - ): - pad_start = self.pad_start_default if pad_start is None else pad_start - scalar = x.ndim == 1 - if scalar: - x = x[None] - params = params[None] - if pad_start: - x = torch.cat( - (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1 - ) - y = self._forward(params, x, **kwargs) - if scalar: - y = y[0] - return y - - @abc.abstractmethod - def _forward(self, params: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """ - The true forward method. - - :param params: (N,D) - :param x: (N,L1) - :return: (N,L1-RF+1) - """ - pass diff --git a/nam/models/base.py b/nam/models/base.py index a448e4f6..7bb1f83d 100644 --- a/nam/models/base.py +++ b/nam/models/base.py @@ -4,8 +4,8 @@ """ Implements the base PyTorch Lightning model. -This is meant to combine an acutal model (subclassed from `._base.BaseNet` or -`._base.ParametricBaseNet`) along with loss function boilerplate. +This is meant to combine an actual model (subclassed from `._base.BaseNet`) +along with loss function boilerplate. For the base *PyTorch* model containing the actual architecture, see `._base`. """ diff --git a/setup.py b/setup.py index a0b0c08b..d4dffe49 100644 --- a/setup.py +++ b/setup.py @@ -5,17 +5,20 @@ from distutils.util import convert_path from setuptools import setup, find_packages + def get_additional_requirements(): # Issue 294 try: import transformers + # This may not be unnecessarily straict a requirement, but I'd rather - # fix this promptly than leave a chance that it wouldn't be fixed + # fix this promptly than leave a chance that it wouldn't be fixed # properly. return ["transformers>=4"] except ModuleNotFoundError: return [] + main_ns = {} ver_path = convert_path("nam/_version.py") with open(ver_path) as ver_file: