Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add cloud sql cloudbuild workflow #143

Merged
merged 10 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions retrieval_service/cloudsql.tests.cloudbuild.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

steps:
- id: Install dependencies
name: python:3.11
dir: retrieval_service
entrypoint: pip
args:
[
"install",
"-r",
"requirements.txt",
"-r",
"requirements-test.txt",
"--user",
]

- id: Update config
name: python:3.11
dir: retrieval_service
secretEnv:
- PGUSER
- PGPASSWORD
entrypoint: /bin/bash
args:
- "-c"
- |
# Create config
cp example-config-cloudsql.yml config.yml
sed -i "s/my_database/${_DATABASE_NAME}/g" config.yml
sed -i "s/my-user/$$PGUSER/g" config.yml
sed -i "s/my-password/$$PGPASSWORD/g" config.yml
sed -i "s/my-project/$PROJECT_ID/g" config.yml
sed -i "s/my-region/${_CLOUDSQL_REGION}/g" config.yml
sed -i "s/my-instance/${_CLOUDSQL_INSTANCE}/g" config.yml

- id: Run Cloud SQL DB integration tests
name: python:3.11
dir: retrieval_service
env: # Set env var expected by tests
- "DB_NAME=${_DATABASE_NAME}"
- "DB_PROJECT=$PROJECT_ID"
- "DB_REGION=${_CLOUDSQL_REGION}"
- "DB_INSTANCE=${_CLOUDSQL_INSTANCE}"
secretEnv:
- PGUSER
- PGPASSWORD
entrypoint: /bin/bash
args:
- "-c"
- |
# Set env var expected by tests
export DB_USER=$$PGUSER
export DB_PASS=$$PGPASSWORD
python -m pytest datastore/providers/cloudsql_postgres_test.py

substitutions:
_DATABASE_NAME: test_${SHORT_SHA}
_DATABASE_USER: postgres
_CLOUDSQL_REGION: "us-central1"
_CLOUDSQL_INSTANCE: "my-cloudsql-instance"

availableSecrets:
secretManager:
- versionName: projects/$PROJECT_ID/secrets/cloudsql_pass/versions/latest
env: PGPASSWORD
- versionName: projects/$PROJECT_ID/secrets/cloudsql_user/versions/latest
env: PGUSER

options:
substitutionOption: 'ALLOW_LOOSE'
dynamic_substitutions: true
104 changes: 96 additions & 8 deletions retrieval_service/datastore/providers/cloudsql_postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, AsyncGenerator, List

import asyncpg
import pytest
import pytest_asyncio
from csv_diff import compare, load_csv # type: ignore
from google.cloud.sql.connector import Connector

import models

Expand All @@ -39,11 +43,6 @@ def db_pass() -> str:
return get_env_var("DB_PASS", "password for the postgres user")


@pytest.fixture(scope="module")
def db_name() -> str:
return get_env_var("DB_NAME", "name of a postgres database")


@pytest.fixture(scope="module")
def db_project() -> str:
return get_env_var("DB_PROJECT", "project id for google cloud")
Expand All @@ -59,15 +58,47 @@ def db_instance() -> str:
return get_env_var("DB_INSTANCE", "instance for cloud sql")


@pytest.fixture(scope="module")
async def create_db(
db_user: str, db_pass: str, db_project: str, db_region: str, db_instance: str
) -> AsyncGenerator[str, None]:
db_name = get_env_var("DB_NAME", "name of a postgres database")
loop = asyncio.get_running_loop()
connector = Connector(loop=loop)
# Database does not exist, create it.
sys_conn: asyncpg.Connection = await connector.connect_async(
f"{db_project}:{db_region}:{db_instance}",
"asyncpg",
user=f"{db_user}",
password=f"{db_pass}",
db="postgres",
)
await sys_conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";')
await sys_conn.execute(f'CREATE DATABASE "{db_name}";')
await sys_conn.close()
Comment on lines +76 to +78
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should probably use:

async with sys_conn:
    await sys_conn.execute()

so it closes if there is an error

Copy link
Collaborator Author

@Yuan325 Yuan325 Jan 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't be able to do it this way cause 'Connection' object does not support the asynchronous context manager protocol. (MagicStack/asyncpg#583)

conn: asyncpg.Connection = await connector.connect_async(
f"{db_project}:{db_region}:{db_instance}",
"asyncpg",
user=f"{db_user}",
password=f"{db_pass}",
db=f"{db_name}",
)
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
yield db_name
await conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";')
await conn.close()


@pytest_asyncio.fixture(scope="module")
async def ds(
create_db: AsyncGenerator[str, None],
db_user: str,
db_pass: str,
db_name: str,
db_project: str,
db_region: str,
db_instance: str,
) -> AsyncGenerator[datastore.Client, None]:
db_name = await create_db.__anext__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to do anext like that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the only way that I can get it to work since create_db is returning an async generator (async with yield)

cfg = cloudsql_postgres.Config(
kind="cloudsql-postgres",
user=db_user,
Expand All @@ -77,13 +108,70 @@ async def ds(
region=db_region,
instance=db_instance,
)
t = create_db
ds = await datastore.create(cfg)

airports_ds_path = "../data/airport_dataset.csv"
amenities_ds_path = "../data/amenity_dataset.csv"
flights_ds_path = "../data/flights_dataset.csv"
airports, amenities, flights = await ds.load_dataset(
airports_ds_path, amenities_ds_path, flights_ds_path
)
await ds.initialize_data(airports, amenities, flights)

if ds is None:
raise TypeError("datastore creation failure")
yield ds
print("after yield")
await ds.close()
print("closed database")


async def test_export_dataset(ds: cloudsql_postgres.Client):
airports, amenities, flights = await ds.export_data()

airports_ds_path = "../data/airport_dataset.csv"
amenities_ds_path = "../data/amenity_dataset.csv"
flights_ds_path = "../data/flights_dataset.csv"

airports_new_path = "../data/airport_dataset.csv.new"
amenities_new_path = "../data/amenity_dataset.csv.new"
flights_new_path = "../data/flights_dataset.csv.new"

await ds.export_dataset(
airports,
amenities,
flights,
airports_new_path,
amenities_new_path,
flights_new_path,
)

diff_airports = compare(
load_csv(open(airports_ds_path), "id"), load_csv(open(airports_new_path), "id")
)
assert diff_airports["added"] == []
assert diff_airports["removed"] == []
assert diff_airports["changed"] == []
assert diff_airports["columns_added"] == []
assert diff_airports["columns_removed"] == []

diff_amenities = compare(
load_csv(open(amenities_ds_path), "id"),
load_csv(open(amenities_new_path), "id"),
)
assert diff_amenities["added"] == []
assert diff_amenities["removed"] == []
assert diff_amenities["changed"] == []
assert diff_amenities["columns_added"] == []
assert diff_amenities["columns_removed"] == []

diff_flights = compare(
load_csv(open(flights_ds_path), "id"), load_csv(open(flights_new_path), "id")
)
assert diff_flights["added"] == []
assert diff_flights["removed"] == []
assert diff_flights["changed"] == []
assert diff_flights["columns_added"] == []
assert diff_flights["columns_removed"] == []


async def test_get_airport_by_id(ds: cloudsql_postgres.Client):
Expand Down
10 changes: 10 additions & 0 deletions retrieval_service/example-config-cloudsql.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
host: 0.0.0.0
datastore:
# Example for Cloud SQL
kind: "cloudsql-postgres"
project: "my-project"
region: "my-region"
instance: "my-instance"
database: "my_database"
user: "my-user"
password: "my-password"