From 002d09b19bf1ea627d8668c943da36e761ff285f Mon Sep 17 00:00:00 2001 From: Yuan Teoh Date: Mon, 8 Jan 2024 00:05:40 -0800 Subject: [PATCH] fix asyncpg connect --- .../cloudsql.tests.cloudbuild.yaml | 3 +- .../providers/cloudsql_postgres_test.py | 39 ++++++++++++++++--- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/retrieval_service/cloudsql.tests.cloudbuild.yaml b/retrieval_service/cloudsql.tests.cloudbuild.yaml index d5591e547..0e57d194c 100644 --- a/retrieval_service/cloudsql.tests.cloudbuild.yaml +++ b/retrieval_service/cloudsql.tests.cloudbuild.yaml @@ -50,6 +50,7 @@ steps: name: python:3.11 dir: retrieval_service env: # Set env var expected by tests + - "DB_HOST=${_DATABASE_HOST}" - "DB_NAME=${_DATABASE_NAME}" - "DB_PROJECT=$PROJECT_ID" - "DB_REGION=${_CLOUDSQL_REGION}" @@ -69,7 +70,7 @@ steps: substitutions: _DATABASE_NAME: test_${SHORT_SHA} _DATABASE_USER: postgres - _DATABASE_HOST: 127.0.0.1 + _DATABASE_HOST: "34.29.26.230" _CLOUDSQL_REGION: "us-central1" _CLOUDSQL_INSTANCE: "my-cloudsql-instance" diff --git a/retrieval_service/datastore/providers/cloudsql_postgres_test.py b/retrieval_service/datastore/providers/cloudsql_postgres_test.py index b44b90cb7..4e8a90688 100644 --- a/retrieval_service/datastore/providers/cloudsql_postgres_test.py +++ b/retrieval_service/datastore/providers/cloudsql_postgres_test.py @@ -55,26 +55,51 @@ def db_project() -> str: def db_region() -> str: return get_env_var("DB_REGION", "region for cloud sql instance") - @pytest.fixture(scope="module") def db_instance() -> str: return get_env_var("DB_INSTANCE", "instance for cloud sql") @pytest.fixture(scope="module") -async def create_db(db_user: str, db_name: str) -> AsyncGenerator[None, None]: +def db_host() -> str: + return get_env_var("DB_IP", "public ip for cloud sql instance") + +@pytest.fixture(scope="module") +async def create_db(db_user: str, db_pass: str, db_name: str, db_host: str) -> AsyncGenerator[None, None]: try: - conn = await asyncpg.connect(user=db_user, database=db_name) + print("actually in the function") + conn = await asyncpg.connect( + host=db_host, + port=5432, + user=db_user, + password=db_pass, + database=db_name, + timeout=500, + ) except asyncpg.InvalidCatalogNameError: + print("in the error") # Database does not exist, create it. sys_conn = await asyncpg.connect( + host=db_host, + port=5432, database='template1', + password=db_pass, user=db_user, + timeout=500, ) await sys_conn.execute(f'CREATE DATABASE "{db_name}";') - conn = await asyncpg.connect(user=db_user, database=db_name) + conn = await asyncpg.connect( + host=db_host, + port=5432, + user=db_user, + password=db_pass, + database=db_name, + timeout=500, + ) await conn.execute("CREATE EXTENSION vector;") - print("created") await sys_conn.close() + except Exception as error: + print("Error while connecting to db: {}".format(error)) + print("run async generator") yield await conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";') @@ -88,7 +113,9 @@ async def ds( db_region: str, db_instance: str, ) -> AsyncGenerator[datastore.Client, None]: - t = create_db + t = await create_db.__anext__() + print("after create_db") + print(t) cfg = cloudsql_postgres.Config( kind="cloudsql-postgres", user=db_user,