Skip to content

Commit

Permalink
fix(ingest/snowflake): fix type annotations + refactor get_connect_ar…
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored and cccs-Dustin committed Feb 1, 2023
1 parent 0f6e5f3 commit 528ed79
Showing 1 changed file with 20 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Optional
from typing import Any, Dict, Optional

import pydantic
import snowflake.connector
Expand Down Expand Up @@ -153,7 +153,7 @@ class BaseSnowflakeConfig(BaseTimeWindowConfig):
default=True,
description="If enabled, populates the snowflake view->table and table->view lineages (no view->view lineage yet). Requires appropriate grants given to the role, and include_table_lineage to be True. view->table lineage requires Snowflake Enterprise Edition or above.",
)
connect_args: Optional[Dict] = pydantic.Field(
connect_args: Optional[Dict[str, Any]] = pydantic.Field(
default=None,
description="Connect args to pass to Snowflake SqlAlchemy driver",
exclude=True,
Expand Down Expand Up @@ -297,28 +297,28 @@ def get_sql_alchemy_url(
},
)

_computed_connect_args: Optional[dict] = None

def get_connect_args(self) -> dict:
"""
Builds connect args and updates self.connect_args so that
Subsequent calls to this method are efficient, i.e. do not read files again
Builds connect args, adding defaults and reading a private key from the file if needed.
Caches the results in a private instance variable to avoid reading the file multiple times.
"""

base_connect_args = {
if self._computed_connect_args is not None:
return self._computed_connect_args

connect_args: dict = {
# Improves performance and avoids timeout errors for larger query result
CLIENT_PREFETCH_THREADS: 10,
CLIENT_SESSION_KEEP_ALIVE: True,
}

if self.connect_args is None:
self.connect_args = base_connect_args
else:
# Let user override the default config values
base_connect_args.update(self.connect_args)
self.connect_args = base_connect_args
**(self.connect_args or {}),
}

if (
self.authentication_type == "KEY_PAIR_AUTHENTICATOR"
and "private_key" not in self.connect_args.keys()
"private_key" not in connect_args
and self.authentication_type == "KEY_PAIR_AUTHENTICATOR"
):
if self.private_key is not None:
pkey_bytes = self.private_key.replace("\\n", "\n").encode()
Expand All @@ -337,13 +337,16 @@ def get_connect_args(self) -> dict:
backend=default_backend(),
)

pkb = p_key.private_bytes(
pkb: bytes = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
self.connect_args.update({"private_key": pkb})
return self.connect_args

connect_args["private_key"] = pkb

self._computed_connect_args = connect_args
return connect_args


class SnowflakeConfig(BaseSnowflakeConfig, SQLAlchemyConfig):
Expand Down

0 comments on commit 528ed79

Please sign in to comment.