Skip to content

Commit

Permalink
add subcommand for updating mirror detail (#43)
Browse files Browse the repository at this point in the history
* add subcommand for updating mirror detail

* add type annotations for set
  • Loading branch information
elfkuzco authored Aug 23, 2024
1 parent 656e88d commit 0c088e1
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 17 deletions.
35 changes: 34 additions & 1 deletion backend/src/mirrors_qa_backend/cli/mirrors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
from mirrors_qa_backend import logger
from mirrors_qa_backend.db import Session
from mirrors_qa_backend.db.mirrors import create_or_update_mirror_status
from mirrors_qa_backend.db.mirrors import (
create_or_update_mirror_status,
get_mirror,
update_mirror_countries_from_regions,
)
from mirrors_qa_backend.db.mirrors import (
update_mirror_country as update_db_mirror_country,
)
from mirrors_qa_backend.db.mirrors import (
update_mirror_region as update_db_mirror_region,
)
from mirrors_qa_backend.extract import get_current_mirrors


Expand All @@ -13,3 +23,26 @@ def update_mirrors() -> None:
f"Updated mirrors list. Added {results.nb_mirrors_added} mirror(s), "
f"disabled {results.nb_mirrors_disabled} mirror(s)"
)


def update_mirror_other_countries(mirror_id: str, region_codes: set[str]) -> None:
with Session.begin() as session:
mirror = update_mirror_countries_from_regions(
session, get_mirror(session, mirror_id), region_codes
)

logger.info(
f"Set {len(mirror.other_countries)} countries " # pyright: ignore[reportGeneralTypeIssues,reportArgumentType]
f"for mirror {mirror.id}"
)


def update_mirror_region(mirror_id: str, region_code: str) -> None:
"""Update the region the mirror server is located."""
with Session.begin() as session:
update_db_mirror_region(session, get_mirror(session, mirror_id), region_code)


def update_mirror_country(mirror_id: str, country_code: str) -> None:
with Session.begin() as session:
update_db_mirror_country(session, get_mirror(session, mirror_id), country_code)
80 changes: 74 additions & 6 deletions backend/src/mirrors_qa_backend/db/mirrors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from dataclasses import dataclass
from itertools import chain

from sqlalchemy import select
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend import logger, schemas
from mirrors_qa_backend.db.country import get_country_or_none
from mirrors_qa_backend.db.country import get_country, get_country_or_none
from mirrors_qa_backend.db.exceptions import EmptyMirrorsError, RecordDoesNotExistError
from mirrors_qa_backend.db.models import Mirror
from mirrors_qa_backend.db.region import get_region_or_none
from mirrors_qa_backend.db.region import (
get_countries_for,
get_region,
get_region_or_none,
)


@dataclass
Expand All @@ -18,10 +23,14 @@ class MirrorsUpdateResult:
nb_mirrors_disabled: int = 0


def update_mirror_country(
def _update_mirror_country_and_region(
session: OrmSession, country_code: str, mirror: Mirror
) -> Mirror:
logger.debug("Updating mirror country information.")
"""Update the mirror country and region using the country code if they exist.
Used during mirror list update to set region and country as these fields
were missing in old DB schema.
"""
mirror.country = get_country_or_none(session, country_code)
if mirror.country and mirror.country.region_code:
mirror.region = get_region_or_none(session, mirror.country.region_code)
Expand Down Expand Up @@ -53,7 +62,7 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int:
session.add(db_mirror)

if mirror.country_code:
update_mirror_country(session, mirror.country_code, db_mirror)
_update_mirror_country_and_region(session, mirror.country_code, db_mirror)

logger.debug(f"Registered new mirror: {db_mirror.id}.")
nb_created += 1
Expand Down Expand Up @@ -113,7 +122,7 @@ def create_or_update_mirror_status(
if db_mirror_id in current_mirrors:
country_code = current_mirrors[db_mirror_id].country_code
if country_code:
update_mirror_country(session, country_code, db_mirror)
_update_mirror_country_and_region(session, country_code, db_mirror)
return result


Expand All @@ -132,3 +141,62 @@ def get_enabled_mirrors(session: OrmSession) -> list[Mirror]:
select(Mirror).where(Mirror.enabled == True) # noqa: E712
).all()
)


def update_mirror_countries_from_regions(
session: OrmSession, mirror: Mirror, region_codes: set[str]
) -> Mirror:
"""Update the list of other countries for a mirror with countries from region codes.
Sets the list of other countries to empty if no regions are provided.
Because otherCountries overrides the mirror region and country choice as per
MirroBrain configuration, the list of countries in this mirror's region is
added to the list.
"""
if not region_codes:
mirror.other_countries = []
session.add(mirror)
return mirror

if mirror.region_code:
region_codes.add(mirror.region_code)

country_codes = [
country.code
for country in chain.from_iterable(
get_countries_for(session, region_code) for region_code in region_codes
)
]
if not country_codes:
raise ValueError("No countries found in provided regions.")

mirror.other_countries = country_codes
session.add(mirror)
return mirror


def update_mirror_region(
session: OrmSession, mirror: Mirror, region_code: str
) -> Mirror:
"""Updates the region the mirror server is located.
Assumes the region exists in the DB.
This does not update the country the mirror is located in.
"""
mirror.region = get_region(session, region_code)
session.add(mirror)
return mirror


