Skip to content

Commit

Permalink
feat: Allow setting and deleting parameters within container
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed Oct 17, 2024
1 parent e71e541 commit 19f354d
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 15 deletions.
58 changes: 43 additions & 15 deletions src/_griffe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(
"""Initialize the parameter.
Parameters:
name: The parameter name.
name: The parameter name, without leading stars (`*` or `**`).
annotation: The parameter annotation, if any.
kind: The parameter kind.
default: The parameter default, if any.
Expand Down Expand Up @@ -266,31 +266,61 @@ def __init__(self, *parameters: Parameter) -> None:
Parameters:
*parameters: The initial parameters to add to the container.
"""
self._parameters_list: list[Parameter] = []
self._parameters_dict: dict[str, Parameter] = {}
for parameter in parameters:
self.add(parameter)
self._params: list[Parameter] = list(parameters)

def __repr__(self) -> str:
return f"Parameters({', '.join(repr(param) for param in self._parameters_list)})"
return f"Parameters({', '.join(repr(param) for param in self._params)})"

def __getitem__(self, name_or_index: int | str) -> Parameter:
"""Get a parameter by index or name."""
if isinstance(name_or_index, int):
return self._parameters_list[name_or_index]
return self._parameters_dict[name_or_index.lstrip("*")]
return self._params[name_or_index]
name = name_or_index.lstrip("*")
try:
return next(param for param in self._params if param.name == name)
except StopIteration as error:
raise KeyError(f"parameter {name_or_index} not found") from error

def __setitem__(self, name_or_index: int | str, parameter: Parameter) -> None:
"""Set a parameter by index or name."""
if isinstance(name_or_index, int):
self._params[name_or_index] = parameter
else:
name = name_or_index.lstrip("*")
try:
index = next(idx for idx, param in enumerate(self._params) if param.name == name)
except StopIteration:
self._params.append(parameter)
else:
self._params[index] = parameter

def __delitem__(self, name_or_index: int | str) -> None:
"""Delete a parameter by index or name."""
if isinstance(name_or_index, int):
del self._params[name_or_index]
else:
name = name_or_index.lstrip("*")
try:
index = next(idx for idx, param in enumerate(self._params) if param.name == name)
except StopIteration as error:
raise KeyError(f"parameter {name_or_index} not found") from error
del self._params[index]

def __len__(self):
"""The number of parameters."""
return len(self._parameters_list)
return len(self._params)

def __iter__(self):
"""Iterate over the parameters, in order."""
return iter(self._parameters_list)
return iter(self._params)

def __contains__(self, param_name: str):
"""Whether a parameter with the given name is present."""
return param_name.lstrip("*") in self._parameters_dict
try:
next(param for param in self._params if param.name == param_name.lstrip("*"))
except StopIteration:
return False
return True

def add(self, parameter: Parameter) -> None:
"""Add a parameter to the container.
Expand All @@ -301,11 +331,9 @@ def add(self, parameter: Parameter) -> None:
Raises:
ValueError: When a parameter with the same name is already present.
"""
if parameter.name not in self._parameters_dict:
self._parameters_dict[parameter.name] = parameter
self._parameters_list.append(parameter)
else:
if parameter.name in self:
raise ValueError(f"parameter {parameter.name} already present")
self._params.append(parameter)


class Object(ObjectAliasMixin):
Expand Down
33 changes: 33 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
GriffeLoader,
Module,
NameResolutionError,
Parameter,
Parameters,
module_vtree,
temporary_inspected_module,
temporary_pypackage,
Expand Down Expand Up @@ -493,3 +495,34 @@ def method(self):
assert module["Class.method"].resolve("imported") == "imported"
assert module["Class.method"].resolve("class_attribute") == "module.Class.class_attribute"
assert module["Class.method"].resolve("instance_attribute") == "module.Class.instance_attribute"


def test_set_parameters() -> None:
"""We can set parameters."""
parameters = Parameters()
# Does not exist yet.
parameters["x"] = Parameter(name="x")
assert "x" in parameters
# Already exists, by name.
parameters["x"] = Parameter(name="x")
assert "x" in parameters
assert len(parameters) == 1
# Already exists, by index.
parameters[0] = Parameter(name="y")
assert "y" in parameters
assert len(parameters) == 1


def test_delete_parameters() -> None:
"""We can delete parameters."""
parameters = Parameters()
# By name.
parameters["x"] = Parameter(name="x")
del parameters["x"]
assert "x" not in parameters
assert len(parameters) == 0
# By index.
parameters["x"] = Parameter(name="x")
del parameters[0]
assert "x" not in parameters
assert len(parameters) == 0

0 comments on commit 19f354d

Please sign in to comment.