Skip to content

Commit

Permalink
fix: support for capital sequences and better sequence testing (#550)
Browse files Browse the repository at this point in the history
* fix: support for capital sequences and better sequence testing

* fix: bandit security stuff
  • Loading branch information
vjeeva authored Sep 9, 2024
1 parent e5612d6 commit 48915e8
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 53 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ setup: ## Install development requirements. You should be in a virtualenv
poetry install && pre-commit install

test: ## Run tests
docker build . -t autodesk/pgbelt:latest && docker build tests/integration/files/postgres13-pglogical-docker/ -t autodesk/postgres-pglogical-docker:13 && docker-compose run tests
docker build . -t autodesk/pgbelt:latest && docker build tests/integration/files/postgres13-pglogical-docker/ -t autodesk/postgres-pglogical-docker:13 && docker compose run tests

tests: test

local-dev: ## Sets up docker containers for Postgres DBs and gets you into a docker container with pgbelt installed. DC: testdc, DB: testdb
docker build . -t autodesk/pgbelt:latest && docker build tests/integration/files/postgres13-pglogical-docker/ -t autodesk/postgres-pglogical-docker:13 && docker-compose run localtest
docker build . -t autodesk/pgbelt:latest && docker build tests/integration/files/postgres13-pglogical-docker/ -t autodesk/postgres-pglogical-docker:13 && docker compose run localtest

clean-docker: ## Stop and remove all docker containers and images made from local testing
docker stop $$(docker ps -aq --filter name=^/pgbelt) && docker rm $$(docker ps -aq --filter name=^/pgbelt) && docker-compose down --rmi all
docker stop $$(docker ps -aq --filter name=^/pgbelt) && docker rm $$(docker ps -aq --filter name=^/pgbelt) && docker compose down --rmi all

# Note: typer-cli has dependency conflict issues that don't affect it generating docs, see https://github.com/tiangolo/typer-cli/pull/120.
# We need to install the package with pip instead. Then, we run pre-commit to fix the formatting of the generated file.
Expand Down
2 changes: 1 addition & 1 deletion pgbelt/cmd/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def _sync_sequences(
) -> None:

seq_vals = await dump_sequences(src_pool, targeted_sequences, schema, src_logger)
await load_sequences(dst_pool, seq_vals, dst_logger)
await load_sequences(dst_pool, seq_vals, schema, dst_logger)


@run_with_configs
Expand Down
28 changes: 10 additions & 18 deletions pgbelt/util/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,13 @@ async def dump_sequences(
# Get all sequences in the schema
seqs = await pool.fetch(
f"""
SELECT '{schema}' || '.\"' || sequence_name
SELECT sequence_name
FROM information_schema.sequences
WHERE sequence_schema = '{schema}' || '\"';
WHERE sequence_schema = '{schema}';
"""
)

# Note: When in an exodus migration with a non-public schema, the sequence names must be prefixed with the schema name.
# This may not be done by the user, so we must do it here.
proper_sequence_names = None
if targeted_sequences is not None:
proper_sequence_names = []
for seq in targeted_sequences:
if f"{schema}." not in seq:
proper_sequence_names.append(f'{schema}."{seq}"')
else:
proper_sequence_names.append(seq)
targeted_sequences = proper_sequence_names
# Note, in exodus migrations, we expect the sequence names to not contain the schema name when coming into targeted_sequences.

seq_vals = {}
final_seqs = []
Expand All @@ -42,14 +32,16 @@ async def dump_sequences(
final_seqs = [r[0] for r in seqs]

for seq in final_seqs:
res = await pool.fetchval(f"SELECT last_value FROM {seq};")
res = await pool.fetchval(f'SELECT last_value FROM {schema}."{seq}";')
seq_vals[seq.strip()] = res

logger.debug(f"Dumped sequences: {seq_vals}")
return seq_vals


async def load_sequences(pool: Pool, seqs: dict[str, int], logger: Logger) -> None:
async def load_sequences(
pool: Pool, seqs: dict[str, int], schema: str, logger: Logger
) -> None:
"""
given a dict of sequence named mapped to values, set each sequence to the
matching value
Expand All @@ -60,9 +52,9 @@ async def load_sequences(pool: Pool, seqs: dict[str, int], logger: Logger) -> No
logger.info("No sequences to load. Skipping sequence loading.")
return

logger.info(f"Loading sequences {list(seqs.keys())}...")
sql_template = "SELECT pg_catalog.setval('{}', {}, true);"
sql = "\n".join([sql_template.format(k, v) for k, v in seqs.items()])
logger.info(f"Loading sequences {list(seqs.keys())} from schema {schema}...")
sql_template = "SELECT pg_catalog.setval('{}.\"{}\"', {}, true);"
sql = "\n".join([sql_template.format(schema, k, v) for k, v in seqs.items()])
async with pool.acquire() as conn:
async with conn.transaction():
await conn.execute(sql)
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def _create_dbupgradeconfigs() -> dict[str, DbupgradeConfig]:
)
db_upgrade_config_kwargs["tables"] = ["UsersCapital"] if "exodus" in s else None
db_upgrade_config_kwargs["sequences"] = (
["users_id_seq"] if "exodus" in s else None
["userS_id_seq"] if "exodus" in s else None
)
config = DbupgradeConfig(**db_upgrade_config_kwargs)

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/files/test_schema_data.sql
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ INSERT INTO public."UsersCapital2" (id, "hash_firstName", hash_lastname, gender)
-- Name: userS_id_seq; Type: SEQUENCE SET; Schema: public; Owner: owner
--

SELECT pg_catalog.setval('public."userS_id_seq"', 1, false);
SELECT pg_catalog.setval('public."userS_id_seq"', 16, false);


--
-- Name: users2_id_seq; Type: SEQUENCE SET; Schema: public; Owner: owner
--

SELECT pg_catalog.setval('public.users2_id_seq', 1, false);
SELECT pg_catalog.setval('public.users2_id_seq', 15, false);


--
Expand Down
119 changes: 91 additions & 28 deletions tests/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,57 @@ async def _filter_dump(dump: str, keywords_to_exclude: list[str]):
return "\n".join(commands)


async def _compare_sequences(
sequences: str, src_root_dsn: str, dst_root_dsn: str, schema_name: str
):
"""
Compare the sequences in the source and destination databases by asynchronously running
PSQL "SELECT last_value FROM sequence_name;" for each sequence in the set.
"""

std_kwargs = {
"stdin": subprocess.PIPE,
"stdout": subprocess.PIPE,
"stderr": subprocess.PIPE,
}
src_seq_fetch_processes = await asyncio.gather(
*[
asyncio.create_subprocess_exec(
"psql",
src_root_dsn,
"-c",
f"'SELECT last_value FROM {schema_name}.\"{sequence}\";'",
"-t",
**std_kwargs,
)
for sequence in sequences
]
)
dst_seq_fetch_processes = await asyncio.gather(
*[
asyncio.create_subprocess_exec(
"psql",
dst_root_dsn,
"-c",
f"'SELECT last_value FROM {schema_name}.\"{sequence}\";'",
"-t",
**std_kwargs,
)
for sequence in sequences
]
)

await asyncio.gather(*[p.wait() for p in src_seq_fetch_processes])
await asyncio.gather(*[p.wait() for p in dst_seq_fetch_processes])

for i in range(len(sequences)):
src_val = (await src_seq_fetch_processes[i].communicate())[0].strip()
dst_val = (await dst_seq_fetch_processes[i].communicate())[0].strip()

print(f"Sequence {sequences[i]} in source: {src_val}, destination: {dst_val}")
assert src_val == dst_val


async def _ensure_same_data(configs: dict[str, DbupgradeConfig]):
# Dump the databases and ensure they're the same
# Unfortunately except for the sequence lines because for some reason, the dump in the source is_called is true, yet on the destination is false.
Expand Down Expand Up @@ -320,42 +371,54 @@ async def _ensure_same_data(configs: dict[str, DbupgradeConfig]):

assert src_table_data[table] == dst_table_data[table]

# We also need to ensure the sequences are the same
# I'm using the same code as in the sync_sequences function to do this because it has
# all the logic to handle exodus-style migrations and target the right sequences.
src_pool, dst_pool = await asyncio.gather(
create_pool(configs[setname].src.pglogical_uri, min_size=1),
create_pool(configs[setname].dst.root_uri, min_size=1),
)
src_seq_vals = await pgbelt.util.postgres.dump_sequences(
src_pool,
configs[setname].sequences,
configs[setname].schema_name,
pgbelt.util.logs.get_logger(
configs[setname].db,
configs[setname].dc,
"integration-sequences.src",
),
)
dst_seq_vals = await pgbelt.util.postgres.dump_sequences(
dst_pool,
configs[setname].sequences,
configs[setname].schema_name,
pgbelt.util.logs.get_logger(
configs[setname].db,
configs[setname].dc,
"integration-sequences.dst",
),
)
# Check that the sequences are the same by literally running PSQL "SELECT last_value FROM sequence_name;"

print(
f"Ensuring {setname} source and destination sequences are the same..."
)
assert src_seq_vals == dst_seq_vals

_compare_sequences(
configs[
setname
].sequences, # In exodus-style migrations, we have our sequences defined in the config
configs[setname].src.root_dsn,
configs[setname].dst.root_dsn,
configs[setname].schema_name,
)

else:
print(f"Ensuring {setname} source and destination dumps are the same...")
assert src_dumps_filtered[i] == dst_dumps_filtered[i]

print(
f"Ensuring {setname} source and destination sequences are the same..."
)

# First, get a list of all sequences in the source database in the specified schema
# Synchronous because we need to run it once before the next commands anyways.
sequences = (
subprocess.run(
[
"psql",
f'"{configs[setname].src.root_dsn}"',
"-c",
f"'SELECT sequence_name FROM information_schema.sequences WHERE sequence_schema = \"{configs[setname].schema_name}\";'",
"-t",
],
capture_output=True,
)
.stdout.decode("utf-8")
.strip()
.split("\n")
)

_compare_sequences(
sequences, # In full migrations, we need to get the sequences from the source database
configs[setname].src.root_dsn,
configs[setname].dst.root_dsn,
configs[setname].schema_name,
)


async def _test_teardown_not_full(configs: dict[str, DbupgradeConfig]):
await pgbelt.cmd.teardown.teardown(db=None, dc=configs[list(configs.keys())[0]].dc)
Expand Down

0 comments on commit 48915e8

Please sign in to comment.