From 19f354da6a331a12d80a61bd3005cdcc30a3c42c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Mazzucotelli?= Date: Fri, 18 Oct 2024 00:42:45 +0200 Subject: [PATCH] feat: Allow setting and deleting parameters within container --- src/_griffe/models.py | 58 ++++++++++++++++++++++++++++++++----------- tests/test_models.py | 33 ++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 15 deletions(-) diff --git a/src/_griffe/models.py b/src/_griffe/models.py index 83794214..29bb6f07 100644 --- a/src/_griffe/models.py +++ b/src/_griffe/models.py @@ -182,7 +182,7 @@ def __init__( """Initialize the parameter. Parameters: - name: The parameter name. + name: The parameter name, without leading stars (`*` or `**`). annotation: The parameter annotation, if any. kind: The parameter kind. default: The parameter default, if any. @@ -266,31 +266,61 @@ def __init__(self, *parameters: Parameter) -> None: Parameters: *parameters: The initial parameters to add to the container. """ - self._parameters_list: list[Parameter] = [] - self._parameters_dict: dict[str, Parameter] = {} - for parameter in parameters: - self.add(parameter) + self._params: list[Parameter] = list(parameters) def __repr__(self) -> str: - return f"Parameters({', '.join(repr(param) for param in self._parameters_list)})" + return f"Parameters({', '.join(repr(param) for param in self._params)})" def __getitem__(self, name_or_index: int | str) -> Parameter: """Get a parameter by index or name.""" if isinstance(name_or_index, int): - return self._parameters_list[name_or_index] - return self._parameters_dict[name_or_index.lstrip("*")] + return self._params[name_or_index] + name = name_or_index.lstrip("*") + try: + return next(param for param in self._params if param.name == name) + except StopIteration as error: + raise KeyError(f"parameter {name_or_index} not found") from error + + def __setitem__(self, name_or_index: int | str, parameter: Parameter) -> None: + """Set a parameter by index or name.""" + if isinstance(name_or_index, int): + self._params[name_or_index] = parameter + else: + name = name_or_index.lstrip("*") + try: + index = next(idx for idx, param in enumerate(self._params) if param.name == name) + except StopIteration: + self._params.append(parameter) + else: + self._params[index] = parameter + + def __delitem__(self, name_or_index: int | str) -> None: + """Delete a parameter by index or name.""" + if isinstance(name_or_index, int): + del self._params[name_or_index] + else: + name = name_or_index.lstrip("*") + try: + index = next(idx for idx, param in enumerate(self._params) if param.name == name) + except StopIteration as error: + raise KeyError(f"parameter {name_or_index} not found") from error + del self._params[index] def __len__(self): """The number of parameters.""" - return len(self._parameters_list) + return len(self._params) def __iter__(self): """Iterate over the parameters, in order.""" - return iter(self._parameters_list) + return iter(self._params) def __contains__(self, param_name: str): """Whether a parameter with the given name is present.""" - return param_name.lstrip("*") in self._parameters_dict + try: + next(param for param in self._params if param.name == param_name.lstrip("*")) + except StopIteration: + return False + return True def add(self, parameter: Parameter) -> None: """Add a parameter to the container. @@ -301,11 +331,9 @@ def add(self, parameter: Parameter) -> None: Raises: ValueError: When a parameter with the same name is already present. """ - if parameter.name not in self._parameters_dict: - self._parameters_dict[parameter.name] = parameter - self._parameters_list.append(parameter) - else: + if parameter.name in self: raise ValueError(f"parameter {parameter.name} already present") + self._params.append(parameter) class Object(ObjectAliasMixin): diff --git a/tests/test_models.py b/tests/test_models.py index f40c567f..e86eb260 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,6 +13,8 @@ GriffeLoader, Module, NameResolutionError, + Parameter, + Parameters, module_vtree, temporary_inspected_module, temporary_pypackage, @@ -493,3 +495,34 @@ def method(self): assert module["Class.method"].resolve("imported") == "imported" assert module["Class.method"].resolve("class_attribute") == "module.Class.class_attribute" assert module["Class.method"].resolve("instance_attribute") == "module.Class.instance_attribute" + + +def test_set_parameters() -> None: + """We can set parameters.""" + parameters = Parameters() + # Does not exist yet. + parameters["x"] = Parameter(name="x") + assert "x" in parameters + # Already exists, by name. + parameters["x"] = Parameter(name="x") + assert "x" in parameters + assert len(parameters) == 1 + # Already exists, by index. + parameters[0] = Parameter(name="y") + assert "y" in parameters + assert len(parameters) == 1 + + +def test_delete_parameters() -> None: + """We can delete parameters.""" + parameters = Parameters() + # By name. + parameters["x"] = Parameter(name="x") + del parameters["x"] + assert "x" not in parameters + assert len(parameters) == 0 + # By index. + parameters["x"] = Parameter(name="x") + del parameters[0] + assert "x" not in parameters + assert len(parameters) == 0