Skip to content

Commit

Permalink
fix asyncpg connect
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Jan 8, 2024
1 parent 83f8391 commit 002d09b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
3 changes: 2 additions & 1 deletion retrieval_service/cloudsql.tests.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ steps:
name: python:3.11
dir: retrieval_service
env: # Set env var expected by tests
- "DB_HOST=${_DATABASE_HOST}"
- "DB_NAME=${_DATABASE_NAME}"
- "DB_PROJECT=$PROJECT_ID"
- "DB_REGION=${_CLOUDSQL_REGION}"
Expand All @@ -69,7 +70,7 @@ steps:
substitutions:
_DATABASE_NAME: test_${SHORT_SHA}
_DATABASE_USER: postgres
_DATABASE_HOST: 127.0.0.1
_DATABASE_HOST: "34.29.26.230"
_CLOUDSQL_REGION: "us-central1"
_CLOUDSQL_INSTANCE: "my-cloudsql-instance"

Expand Down
39 changes: 33 additions & 6 deletions retrieval_service/datastore/providers/cloudsql_postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,26 +55,51 @@ def db_project() -> str:
def db_region() -> str:
return get_env_var("DB_REGION", "region for cloud sql instance")


@pytest.fixture(scope="module")
def db_instance() -> str:
return get_env_var("DB_INSTANCE", "instance for cloud sql")

@pytest.fixture(scope="module")
async def create_db(db_user: str, db_name: str) -> AsyncGenerator[None, None]:
def db_host() -> str:
return get_env_var("DB_IP", "public ip for cloud sql instance")

@pytest.fixture(scope="module")
async def create_db(db_user: str, db_pass: str, db_name: str, db_host: str) -> AsyncGenerator[None, None]:
try:
conn = await asyncpg.connect(user=db_user, database=db_name)
print("actually in the function")
conn = await asyncpg.connect(
host=db_host,
port=5432,
user=db_user,
password=db_pass,
database=db_name,
timeout=500,
)
except asyncpg.InvalidCatalogNameError:
print("in the error")
# Database does not exist, create it.
sys_conn = await asyncpg.connect(
host=db_host,
port=5432,
database='template1',
password=db_pass,
user=db_user,
timeout=500,
)
await sys_conn.execute(f'CREATE DATABASE "{db_name}";')
conn = await asyncpg.connect(user=db_user, database=db_name)
conn = await asyncpg.connect(
host=db_host,
port=5432,
user=db_user,
password=db_pass,
database=db_name,
timeout=500,
)
await conn.execute("CREATE EXTENSION vector;")
print("created")
await sys_conn.close()
except Exception as error:
print("Error while connecting to db: {}".format(error))
print("run async generator")
yield
await conn.execute(f'DROP DATABASE IF EXISTS "{db_name}";')

Expand All @@ -88,7 +113,9 @@ async def ds(
db_region: str,
db_instance: str,
) -> AsyncGenerator[datastore.Client, None]:
t = create_db
t = await create_db.__anext__()
print("after create_db")
print(t)
cfg = cloudsql_postgres.Config(
kind="cloudsql-postgres",
user=db_user,
Expand Down

0 comments on commit 002d09b

Please sign in to comment.