Skip to content

Commit

Permalink
Merge pull request #20 from kiwix/backend-scheduler
Browse files Browse the repository at this point in the history
set up scheduler to create tasks for idle workers
  • Loading branch information
elfkuzco authored Jun 21, 2024
2 parents 460ac2d + 9a8cea8 commit 9074157
Show file tree
Hide file tree
Showing 31 changed files with 618 additions and 168 deletions.
7 changes: 5 additions & 2 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"pycountry==24.6.1",
"cryptography==42.0.8",
"PyJWT==2.8.0",
"paramiko==3.4.0",
]
license = {text = "GPL-3.0-or-later"}
classifiers = [
Expand All @@ -37,6 +38,7 @@ Homepage = "https://github.com/kiwix/mirrors-qa"

[project.scripts]
update-mirrors = "mirrors_qa_backend.entrypoint:main"
mirrors-qa-scheduler = "mirrors_qa_backend.scheduler:main"

[project.optional-dependencies]
scripts = [
Expand All @@ -53,7 +55,6 @@ test = [
"pytest==8.0.0",
"coverage==7.4.1",
"Faker==25.8.0",
"paramiko==3.4.0",
"httpx==0.27.0",
]
dev = [
Expand Down Expand Up @@ -189,6 +190,8 @@ ignore = [
"S603",
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Ignore warnings on missing timezone info
"DTZ005", "DTZ001", "DTZ006",
]
unfixable = [
# Don't touch unused imports
Expand All @@ -215,7 +218,7 @@ testpaths = ["tests"]
pythonpath = [".", "src"]
addopts = "--strict-markers"
markers = [
"num_tests: number of tests to create in the database (default: 10)",
"num_tests(num=10, *, status=..., country_code=...): create num tests in the database using status and/or country_code. Random data is chosen for country_code or status if either is not set",
]

[tool.coverage.paths]
Expand Down
40 changes: 26 additions & 14 deletions backend/src/mirrors_qa_backend/cryptography.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# pyright: strict, reportGeneralTypeIssues=false
import datetime
from pathlib import Path

import jwt
import paramiko
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey

from mirrors_qa_backend.exceptions import PEMPublicKeyLoadError
from mirrors_qa_backend.settings import Settings


def verify_signed_message(public_key: bytes, signature: bytes, message: bytes) -> bool:
Expand Down Expand Up @@ -44,13 +43,26 @@ def sign_message(private_key: RSAPrivateKey, message: bytes) -> bytes:
)


def generate_access_token(worker_id: str) -> str:
issue_time = datetime.datetime.now(datetime.UTC)
expire_time = issue_time + datetime.timedelta(hours=Settings.TOKEN_EXPIRY)
payload = {
"iss": "mirrors-qa-backend", # issuer
"exp": expire_time.timestamp(), # expiration time
"iat": issue_time.timestamp(), # issued at
"subject": worker_id,
}
return jwt.encode(payload, key=Settings.JWT_SECRET, algorithm="HS256")
def load_private_key_from_path(private_key_fpath: Path) -> RSAPrivateKey:
with private_key_fpath.open("rb") as key_file:
return serialization.load_pem_private_key(
key_file.read(), password=None
) # pyright: ignore[reportReturnType]


def generate_public_key(private_key: RSAPrivateKey) -> RSAPublicKey:
return private_key.public_key()


def serialize_public_key(public_key: RSAPublicKey) -> bytes:
return public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)


def get_public_key_fingerprint(public_key: RSAPublicKey) -> str:
"""Compute the SHA256 fingerprint of the public key"""
return paramiko.RSAKey(
key=public_key
).fingerprint # pyright: ignore[reportUnknownMemberType, UnknownVariableType]
11 changes: 6 additions & 5 deletions backend/src/mirrors_qa_backend/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from sqlalchemy.orm import sessionmaker

from mirrors_qa_backend import logger
from mirrors_qa_backend.db import mirrors, models
from mirrors_qa_backend.db import models
from mirrors_qa_backend.db.mirrors import create_or_update_mirror_status
from mirrors_qa_backend.extract import get_current_mirrors
from mirrors_qa_backend.settings import Settings

Expand Down Expand Up @@ -44,16 +45,16 @@ def initialize_mirrors() -> None:
if nb_mirrors == 0:
logger.info("No mirrors exist in database.")
if not current_mirrors:
logger.info(f"No mirrors were found on {Settings.MIRRORS_URL!r}")
logger.info(f"No mirrors were found on {Settings.MIRRORS_URL}")
return
result = mirrors.create_or_update_status(session, current_mirrors)
result = create_or_update_mirror_status(session, current_mirrors)
logger.info(
f"Registered {result.nb_mirrors_added} mirrors "
f"from {Settings.MIRRORS_URL!r}"
f"from {Settings.MIRRORS_URL}"
)
else:
logger.info(f"Found {nb_mirrors} mirrors in database.")
result = mirrors.create_or_update_status(session, current_mirrors)
result = create_or_update_mirror_status(session, current_mirrors)
logger.info(
f"Added {result.nb_mirrors_added} mirrors. "
f"Disabled {result.nb_mirrors_disabled} mirrors."
Expand Down
16 changes: 16 additions & 0 deletions backend/src/mirrors_qa_backend/db/country.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from sqlalchemy import select
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db.models import Country


def get_countries(session: OrmSession, *country_codes: str) -> list[Country]:
return list(
session.scalars(select(Country).where(Country.code.in_(country_codes))).all()
)


def get_country_or_none(session: OrmSession, country_code: str) -> Country | None:
return session.scalars(
select(Country).where(Country.code == country_code)
).one_or_none()
4 changes: 4 additions & 0 deletions backend/src/mirrors_qa_backend/db/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ def __init__(self, message: str, *args: object) -> None:

class EmptyMirrorsError(Exception):
"""An empty list was used to update the mirrors in the database."""


class DuplicatePrimaryKeyError(Exception):
"""A database record with the same primary key exists."""
12 changes: 6 additions & 6 deletions backend/src/mirrors_qa_backend/db/mirrors.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int:
db_mirror.country = country
session.add(db_mirror)
logger.debug(
f"Registered new mirror: {db_mirror.id!r} for country: {country.name!r}"
f"Registered new mirror: {db_mirror.id} for country: {country.name}"
)
nb_created += 1
return nb_created


def create_or_update_status(
def create_or_update_mirror_status(
session: OrmSession, mirrors: list[schemas.Mirror]
) -> MirrorsUpdateResult:
"""Updates the status of mirrors in the database and creates any new mirrors.
Expand Down Expand Up @@ -96,16 +96,16 @@ def create_or_update_status(
for db_mirror_id, db_mirror in db_mirrors.items():
if db_mirror_id not in current_mirrors:
logger.debug(
f"Disabling mirror: {db_mirror.id!r} for "
f"country: {db_mirror.country.name!r}"
f"Disabling mirror: {db_mirror.id} for "
f"country: {db_mirror.country.name}"
)
db_mirror.enabled = False
session.add(db_mirror)
result.nb_mirrors_disabled += 1
elif not db_mirror.enabled: # re-enable mirror if it was disabled
logger.debug(
f"Re-enabling mirror: {db_mirror.id!r} for "
f"country: {db_mirror.country.name!r}"
f"Re-enabling mirror: {db_mirror.id} for "
f"country: {db_mirror.country.name}"
)
db_mirror.enabled = True
session.add(db_mirror)
Expand Down
10 changes: 9 additions & 1 deletion backend/src/mirrors_qa_backend/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class Country(Base):
cascade="all, delete-orphan",
)

tests: Mapped[list[Test]] = relationship(back_populates="country", init=False)

__table_args__ = (UniqueConstraint("name", "code"),)


Expand Down Expand Up @@ -131,7 +133,11 @@ class Test(Base):
ip_address: Mapped[IPv4Address | None] = mapped_column(default=None)
# autonomous system based on IP
asn: Mapped[str | None] = mapped_column(default=None)
country: Mapped[str | None] = mapped_column(default=None) # country based on IP
country_code: Mapped[str | None] = mapped_column(
ForeignKey("country.code"),
init=False,
default=None,
)
location: Mapped[str | None] = mapped_column(default=None) # city based on IP
latency: Mapped[int | None] = mapped_column(default=None) # milliseconds
download_size: Mapped[int | None] = mapped_column(default=None) # bytes
Expand All @@ -142,3 +148,5 @@ class Test(Base):
)

worker: Mapped[Worker | None] = relationship(back_populates="tests", init=False)

country: Mapped[Country | None] = relationship(back_populates="tests", init=False)
71 changes: 64 additions & 7 deletions backend/src/mirrors_qa_backend/db/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from ipaddress import IPv4Address
from uuid import UUID

from sqlalchemy import UnaryExpression, asc, desc, func, select
from sqlalchemy import UnaryExpression, asc, desc, func, select, update
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db import models
from mirrors_qa_backend.db.country import get_country_or_none
from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError
from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum
from mirrors_qa_backend.settings import Settings
Expand All @@ -24,7 +25,7 @@ def filter_test(
test: models.Test,
*,
worker_id: str | None = None,
country: str | None = None,
country_code: str | None = None,
statuses: list[StatusEnum] | None = None,
) -> bool:
"""Checks if a test has the same attribute as the provided attribute.
Expand All @@ -34,7 +35,7 @@ def filter_test(
"""
if worker_id is not None and test.worker_id != worker_id:
return False
if country is not None and test.country != country:
if country_code is not None and test.country_code != country_code:
return False
if statuses is not None and test.status not in statuses:
return False
Expand All @@ -51,7 +52,7 @@ def list_tests(
session: OrmSession,
*,
worker_id: str | None = None,
country: str | None = None,
country_code: str | None = None,
statuses: list[StatusEnum] | None = None,
page_num: int = 1,
page_size: int = Settings.MAX_PAGE_SIZE,
Expand Down Expand Up @@ -87,7 +88,7 @@ def list_tests(
select(func.count().over().label("total_records"), models.Test)
.where(
(models.Test.worker_id == worker_id) | (worker_id is None),
(models.Test.country == country) | (country is None),
(models.Test.country_code == country_code) | (country_code is None),
(models.Test.status.in_(statuses)),
)
.order_by(*order_by)
Expand All @@ -113,7 +114,7 @@ def create_or_update_test(
error: str | None = None,
ip_address: IPv4Address | None = None,
asn: str | None = None,
country: str | None = None,
country_code: str | None = None,
location: str | None = None,
latency: int | None = None,
download_size: int | None = None,
Expand All @@ -135,7 +136,9 @@ def create_or_update_test(
test.error = error if error else test.error
test.ip_address = ip_address if ip_address else test.ip_address
test.asn = asn if asn else test.asn
test.country = country if country else test.country
test.country = (
get_country_or_none(session, country_code) if country_code else test.country
)
test.location = location if location else test.location
test.latency = latency if latency else test.latency
test.download_size = download_size if download_size else test.download_size
Expand All @@ -144,5 +147,59 @@ def create_or_update_test(
test.started_on = started_on if started_on else test.started_on

session.add(test)
session.flush()

return test


def create_test(
session: OrmSession,
*,
worker_id: str | None = None,
status: StatusEnum = StatusEnum.PENDING,
error: str | None = None,
ip_address: IPv4Address | None = None,
asn: str | None = None,
country_code: str | None = None,
location: str | None = None,
latency: int | None = None,
download_size: int | None = None,
duration: int | None = None,
speed: float | None = None,
started_on: datetime.datetime | None = None,
) -> models.Test:
return create_or_update_test(
session,
test_id=None,
worker_id=worker_id,
status=status,
error=error,
ip_address=ip_address,
asn=asn,
country_code=country_code,
location=location,
latency=latency,
download_size=download_size,
duration=duration,
speed=speed,
started_on=started_on,
)


def expire_tests(
session: OrmSession, interval: datetime.timedelta
) -> list[models.Test]:
"""Change the status of PENDING tests created before the interval to MISSED"""
end = datetime.datetime.now() - interval
begin = datetime.datetime.fromtimestamp(0)
return list(
session.scalars(
update(models.Test)
.where(
models.Test.requested_on.between(begin, end),
models.Test.status == StatusEnum.PENDING,
)
.values(status=StatusEnum.MISSED)
.returning(models.Test)
).all()
)
Loading

0 comments on commit 9074157

Please sign in to comment.