diff --git a/retrieval_service/cloudsql.tests.cloudbuild.yaml b/retrieval_service/cloudsql.tests.cloudbuild.yaml new file mode 100644 index 00000000..e7cb2f2e --- /dev/null +++ b/retrieval_service/cloudsql.tests.cloudbuild.yaml @@ -0,0 +1,84 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +steps: + - id: Install dependencies + name: python:3.11 + dir: retrieval_service + entrypoint: pip + args: + [ + "install", + "-r", + "requirements.txt", + "-r", + "requirements-test.txt", + "--user", + ] + + - id: Update config + name: python:3.11 + dir: retrieval_service + secretEnv: + - PGUSER + - PGPASSWORD + entrypoint: /bin/bash + args: + - "-c" + - | + # Create config + cp example-config-cloudsql.yml config.yml + sed -i "s/my_database/${_DATABASE_NAME}/g" config.yml + sed -i "s/my-user/$$PGUSER/g" config.yml + sed -i "s/my-password/$$PGPASSWORD/g" config.yml + sed -i "s/my-project/$PROJECT_ID/g" config.yml + sed -i "s/my-region/${_CLOUDSQL_REGION}/g" config.yml + sed -i "s/my-instance/${_CLOUDSQL_INSTANCE}/g" config.yml + + - id: Run Cloud SQL DB integration tests + name: python:3.11 + dir: retrieval_service + env: # Set env var expected by tests + - "DB_NAME=${_DATABASE_NAME}" + - "DB_PROJECT=$PROJECT_ID" + - "DB_REGION=${_CLOUDSQL_REGION}" + - "DB_INSTANCE=${_CLOUDSQL_INSTANCE}" + secretEnv: + - PGUSER + - PGPASSWORD + entrypoint: /bin/bash + args: + - "-c" + - | + # Set env var expected by tests + export DB_USER=$$PGUSER + export DB_PASS=$$PGPASSWORD + python -m pytest datastore/providers/cloudsql_postgres_test.py + +substitutions: + _DATABASE_NAME: test_${SHORT_SHA} + _DATABASE_USER: postgres + _CLOUDSQL_REGION: "us-central1" + _CLOUDSQL_INSTANCE: "my-cloudsql-instance" + +availableSecrets: + secretManager: + - versionName: projects/$PROJECT_ID/secrets/cloudsql_pass/versions/latest + env: PGPASSWORD + - versionName: projects/$PROJECT_ID/secrets/cloudsql_user/versions/latest + env: PGUSER + +options: + substitutionOption: 'ALLOW_LOOSE' + dynamic_substitutions: true diff --git a/retrieval_service/datastore/providers/cloudsql_postgres_test.py b/retrieval_service/datastore/providers/cloudsql_postgres_test.py index 87408269..cf09c731 100644 --- a/retrieval_service/datastore/providers/cloudsql_postgres_test.py +++ b/retrieval_service/datastore/providers/cloudsql_postgres_test.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from datetime import datetime from ipaddress import IPv4Address from typing import Any, AsyncGenerator, List +import asyncpg import pytest import pytest_asyncio +from csv_diff import compare, load_csv # type: ignore +from google.cloud.sql.connector import Connector import models @@ -39,11 +43,6 @@ def db_pass() -> str: return get_env_var("DB_PASS", "password for the postgres user") -@pytest.fixture(scope="module") -def db_name() -> str: - return get_env_var("DB_NAME", "name of a postgres database") - - @pytest.fixture(scope="module") def db_project() -> str: return get_env_var("DB_PROJECT", "project id for google cloud") @@ -59,15 +58,47 @@ def db_instance() -> str: return get_env_var("DB_INSTANCE", "instance for cloud sql") +@pytest.fixture(scope="module") +async def create_db( + db_user: str, db_pass: str, db_project: str, db_region: str, db_instance: str +) -> AsyncGenerator[str, None]: + db_name = get_env_var("DB_NAME", "name of a postgres database") + loop = asyncio.get_running_loop() + connector = Connector(loop=loop) + # Database does not exist, create it. + sys_conn: asyncpg.Connection = await connector.connect_async( + f"{db_project}:{db_region}:{db_instance}", + "asyncpg", + user=f"{db_user}", + password=f"{db_pass}", + db="postgres", + ) + await sys_conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";') + await sys_conn.execute(f'CREATE DATABASE "{db_name}";') + await sys_conn.close() + conn: asyncpg.Connection = await connector.connect_async( + f"{db_project}:{db_region}:{db_instance}", + "asyncpg", + user=f"{db_user}", + password=f"{db_pass}", + db=f"{db_name}", + ) + await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;") + yield db_name + await conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";') + await conn.close() + + @pytest_asyncio.fixture(scope="module") async def ds( + create_db: AsyncGenerator[str, None], db_user: str, db_pass: str, - db_name: str, db_project: str, db_region: str, db_instance: str, ) -> AsyncGenerator[datastore.Client, None]: + db_name = await create_db.__anext__() cfg = cloudsql_postgres.Config( kind="cloudsql-postgres", user=db_user, @@ -77,13 +108,70 @@ async def ds( region=db_region, instance=db_instance, ) + t = create_db ds = await datastore.create(cfg) + + airports_ds_path = "../data/airport_dataset.csv" + amenities_ds_path = "../data/amenity_dataset.csv" + flights_ds_path = "../data/flights_dataset.csv" + airports, amenities, flights = await ds.load_dataset( + airports_ds_path, amenities_ds_path, flights_ds_path + ) + await ds.initialize_data(airports, amenities, flights) + if ds is None: raise TypeError("datastore creation failure") yield ds - print("after yield") await ds.close() - print("closed database") + + +async def test_export_dataset(ds: cloudsql_postgres.Client): + airports, amenities, flights = await ds.export_data() + + airports_ds_path = "../data/airport_dataset.csv" + amenities_ds_path = "../data/amenity_dataset.csv" + flights_ds_path = "../data/flights_dataset.csv" + + airports_new_path = "../data/airport_dataset.csv.new" + amenities_new_path = "../data/amenity_dataset.csv.new" + flights_new_path = "../data/flights_dataset.csv.new" + + await ds.export_dataset( + airports, + amenities, + flights, + airports_new_path, + amenities_new_path, + flights_new_path, + ) + + diff_airports = compare( + load_csv(open(airports_ds_path), "id"), load_csv(open(airports_new_path), "id") + ) + assert diff_airports["added"] == [] + assert diff_airports["removed"] == [] + assert diff_airports["changed"] == [] + assert diff_airports["columns_added"] == [] + assert diff_airports["columns_removed"] == [] + + diff_amenities = compare( + load_csv(open(amenities_ds_path), "id"), + load_csv(open(amenities_new_path), "id"), + ) + assert diff_amenities["added"] == [] + assert diff_amenities["removed"] == [] + assert diff_amenities["changed"] == [] + assert diff_amenities["columns_added"] == [] + assert diff_amenities["columns_removed"] == [] + + diff_flights = compare( + load_csv(open(flights_ds_path), "id"), load_csv(open(flights_new_path), "id") + ) + assert diff_flights["added"] == [] + assert diff_flights["removed"] == [] + assert diff_flights["changed"] == [] + assert diff_flights["columns_added"] == [] + assert diff_flights["columns_removed"] == [] async def test_get_airport_by_id(ds: cloudsql_postgres.Client): diff --git a/retrieval_service/example-config-cloudsql.yml b/retrieval_service/example-config-cloudsql.yml new file mode 100644 index 00000000..a1ab6c96 --- /dev/null +++ b/retrieval_service/example-config-cloudsql.yml @@ -0,0 +1,10 @@ +host: 0.0.0.0 +datastore: + # Example for Cloud SQL + kind: "cloudsql-postgres" + project: "my-project" + region: "my-region" + instance: "my-instance" + database: "my_database" + user: "my-user" + password: "my-password"