diff --git a/backend/src/mirrors_qa_backend/cli/mirrors.py b/backend/src/mirrors_qa_backend/cli/mirrors.py index 3e14fa6..fe76e67 100644 --- a/backend/src/mirrors_qa_backend/cli/mirrors.py +++ b/backend/src/mirrors_qa_backend/cli/mirrors.py @@ -1,6 +1,16 @@ from mirrors_qa_backend import logger from mirrors_qa_backend.db import Session -from mirrors_qa_backend.db.mirrors import create_or_update_mirror_status +from mirrors_qa_backend.db.mirrors import ( + create_or_update_mirror_status, + get_mirror, + update_mirror_countries_from_regions, +) +from mirrors_qa_backend.db.mirrors import ( + update_mirror_country as update_db_mirror_country, +) +from mirrors_qa_backend.db.mirrors import ( + update_mirror_region as update_db_mirror_region, +) from mirrors_qa_backend.extract import get_current_mirrors @@ -13,3 +23,26 @@ def update_mirrors() -> None: f"Updated mirrors list. Added {results.nb_mirrors_added} mirror(s), " f"disabled {results.nb_mirrors_disabled} mirror(s)" ) + + +def update_mirror_other_countries(mirror_id: str, region_codes: set[str]) -> None: + with Session.begin() as session: + mirror = update_mirror_countries_from_regions( + session, get_mirror(session, mirror_id), region_codes + ) + + logger.info( + f"Set {len(mirror.other_countries)} countries " # pyright: ignore[reportGeneralTypeIssues,reportArgumentType] + f"for mirror {mirror.id}" + ) + + +def update_mirror_region(mirror_id: str, region_code: str) -> None: + """Update the region the mirror server is located.""" + with Session.begin() as session: + update_db_mirror_region(session, get_mirror(session, mirror_id), region_code) + + +def update_mirror_country(mirror_id: str, country_code: str) -> None: + with Session.begin() as session: + update_db_mirror_country(session, get_mirror(session, mirror_id), country_code) diff --git a/backend/src/mirrors_qa_backend/db/mirrors.py b/backend/src/mirrors_qa_backend/db/mirrors.py index 3f6e4bf..b315a27 100644 --- a/backend/src/mirrors_qa_backend/db/mirrors.py +++ b/backend/src/mirrors_qa_backend/db/mirrors.py @@ -1,13 +1,18 @@ from dataclasses import dataclass +from itertools import chain from sqlalchemy import select from sqlalchemy.orm import Session as OrmSession from mirrors_qa_backend import logger, schemas -from mirrors_qa_backend.db.country import get_country_or_none +from mirrors_qa_backend.db.country import get_country, get_country_or_none from mirrors_qa_backend.db.exceptions import EmptyMirrorsError, RecordDoesNotExistError from mirrors_qa_backend.db.models import Mirror -from mirrors_qa_backend.db.region import get_region_or_none +from mirrors_qa_backend.db.region import ( + get_countries_for, + get_region, + get_region_or_none, +) @dataclass @@ -18,10 +23,14 @@ class MirrorsUpdateResult: nb_mirrors_disabled: int = 0 -def update_mirror_country( +def _update_mirror_country_and_region( session: OrmSession, country_code: str, mirror: Mirror ) -> Mirror: - logger.debug("Updating mirror country information.") + """Update the mirror country and region using the country code if they exist. + + Used during mirror list update to set region and country as these fields + were missing in old DB schema. + """ mirror.country = get_country_or_none(session, country_code) if mirror.country and mirror.country.region_code: mirror.region = get_region_or_none(session, mirror.country.region_code) @@ -53,7 +62,7 @@ def create_mirrors(session: OrmSession, mirrors: list[schemas.Mirror]) -> int: session.add(db_mirror) if mirror.country_code: - update_mirror_country(session, mirror.country_code, db_mirror) + _update_mirror_country_and_region(session, mirror.country_code, db_mirror) logger.debug(f"Registered new mirror: {db_mirror.id}.") nb_created += 1 @@ -113,7 +122,7 @@ def create_or_update_mirror_status( if db_mirror_id in current_mirrors: country_code = current_mirrors[db_mirror_id].country_code if country_code: - update_mirror_country(session, country_code, db_mirror) + _update_mirror_country_and_region(session, country_code, db_mirror) return result @@ -132,3 +141,62 @@ def get_enabled_mirrors(session: OrmSession) -> list[Mirror]: select(Mirror).where(Mirror.enabled == True) # noqa: E712 ).all() ) + + +def update_mirror_countries_from_regions( + session: OrmSession, mirror: Mirror, region_codes: set[str] +) -> Mirror: + """Update the list of other countries for a mirror with countries from region codes. + + Sets the list of other countries to empty if no regions are provided. + + Because otherCountries overrides the mirror region and country choice as per + MirroBrain configuration, the list of countries in this mirror's region is + added to the list. + """ + if not region_codes: + mirror.other_countries = [] + session.add(mirror) + return mirror + + if mirror.region_code: + region_codes.add(mirror.region_code) + + country_codes = [ + country.code + for country in chain.from_iterable( + get_countries_for(session, region_code) for region_code in region_codes + ) + ] + if not country_codes: + raise ValueError("No countries found in provided regions.") + + mirror.other_countries = country_codes + session.add(mirror) + return mirror + + +def update_mirror_region( + session: OrmSession, mirror: Mirror, region_code: str +) -> Mirror: + """Updates the region the mirror server is located. + + Assumes the region exists in the DB. + This does not update the country the mirror is located in. + """ + mirror.region = get_region(session, region_code) + session.add(mirror) + return mirror + + +def update_mirror_country( + session: OrmSession, mirror: Mirror, country_code: str +) -> Mirror: + """Update the country the mirror server is located. + + Assumes the country exists in the DB. + This does not update the region the mirror is located in. + """ + mirror.country = get_country(session, country_code) + session.add(mirror) + return mirror diff --git a/backend/src/mirrors_qa_backend/entrypoint.py b/backend/src/mirrors_qa_backend/entrypoint.py index b022086..de5b67d 100644 --- a/backend/src/mirrors_qa_backend/entrypoint.py +++ b/backend/src/mirrors_qa_backend/entrypoint.py @@ -10,12 +10,18 @@ create_regions_and_countries, extract_country_regions_from_csv, ) -from mirrors_qa_backend.cli.mirrors import update_mirrors +from mirrors_qa_backend.cli.mirrors import ( + update_mirror_country, + update_mirror_other_countries, + update_mirror_region, + update_mirrors, +) from mirrors_qa_backend.cli.scheduler import main as start_scheduler from mirrors_qa_backend.cli.worker import create_worker, update_worker from mirrors_qa_backend.settings.scheduler import SchedulerSettings UPDATE_MIRRORS_CLI = "update-mirrors" +UPDATE_MIRROR_CLI = "update-mirror" CREATE_WORKER_CLI = "create-worker" UPDATE_WORKER_CLI = "update-worker" SCHEDULER_CLI = "scheduler" @@ -115,6 +121,38 @@ def main(): ), ) + update_mirror_cli = subparsers.add_parser( + UPDATE_MIRROR_CLI, help="Update details of a mirror." + ) + update_mirror_cli.add_argument( + "mirror_id", help="ID of the mirror.", metavar="mirror-id" + ) + update_mirror_cli_opts = update_mirror_cli.add_mutually_exclusive_group( + required=True + ) + update_mirror_cli_opts.add_argument( + "--regions", + help=( + "Comma seperated two-letter region codes whose countries " + "should be sent to this mirror. The mirror's default region " + "is added to this list.\nSet to empty to remove additional " + "countries." + ), + type=lambda regions: regions.split(","), + dest="regions", + metavar="codes", + ) + update_mirror_cli_opts.add_argument( + "--region", + help="Two-letter region code of where the mirror server is located.", + metavar="code", + ) + update_mirror_cli_opts.add_argument( + "--country", + help="ISO 3166-1 alpha-2 country code of where the mirror server is located.", + metavar="code", + ) + args = parser.parse_args() if args.verbose: logger.setLevel(logging.DEBUG) @@ -170,6 +208,36 @@ def main(): logger.error(f"error while creating regions: {exc!s}") sys.exit(1) logger.info("Created regions and associated countries.") + elif args.cli_name == UPDATE_MIRROR_CLI: + if args.regions: + try: + logger.debug("Updating mirror region.") + update_mirror_other_countries( + args.mirror_id, region_codes={code for code in args.regions if code} + ) + except Exception as exc: + logger.error(f"error whle updating region for mirror: {exc!s}") + sys.exit(1) + logger.info("Updated additional regions for mirror.") + elif args.region: + logger.debug("Updating default region for mirror.") + try: + update_mirror_region(args.mirror_id, args.region) + except Exception as exc: + logger.error(f"error while updating default region for mirror: {exc!s}") + sys.exit(1) + logger.info("Updated default region for mirror") + elif args.country: + logger.debug("Updating default country for mirror.") + try: + update_mirror_country(args.mirror_id, args.country) + except Exception as exc: + logger.error( + f"error while updating default country for mirror: {exc!s}" + ) + sys.exit(1) + logger.info("Updated default country for mirror.") + else: args.print_help() diff --git a/backend/tests/cli/test_mirror.py b/backend/tests/cli/test_mirror.py index 93238c4..c609b89 100644 --- a/backend/tests/cli/test_mirror.py +++ b/backend/tests/cli/test_mirror.py @@ -1,10 +1,12 @@ from sqlalchemy.orm import Session as OrmSession from mirrors_qa_backend.db import models -from mirrors_qa_backend.db.mirrors import update_mirror_country +from mirrors_qa_backend.db.mirrors import ( + _update_mirror_country_and_region, # pyright: ignore[reportPrivateUsage] +) -def test_update_mirror_region_and_country( +def test_update_mirror_country_and_region( dbsession: OrmSession, db_mirror: models.Mirror ): @@ -16,7 +18,7 @@ def test_update_mirror_region_and_country( country.region = region dbsession.add(country) - db_mirror = update_mirror_country(dbsession, country.code, db_mirror) + db_mirror = _update_mirror_country_and_region(dbsession, country.code, db_mirror) assert db_mirror.country is not None assert db_mirror.country == country assert db_mirror.region == region diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 0092fde..33055a3 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -16,7 +16,7 @@ from mirrors_qa_backend.cryptography import sign_message from mirrors_qa_backend.db import Session from mirrors_qa_backend.db.country import create_country -from mirrors_qa_backend.db.models import Base, Mirror, Test, Worker +from mirrors_qa_backend.db.models import Base, Country, Mirror, Region, Test, Worker from mirrors_qa_backend.db.worker import update_worker_countries from mirrors_qa_backend.enums import StatusEnum from mirrors_qa_backend.serializer import serialize_mirror @@ -188,3 +188,39 @@ def new_schema_mirror() -> schemas.Mirror: as_only=None, other_countries=None, ) + + +@pytest.fixture +def africa_region(dbsession: OrmSession) -> Region: + """Set up a region in Africa and add some default countries.""" + region = Region(code="af", name="Africa") + countries = [ + Country(code="ng", name="Nigeria"), + ] + region.countries = countries + dbsession.add(region) + return region + + +@pytest.fixture +def europe_region(dbsession: OrmSession) -> Region: + """Set up a region in Europe and add some default countries.""" + region = Region(code="eu", name="Europe") + countries = [ + Country(code="fr", name="France"), + ] + region.countries = countries + dbsession.add(region) + return region + + +@pytest.fixture +def asia_region(dbsession: OrmSession) -> Region: + """Set up a region in Asia and add some default countries.""" + region = Region(code="as", name="Asia") + countries = [ + Country(code="jp", name="Japan"), + ] + region.countries = countries + dbsession.add(region) + return region diff --git a/backend/tests/db/test_mirrors.py b/backend/tests/db/test_mirrors.py index 6a674df..efca289 100644 --- a/backend/tests/db/test_mirrors.py +++ b/backend/tests/db/test_mirrors.py @@ -3,14 +3,21 @@ from sqlalchemy.orm import Session as OrmSession from mirrors_qa_backend import schemas -from mirrors_qa_backend.db import count_from_stmt, models +from mirrors_qa_backend.db import count_from_stmt from mirrors_qa_backend.db.exceptions import EmptyMirrorsError -from mirrors_qa_backend.db.mirrors import create_mirrors, create_or_update_mirror_status +from mirrors_qa_backend.db.mirrors import ( + create_mirrors, + create_or_update_mirror_status, + update_mirror_countries_from_regions, + update_mirror_country, + update_mirror_region, +) +from mirrors_qa_backend.db.models import Country, Mirror, Region from mirrors_qa_backend.serializer import serialize_mirror def test_db_empty(dbsession: OrmSession): - assert count_from_stmt(dbsession, select(models.Country)) == 0 + assert count_from_stmt(dbsession, select(Country)) == 0 def test_create_no_mirrors(dbsession: OrmSession): @@ -39,7 +46,7 @@ def test_register_new_mirror( def test_disable_old_mirror( dbsession: OrmSession, - db_mirror: models.Mirror, # noqa: ARG001 [pytest fixture that saves a mirror] + db_mirror: Mirror, # noqa: ARG001 [pytest fixture that saves a mirror] new_schema_mirror: schemas.Mirror, ): result = create_or_update_mirror_status(dbsession, [new_schema_mirror]) @@ -60,7 +67,7 @@ def test_re_enable_existing_mirror( dbsession: OrmSession, ): # Create a mirror in the database with enabled set to False - db_mirror = models.Mirror( + db_mirror = Mirror( id="mirrors.dotsrc.org", base_url="https://mirrors.dotsrc.org/kiwix/", enabled=False, @@ -81,3 +88,57 @@ def test_re_enable_existing_mirror( result = create_or_update_mirror_status(dbsession, [schema_mirror]) assert result.nb_mirrors_added == 1 + + +def test_update_mirror_region( + dbsession: OrmSession, db_mirror: Mirror, africa_region: Region +): + update_mirror_region(dbsession, db_mirror, africa_region.code) + assert db_mirror.region == africa_region + + +def test_update_mirror_country(dbsession: OrmSession, db_mirror: Mirror): + country = Country(code="fr", name="France") + dbsession.add(country) + + update_mirror_country(dbsession, db_mirror, country.code) + assert db_mirror.country == country + + +def test_update_mirror_countries_from_empty_region( + dbsession: OrmSession, db_mirror: Mirror, africa_region: Region +): + + db_mirror.region = africa_region + db_mirror.country = africa_region.countries[0] + db_mirror.other_countries = ["us", "fr"] + dbsession.add(db_mirror) + + update_mirror_countries_from_regions(dbsession, db_mirror, set()) + assert db_mirror.other_countries == [] + + +def test_update_mirror_countries_from_regions( + dbsession: OrmSession, + db_mirror: Mirror, + africa_region: Region, + europe_region: Region, + asia_region: Region, +): + + regions = [asia_region, africa_region, europe_region] + expected_country_codes: set[str] = set() + region_codes: set[str] = set() + + for region in regions: + region_codes.add(region.code) + for country in region.countries: + expected_country_codes.add(country.code) + + db_mirror = update_mirror_countries_from_regions(dbsession, db_mirror, region_codes) + + assert db_mirror.other_countries is not None + assert len(expected_country_codes) == len(db_mirror.other_countries) + + for country_code in expected_country_codes: + assert country_code in db_mirror.other_countries