Skip to content

Commit

Permalink
Clean up some code (#381)
Browse files Browse the repository at this point in the history
* Some cleanup

* Black
  • Loading branch information
sdatkinson authored Feb 11, 2024
1 parent 58c6f74 commit a5ad4b7
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 41 deletions.
38 changes: 0 additions & 38 deletions nam/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,41 +188,3 @@ def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]:
d["loudness"] = self._metadata_loudness()
d["gain"] = self._metadata_gain()
return d


class ParametricBaseNet(_Base):
"""
Parametric inputs
"""

def forward(
self,
params: torch.Tensor,
x: torch.Tensor,
pad_start: Optional[bool] = None,
**kwargs
):
pad_start = self.pad_start_default if pad_start is None else pad_start
scalar = x.ndim == 1
if scalar:
x = x[None]
params = params[None]
if pad_start:
x = torch.cat(
(torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1
)
y = self._forward(params, x, **kwargs)
if scalar:
y = y[0]
return y

@abc.abstractmethod
def _forward(self, params: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
The true forward method.
:param params: (N,D)
:param x: (N,L1)
:return: (N,L1-RF+1)
"""
pass
4 changes: 2 additions & 2 deletions nam/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

"""
Implements the base PyTorch Lightning model.
This is meant to combine an acutal model (subclassed from `._base.BaseNet` or
`._base.ParametricBaseNet`) along with loss function boilerplate.
This is meant to combine an actual model (subclassed from `._base.BaseNet`)
along with loss function boilerplate.
For the base *PyTorch* model containing the actual architecture, see `._base`.
"""
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
from distutils.util import convert_path
from setuptools import setup, find_packages


def get_additional_requirements():
# Issue 294
try:
import transformers

# This may not be unnecessarily straict a requirement, but I'd rather
# fix this promptly than leave a chance that it wouldn't be fixed
# fix this promptly than leave a chance that it wouldn't be fixed
# properly.
return ["transformers>=4"]
except ModuleNotFoundError:
return []


main_ns = {}
ver_path = convert_path("nam/_version.py")
with open(ver_path) as ver_file:
Expand Down

0 comments on commit a5ad4b7

Please sign in to comment.