Skip to content

Commit

Permalink
Black code formatting and GH action (#7, #8)
Browse files Browse the repository at this point in the history
  • Loading branch information
dc-almeida committed Aug 16, 2024
1 parent 7afac3b commit 007cf39
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 25 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/black.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
name: Lint

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: psf/black@stable
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# environment
venv/
venv/
.venv
2 changes: 1 addition & 1 deletion pysquirrel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from . import core

# create database
nuts = core.AllRegions()
nuts = core.AllRegions()
68 changes: 46 additions & 22 deletions pysquirrel/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MIN_DATA_ROW = 2
MAX_DATA_COL = 4


# utility function
def flatten(l):
for i in l:
Expand All @@ -23,14 +24,17 @@ def flatten(l):
else:
yield i


class Level(IntEnum):
LEVEL_1 = 1
LEVEL_2 = 2
LEVEL_3 = 3


@dataclass(frozen=True)
class Region:
"""Territorial region base class."""

country_code: str
code: str
label: str
Expand All @@ -55,15 +59,15 @@ def check_country_code(cls, v: str):
@classmethod
def check_code(cls, v: str):
"""
Checks if region code follows standard format of a two capital letters
Checks if region code follows standard format of a two capital letters
country code followed by an alphanumeric code, between one to three elements.
Placeholder region are marked with 'Z' in place of digits.
"""
if v[:2].isalpha() and v[:2].isupper() and v[2:].isalnum():
return v
else:
raise ValueError()

@field_validator("parent_code")
@classmethod
def check_parent_code(cls, v: str, info: ValidationInfo):
Expand All @@ -75,22 +79,29 @@ def check_parent_code(cls, v: str, info: ValidationInfo):
return v
else:
raise ValueError()

@model_validator(mode="after")
def check_code_consistency(self):
"""Checks if code, country code, and level are all in conformity."""
if self.code.startswith(self.country_code) and \
len(self.code) == len(self.country_code) + self.level:
if (
self.code.startswith(self.country_code)
and len(self.code) == len(self.country_code) + self.level
):
return self


class NUTSRegion(Region):
"""NUTS-specific implementation of the Region base class."""

pass


class SRRegion(Region):
"""SR-specific implementation of the Region base class."""

pass


class AllRegions:
"""Database that contains list of all territorial region."""

Expand All @@ -105,8 +116,12 @@ def _load(self) -> None:
"""
Reads data from NUTS spreadsheet into Database and builds search index.
"""
nuts2024_hash = "3df559906175180d58a2a283985fb632b799b4cbe034e92515295064a9f2c01e"
pooch.retrieve(FILE_URL, known_hash=nuts2024_hash, fname=FILENAME, path=BASE_PATH)
nuts2024_hash = (
"3df559906175180d58a2a283985fb632b799b4cbe034e92515295064a9f2c01e"
)
pooch.retrieve(
FILE_URL, known_hash=nuts2024_hash, fname=FILENAME, path=BASE_PATH
)
spreadsheet = openpyxl.load_workbook(
BASE_PATH / FILENAME, read_only=True, data_only=True
)
Expand All @@ -115,8 +130,11 @@ def _load(self) -> None:
for sheet_name, cls in sheet_class.items():
sheet = spreadsheet[sheet_name]
for row in sheet.iter_rows(min_row=MIN_DATA_ROW, max_col=MAX_DATA_COL):
if all(cell.value for cell in row):
region = {field.name: cell.value for (field, cell) in zip(fields(cls), row)}
if all(cell.value for cell in row):
region = {
field.name: cell.value
for (field, cell) in zip(fields(cls), row)
}
self.data.append(cls(**region))

def _set_index(self) -> None:
Expand All @@ -136,25 +154,31 @@ def _set_index(self) -> None:
key[value] = [region]

def _search(
self, param: str, value: str | int,
) -> list[NUTSRegion | SRRegion]:
self,
param: str,
value: str | int,
) -> list[NUTSRegion | SRRegion]:
"""
Searches database index for one value of a parameter
and returns a set of all matching result(s).
:param param: field to be searched
:param value: value(s) to be searched in the field
:param value: value(s) to be searched in the field
"""
results = set(flatten([self.search_index[param][key]
for key in self.search_index[param]
if key == value]))

results = set(
flatten(
[
self.search_index[param][key]
for key in self.search_index[param]
if key == value
]
)
)

return results

def get(
self, **params
) -> list[NUTSRegion | SRRegion, None]:

def get(self, **params) -> list[NUTSRegion | SRRegion, None]:
"""
Searches NUTS 2024 classification database. Supports multiple fields/values
search.
Expand All @@ -172,4 +196,4 @@ def get(
results.append(self._search(param, value))
else:
raise TypeError("only one value per keyword argument allowed.")
return list(set.intersection(*results))
return list(set.intersection(*results))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
author="IIASA",
url="https://github.com/iiasa/pysquirrel",
license="MIT License",
)
)

0 comments on commit 007cf39

Please sign in to comment.