def update_mirror_country(
session: OrmSession, mirror: Mirror, country_code: str
) -> Mirror:
"""Update the country the mirror server is located.
Assumes the country exists in the DB.
This does not update the region the mirror is located in.
"""
mirror.country = get_country(session, country_code)
session.add(mirror)
return mirror
70 changes: 69 additions & 1 deletion backend/src/mirrors_qa_backend/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@
create_regions_and_countries,
extract_country_regions_from_csv,
)
from mirrors_qa_backend.cli.mirrors import update_mirrors
from mirrors_qa_backend.cli.mirrors import (
update_mirror_country,
update_mirror_other_countries,
update_mirror_region,
update_mirrors,
)
from mirrors_qa_backend.cli.scheduler import main as start_scheduler
from mirrors_qa_backend.cli.worker import create_worker, update_worker
from mirrors_qa_backend.settings.scheduler import SchedulerSettings

UPDATE_MIRRORS_CLI = "update-mirrors"
UPDATE_MIRROR_CLI = "update-mirror"
CREATE_WORKER_CLI = "create-worker"
UPDATE_WORKER_CLI = "update-worker"
SCHEDULER_CLI = "scheduler"
Expand Down Expand Up @@ -115,6 +121,38 @@ def main():
),
)

update_mirror_cli = subparsers.add_parser(
UPDATE_MIRROR_CLI, help="Update details of a mirror."
)
update_mirror_cli.add_argument(
"mirror_id", help="ID of the mirror.", metavar="mirror-id"
)
update_mirror_cli_opts = update_mirror_cli.add_mutually_exclusive_group(
required=True
)
update_mirror_cli_opts.add_argument(
"--regions",
help=(
"Comma seperated two-letter region codes whose countries "
"should be sent to this mirror. The mirror's default region "
"is added to this list.\nSet to empty to remove additional "
"countries."
),
type=lambda regions: regions.split(","),
dest="regions",
metavar="codes",
)
update_mirror_cli_opts.add_argument(
"--region",
help="Two-letter region code of where the mirror server is located.",
metavar="code",
)
update_mirror_cli_opts.add_argument(
"--country",
help="ISO 3166-1 alpha-2 country code of where the mirror server is located.",
metavar="code",
)

args = parser.parse_args()
if args.verbose:
logger.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -170,6 +208,36 @@ def main():
logger.error(f"error while creating regions: {exc!s}")
sys.exit(1)
logger.info("Created regions and associated countries.")
elif args.cli_name == UPDATE_MIRROR_CLI:
if args.regions:
try:
logger.debug("Updating mirror region.")
update_mirror_other_countries(
args.mirror_id, region_codes={code for code in args.regions if code}
)
except Exception as exc:
logger.error(f"error whle updating region for mirror: {exc!s}")
sys.exit(1)
logger.info("Updated additional regions for mirror.")
elif args.region:
logger.debug("Updating default region for mirror.")
try:
update_mirror_region(args.mirror_id, args.region)
except Exception as exc:
logger.error(f"error while updating default region for mirror: {exc!s}")
sys.exit(1)
logger.info("Updated default region for mirror")
elif args.country:
logger.debug("Updating default country for mirror.")
try:
update_mirror_country(args.mirror_id, args.country)
except Exception as exc:
logger.error(
f"error while updating default country for mirror: {exc!s}"
)
sys.exit(1)
logger.info("Updated default country for mirror.")

else:
args.print_help()

Expand Down
8 changes: 5 additions & 3 deletions backend/tests/cli/test_mirror.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db import models
from mirrors_qa_backend.db.mirrors import update_mirror_country
from mirrors_qa_backend.db.mirrors import (
_update_mirror_country_and_region, # pyright: ignore[reportPrivateUsage]
)


def test_update_mirror_region_and_country(
def test_update_mirror_country_and_region(
dbsession: OrmSession, db_mirror: models.Mirror
):

Expand All @@ -16,7 +18,7 @@ def test_update_mirror_region_and_country(
country.region = region
dbsession.add(country)

db_mirror = update_mirror_country(dbsession, country.code, db_mirror)
db_mirror = _update_mirror_country_and_region(dbsession, country.code, db_mirror)
assert db_mirror.country is not None
assert db_mirror.country == country
assert db_mirror.region == region
38 changes: 37 additions & 1 deletion backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mirrors_qa_backend.cryptography import sign_message
from mirrors_qa_backend.db import Session
from mirrors_qa_backend.db.country import create_country
from mirrors_qa_backend.db.models import Base, Mirror, Test, Worker
from mirrors_qa_backend.db.models import Base, Country, Mirror, Region, Test, Worker
from mirrors_qa_backend.db.worker import update_worker_countries
from mirrors_qa_backend.enums import StatusEnum
from mirrors_qa_backend.serializer import serialize_mirror
Expand Down Expand Up @@ -188,3 +188,39 @@ def new_schema_mirror() -> schemas.Mirror:
as_only=None,
other_countries=None,
)


@pytest.fixture
def africa_region(dbsession: OrmSession) -> Region:
"""Set up a region in Africa and add some default countries."""
region = Region(code="af", name="Africa")
countries = [
Country(code="ng", name="Nigeria"),
]
region.countries = countries
dbsession.add(region)
return region


@pytest.fixture
def europe_region(dbsession: OrmSession) -> Region:
"""Set up a region in Europe and add some default countries."""
region = Region(code="eu", name="Europe")
countries = [
Country(code="fr", name="France"),
]
region.countries = countries
dbsession.add(region)
return region


@pytest.fixture
def asia_region(dbsession: OrmSession) -> Region:
"""Set up a region in Asia and add some default countries."""
region = Region(code="as", name="Asia")
countries = [
Country(code="jp", name="Japan"),
]
region.countries = countries
dbsession.add(region)
return region
Loading

0 comments on commit 0c088e1

Please sign in to comment.