diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..81e6a94 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,10 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable diff --git a/.gitignore b/.gitignore index a5aee56..df4d01a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ # environment -venv/ \ No newline at end of file +venv/ +.venv \ No newline at end of file diff --git a/pysquirrel/__init__.py b/pysquirrel/__init__.py index ea8cc0a..e87b28c 100644 --- a/pysquirrel/__init__.py +++ b/pysquirrel/__init__.py @@ -3,4 +3,4 @@ from . import core # create database -nuts = core.AllRegions() \ No newline at end of file +nuts = core.AllRegions() diff --git a/pysquirrel/core.py b/pysquirrel/core.py index 45f6dc5..8125fd7 100644 --- a/pysquirrel/core.py +++ b/pysquirrel/core.py @@ -15,6 +15,7 @@ MIN_DATA_ROW = 2 MAX_DATA_COL = 4 + # utility function def flatten(l): for i in l: @@ -23,14 +24,17 @@ def flatten(l): else: yield i + class Level(IntEnum): LEVEL_1 = 1 LEVEL_2 = 2 LEVEL_3 = 3 + @dataclass(frozen=True) class Region: """Territorial region base class.""" + country_code: str code: str label: str @@ -55,7 +59,7 @@ def check_country_code(cls, v: str): @classmethod def check_code(cls, v: str): """ - Checks if region code follows standard format of a two capital letters + Checks if region code follows standard format of a two capital letters country code followed by an alphanumeric code, between one to three elements. Placeholder region are marked with 'Z' in place of digits. """ @@ -63,7 +67,7 @@ def check_code(cls, v: str): return v else: raise ValueError() - + @field_validator("parent_code") @classmethod def check_parent_code(cls, v: str, info: ValidationInfo): @@ -75,22 +79,29 @@ def check_parent_code(cls, v: str, info: ValidationInfo): return v else: raise ValueError() - + @model_validator(mode="after") def check_code_consistency(self): """Checks if code, country code, and level are all in conformity.""" - if self.code.startswith(self.country_code) and \ - len(self.code) == len(self.country_code) + self.level: + if ( + self.code.startswith(self.country_code) + and len(self.code) == len(self.country_code) + self.level + ): return self + class NUTSRegion(Region): """NUTS-specific implementation of the Region base class.""" + pass + class SRRegion(Region): """SR-specific implementation of the Region base class.""" + pass + class AllRegions: """Database that contains list of all territorial region.""" @@ -105,8 +116,12 @@ def _load(self) -> None: """ Reads data from NUTS spreadsheet into Database and builds search index. """ - nuts2024_hash = "3df559906175180d58a2a283985fb632b799b4cbe034e92515295064a9f2c01e" - pooch.retrieve(FILE_URL, known_hash=nuts2024_hash, fname=FILENAME, path=BASE_PATH) + nuts2024_hash = ( + "3df559906175180d58a2a283985fb632b799b4cbe034e92515295064a9f2c01e" + ) + pooch.retrieve( + FILE_URL, known_hash=nuts2024_hash, fname=FILENAME, path=BASE_PATH + ) spreadsheet = openpyxl.load_workbook( BASE_PATH / FILENAME, read_only=True, data_only=True ) @@ -115,8 +130,11 @@ def _load(self) -> None: for sheet_name, cls in sheet_class.items(): sheet = spreadsheet[sheet_name] for row in sheet.iter_rows(min_row=MIN_DATA_ROW, max_col=MAX_DATA_COL): - if all(cell.value for cell in row): - region = {field.name: cell.value for (field, cell) in zip(fields(cls), row)} + if all(cell.value for cell in row): + region = { + field.name: cell.value + for (field, cell) in zip(fields(cls), row) + } self.data.append(cls(**region)) def _set_index(self) -> None: @@ -136,25 +154,31 @@ def _set_index(self) -> None: key[value] = [region] def _search( - self, param: str, value: str | int, - ) -> list[NUTSRegion | SRRegion]: + self, + param: str, + value: str | int, + ) -> list[NUTSRegion | SRRegion]: """ Searches database index for one value of a parameter and returns a set of all matching result(s). :param param: field to be searched - :param value: value(s) to be searched in the field - + :param value: value(s) to be searched in the field + """ - results = set(flatten([self.search_index[param][key] - for key in self.search_index[param] - if key == value])) - + results = set( + flatten( + [ + self.search_index[param][key] + for key in self.search_index[param] + if key == value + ] + ) + ) + return results - - def get( - self, **params - ) -> list[NUTSRegion | SRRegion, None]: + + def get(self, **params) -> list[NUTSRegion | SRRegion, None]: """ Searches NUTS 2024 classification database. Supports multiple fields/values search. @@ -172,4 +196,4 @@ def get( results.append(self._search(param, value)) else: raise TypeError("only one value per keyword argument allowed.") - return list(set.intersection(*results)) \ No newline at end of file + return list(set.intersection(*results)) diff --git a/setup.py b/setup.py index 692edef..485843b 100644 --- a/setup.py +++ b/setup.py @@ -8,4 +8,4 @@ author="IIASA", url="https://github.com/iiasa/pysquirrel", license="MIT License", - ) \ No newline at end of file +)