From 5abed7448e84b74e7d09144a7e8c8d8e2648fde1 Mon Sep 17 00:00:00 2001 From: Miles Adkins Date: Mon, 13 Mar 2023 15:58:08 -0500 Subject: [PATCH] feat: Added SnowflakeConnection caching Signed-off-by: Miles Adkins --- .../infra/materialization/snowflake_engine.py | 15 ++- .../feast/infra/offline_stores/snowflake.py | 24 ++-- .../infra/offline_stores/snowflake_source.py | 6 +- .../feast/infra/online_stores/snowflake.py | 16 +-- sdk/python/feast/infra/registry/snowflake.py | 28 ++-- .../infra/utils/snowflake/snowflake_utils.py | 126 ++++++++++-------- .../universal/data_sources/snowflake.py | 9 +- 7 files changed, 122 insertions(+), 102 deletions(-) diff --git a/sdk/python/feast/infra/materialization/snowflake_engine.py b/sdk/python/feast/infra/materialization/snowflake_engine.py index 3b183f97e6..8a63e00891 100644 --- a/sdk/python/feast/infra/materialization/snowflake_engine.py +++ b/sdk/python/feast/infra/materialization/snowflake_engine.py @@ -25,10 +25,10 @@ from feast.infra.online_stores.online_store import OnlineStore from feast.infra.registry.base_registry import BaseRegistry from feast.infra.utils.snowflake.snowflake_utils import ( + GetSnowflakeConnection, _run_snowflake_field_mapping, assert_snowflake_feature_names, execute_snowflake_statement, - get_snowflake_conn, get_snowflake_online_store_path, package_snowpark_zip, ) @@ -121,7 +121,7 @@ def update( ): stage_context = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"' stage_path = f'{stage_context}."feast_{project}"' - with get_snowflake_conn(self.repo_config.batch_engine) as conn: + with GetSnowflakeConnection(self.repo_config.batch_engine) as conn: query = f"SHOW STAGES IN {stage_context}" cursor = execute_snowflake_statement(conn, query) stage_list = pd.DataFrame( @@ -173,7 +173,7 @@ def teardown_infra( ): stage_path = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"."feast_{project}"' - with get_snowflake_conn(self.repo_config.batch_engine) as conn: + with GetSnowflakeConnection(self.repo_config.batch_engine) as conn: query = f"DROP STAGE IF EXISTS {stage_path}" execute_snowflake_statement(conn, query) @@ -263,10 +263,11 @@ def _materialize_one( # Lets check and see if we can skip this query, because the table hasnt changed # since before the start date of this query - with get_snowflake_conn(self.repo_config.offline_store) as conn: + with GetSnowflakeConnection(self.repo_config.offline_store) as conn: query = f"""SELECT SYSTEM$LAST_CHANGE_COMMIT_TIME('{feature_view.batch_source.get_table_query_string()}') AS last_commit_change_time""" last_commit_change_time = ( - conn.cursor().execute(query).fetchall()[0][0] / 1_000_000_000 + execute_snowflake_statement(conn, query).fetchall()[0][0] + / 1_000_000_000 ) if last_commit_change_time < start_date.astimezone(tz=utc).timestamp(): return SnowflakeMaterializationJob( @@ -432,7 +433,7 @@ def materialize_to_snowflake_online_store( ) """ - with get_snowflake_conn(repo_config.batch_engine) as conn: + with GetSnowflakeConnection(repo_config.batch_engine) as conn: query_id = execute_snowflake_statement(conn, query).sfqid click.echo( @@ -450,7 +451,7 @@ def materialize_to_external_online_store( feature_names = [feature.name for feature in feature_view.features] - with get_snowflake_conn(repo_config.batch_engine) as conn: + with GetSnowflakeConnection(repo_config.batch_engine) as conn: query = materialization_sql cursor = execute_snowflake_statement(conn, query) for i, df in enumerate(cursor.fetch_pandas_batches()): diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index 2401458be7..404927146a 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -48,8 +48,8 @@ ) from feast.infra.registry.base_registry import BaseRegistry from feast.infra.utils.snowflake.snowflake_utils import ( + GetSnowflakeConnection, execute_snowflake_statement, - get_snowflake_conn, write_pandas, write_parquet, ) @@ -74,13 +74,13 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel): """Offline store config for Snowflake""" type: Literal["snowflake.offline"] = "snowflake.offline" - """ Offline store type selector""" + """ Offline store type selector """ config_path: Optional[str] = os.path.expanduser("~/.snowsql/config") """ Snowflake config path -- absolute path required (Cant use ~)""" account: Optional[str] = None - """ Snowflake deployment identifier -- drop .snowflakecomputing.com""" + """ Snowflake deployment identifier -- drop .snowflakecomputing.com """ user: Optional[str] = None """ Snowflake user name """ @@ -89,7 +89,7 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel): """ Snowflake password """ role: Optional[str] = None - """ Snowflake role name""" + """ Snowflake role name """ warehouse: Optional[str] = None """ Snowflake warehouse name """ @@ -155,7 +155,8 @@ def pull_latest_from_table_or_query( if data_source.snowflake_options.warehouse: config.offline_store.warehouse = data_source.snowflake_options.warehouse - snowflake_conn = get_snowflake_conn(config.offline_store) + with GetSnowflakeConnection(config.offline_store) as conn: + snowflake_conn = conn start_date = start_date.astimezone(tz=utc) end_date = end_date.astimezone(tz=utc) @@ -208,7 +209,8 @@ def pull_all_from_table_or_query( if data_source.snowflake_options.warehouse: config.offline_store.warehouse = data_source.snowflake_options.warehouse - snowflake_conn = get_snowflake_conn(config.offline_store) + with GetSnowflakeConnection(config.offline_store) as conn: + snowflake_conn = conn start_date = start_date.astimezone(tz=utc) end_date = end_date.astimezone(tz=utc) @@ -241,7 +243,8 @@ def get_historical_features( for fv in feature_views: assert isinstance(fv.batch_source, SnowflakeSource) - snowflake_conn = get_snowflake_conn(config.offline_store) + with GetSnowflakeConnection(config.offline_store) as conn: + snowflake_conn = conn entity_schema = _get_entity_schema(entity_df, snowflake_conn, config) @@ -319,7 +322,8 @@ def write_logged_features( ): assert isinstance(logging_config.destination, SnowflakeLoggingDestination) - snowflake_conn = get_snowflake_conn(config.offline_store) + with GetSnowflakeConnection(config.offline_store) as conn: + snowflake_conn = conn if isinstance(data, Path): write_parquet( @@ -359,7 +363,8 @@ def offline_write_batch( if table.schema != pa_schema: table = table.cast(pa_schema) - snowflake_conn = get_snowflake_conn(config.offline_store) + with GetSnowflakeConnection(config.offline_store) as conn: + snowflake_conn = conn write_pandas( snowflake_conn, @@ -427,7 +432,6 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table: ).fetch_arrow_all() if pa_table: - return pa_table else: empty_result = execute_snowflake_statement(self.snowflake_conn, query) diff --git a/sdk/python/feast/infra/offline_stores/snowflake_source.py b/sdk/python/feast/infra/offline_stores/snowflake_source.py index cc5208a676..63533214ea 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake_source.py +++ b/sdk/python/feast/infra/offline_stores/snowflake_source.py @@ -213,13 +213,13 @@ def get_table_column_names_and_types( """ from feast.infra.offline_stores.snowflake import SnowflakeOfflineStoreConfig from feast.infra.utils.snowflake.snowflake_utils import ( + GetSnowflakeConnection, execute_snowflake_statement, - get_snowflake_conn, ) assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig) - with get_snowflake_conn(config.offline_store) as conn: + with GetSnowflakeConnection(config.offline_store) as conn: query = f"SELECT * FROM {self.get_table_query_string()} LIMIT 5" cursor = execute_snowflake_statement(conn, query) @@ -250,7 +250,7 @@ def get_table_column_names_and_types( else: column = row["column_name"] - with get_snowflake_conn(config.offline_store) as conn: + with GetSnowflakeConnection(config.offline_store) as conn: query = f'SELECT MAX("{column}") AS "{column}" FROM {self.get_table_query_string()}' result = execute_snowflake_statement( conn, query diff --git a/sdk/python/feast/infra/online_stores/snowflake.py b/sdk/python/feast/infra/online_stores/snowflake.py index c4474dff38..c1a03a2862 100644 --- a/sdk/python/feast/infra/online_stores/snowflake.py +++ b/sdk/python/feast/infra/online_stores/snowflake.py @@ -13,8 +13,8 @@ from feast.infra.key_encoding_utils import serialize_entity_key from feast.infra.online_stores.online_store import OnlineStore from feast.infra.utils.snowflake.snowflake_utils import ( + GetSnowflakeConnection, execute_snowflake_statement, - get_snowflake_conn, get_snowflake_online_store_path, write_pandas_binary, ) @@ -29,13 +29,13 @@ class SnowflakeOnlineStoreConfig(FeastConfigBaseModel): """Online store config for Snowflake""" type: Literal["snowflake.online"] = "snowflake.online" - """ Online store type selector""" + """ Online store type selector """ config_path: Optional[str] = os.path.expanduser("~/.snowsql/config") """ Snowflake config path -- absolute path required (Can't use ~)""" account: Optional[str] = None - """ Snowflake deployment identifier -- drop .snowflakecomputing.com""" + """ Snowflake deployment identifier -- drop .snowflakecomputing.com """ user: Optional[str] = None """ Snowflake user name """ @@ -44,7 +44,7 @@ class SnowflakeOnlineStoreConfig(FeastConfigBaseModel): """ Snowflake password """ role: Optional[str] = None - """ Snowflake role name""" + """ Snowflake role name """ warehouse: Optional[str] = None """ Snowflake warehouse name """ @@ -114,7 +114,7 @@ def online_write_batch( # This combines both the data upload plus the overwrite in the same transaction online_path = get_snowflake_online_store_path(config, table) - with get_snowflake_conn(config.online_store, autocommit=False) as conn: + with GetSnowflakeConnection(config.online_store, autocommit=False) as conn: write_pandas_binary( conn, agg_df, @@ -178,7 +178,7 @@ def online_read( ) online_path = get_snowflake_online_store_path(config, table) - with get_snowflake_conn(config.online_store) as conn: + with GetSnowflakeConnection(config.online_store) as conn: query = f""" SELECT "entity_key", "feature_name", "value", "event_ts" @@ -220,7 +220,7 @@ def update( ): assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) - with get_snowflake_conn(config.online_store) as conn: + with GetSnowflakeConnection(config.online_store) as conn: for table in tables_to_keep: online_path = get_snowflake_online_store_path(config, table) query = f""" @@ -248,7 +248,7 @@ def teardown( ): assert isinstance(config.online_store, SnowflakeOnlineStoreConfig) - with get_snowflake_conn(config.online_store) as conn: + with GetSnowflakeConnection(config.online_store) as conn: for table in tables: online_path = get_snowflake_online_store_path(config, table) query = f'DROP TABLE IF EXISTS {online_path}."[online-transient] {config.project}_{table.name}"' diff --git a/sdk/python/feast/infra/registry/snowflake.py b/sdk/python/feast/infra/registry/snowflake.py index 07709db696..12682bdca2 100644 --- a/sdk/python/feast/infra/registry/snowflake.py +++ b/sdk/python/feast/infra/registry/snowflake.py @@ -28,8 +28,8 @@ from feast.infra.registry import proto_registry_utils from feast.infra.registry.base_registry import BaseRegistry from feast.infra.utils.snowflake.snowflake_utils import ( + GetSnowflakeConnection, execute_snowflake_statement, - get_snowflake_conn, ) from feast.on_demand_feature_view import OnDemandFeatureView from feast.project_metadata import ProjectMetadata @@ -121,7 +121,7 @@ def __init__( f'"{self.registry_config.database}"."{self.registry_config.schema_}"' ) - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_creation.sql" with open(sql_function_file, "r") as file: sqlFile = file.read() @@ -177,7 +177,7 @@ def _refresh_cached_registry_if_necessary(self): self.refresh() def teardown(self): - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql" with open(sql_function_file, "r") as file: sqlFile = file.read() @@ -284,7 +284,7 @@ def _apply_object( if hasattr(obj, "last_updated_timestamp"): obj.last_updated_timestamp = update_datetime - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT project_id @@ -405,7 +405,7 @@ def _delete_object( id_field_name: str, not_found_exception: Optional[Callable], ): - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: query = f""" DELETE FROM {self.registry_path}."{table}" WHERE @@ -616,7 +616,7 @@ def _get_object( not_found_exception: Optional[Callable], ): self._maybe_init_project_metadata(project) - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT {proto_field_name} @@ -776,7 +776,7 @@ def _list_objects( proto_field_name: str, ): self._maybe_init_project_metadata(project) - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT {proto_field_name} @@ -839,7 +839,7 @@ def list_project_metadata( return proto_registry_utils.list_project_metadata( self.cached_registry_proto, project ) - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT metadata_key, @@ -869,7 +869,7 @@ def apply_user_metadata( ): fv_table_str = self._infer_fv_table(feature_view) fv_column_name = fv_table_str[:-1].lower() - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT project_id @@ -905,7 +905,7 @@ def get_user_metadata( ) -> Optional[bytes]: fv_table_str = self._infer_fv_table(feature_view) fv_column_name = fv_table_str[:-1].lower() - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT user_metadata @@ -971,7 +971,7 @@ def _get_all_projects(self) -> Set[str]: "STREAM_FEATURE_VIEWS", ] - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: for table in base_tables: query = ( f'SELECT DISTINCT project_id FROM {self.registry_path}."{table}"' @@ -984,7 +984,7 @@ def _get_all_projects(self) -> Set[str]: return projects def _get_last_updated_metadata(self, project: str): - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT metadata_value @@ -1029,7 +1029,7 @@ def _infer_fv_table(self, feature_view) -> str: return table def _maybe_init_project_metadata(self, project): - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT metadata_value @@ -1056,7 +1056,7 @@ def _maybe_init_project_metadata(self, project): usage.set_current_project_uuid(new_project_uuid) def _set_last_updated_metadata(self, last_updated: datetime, project: str): - with get_snowflake_conn(self.registry_config) as conn: + with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT project_id diff --git a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py index 8023980eac..a4cda89a6f 100644 --- a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py @@ -39,6 +39,77 @@ getLogger("snowflake.connector.network").disabled = True logger = getLogger(__name__) +_cache = {} + + +class GetSnowflakeConnection: + def __init__(self, config: str, autocommit=True): + self.config = config + self.autocommit = autocommit + + def __enter__(self): + + assert self.config.type in [ + "snowflake.registry", + "snowflake.offline", + "snowflake.engine", + "snowflake.online", + ] + + if self.config.type not in _cache: + if self.config.type == "snowflake.registry": + config_header = "connections.feast_registry" + elif self.config.type == "snowflake.offline": + config_header = "connections.feast_offline_store" + if self.config.type == "snowflake.engine": + config_header = "connections.feast_batch_engine" + elif self.config.type == "snowflake.online": + config_header = "connections.feast_online_store" + + config_dict = dict(self.config) + + # read config file + config_reader = configparser.ConfigParser() + config_reader.read([config_dict["config_path"]]) + kwargs: Dict[str, Any] = {} + if config_reader.has_section(config_header): + kwargs = dict(config_reader[config_header]) + + kwargs.update((k, v) for k, v in config_dict.items() if v is not None) + + for k, v in kwargs.items(): + if k in ["role", "warehouse", "database", "schema_"]: + kwargs[k] = f'"{v}"' + + kwargs["schema"] = kwargs.pop("schema_") + + # https://docs.snowflake.com/en/user-guide/python-connector-example.html#using-key-pair-authentication-key-pair-rotation + # https://docs.snowflake.com/en/user-guide/key-pair-auth.html#configuring-key-pair-authentication + if "private_key" in kwargs: + kwargs["private_key"] = parse_private_key_path( + kwargs["private_key"], kwargs["private_key_passphrase"] + ) + + try: + _cache[self.config.type] = snowflake.connector.connect( + application="feast", + client_session_keep_alive=True, + autocommit=self.autocommit, + **kwargs, + ) + _cache[self.config.type].cursor().execute( + "ALTER SESSION SET TIMEZONE = 'UTC'", _is_internal=True + ) + + except KeyError as e: + raise SnowflakeIncompleteConfig(e) + + self.client = _cache[self.config.type] + return self.client + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + def assert_snowflake_feature_names(feature_view: FeatureView) -> None: for feature in feature_view.features: @@ -57,61 +128,6 @@ def execute_snowflake_statement(conn: SnowflakeConnection, query) -> SnowflakeCu return cursor -def get_snowflake_conn(config, autocommit=True) -> SnowflakeConnection: - assert config.type in [ - "snowflake.registry", - "snowflake.offline", - "snowflake.engine", - "snowflake.online", - ] - - if config.type == "snowflake.registry": - config_header = "connections.feast_registry" - elif config.type == "snowflake.offline": - config_header = "connections.feast_offline_store" - if config.type == "snowflake.engine": - config_header = "connections.feast_batch_engine" - elif config.type == "snowflake.online": - config_header = "connections.feast_online_store" - - config_dict = dict(config) - - # read config file - config_reader = configparser.ConfigParser() - config_reader.read([config_dict["config_path"]]) - kwargs: Dict[str, Any] = {} - if config_reader.has_section(config_header): - kwargs = dict(config_reader[config_header]) - - kwargs.update((k, v) for k, v in config_dict.items() if v is not None) - - for k, v in kwargs.items(): - if k in ["role", "warehouse", "database", "schema_"]: - kwargs[k] = f'"{v}"' - - kwargs["schema"] = kwargs.pop("schema_") - - # https://docs.snowflake.com/en/user-guide/python-connector-example.html#using-key-pair-authentication-key-pair-rotation - # https://docs.snowflake.com/en/user-guide/key-pair-auth.html#configuring-key-pair-authentication - if "private_key" in kwargs: - kwargs["private_key"] = parse_private_key_path( - kwargs["private_key"], kwargs["private_key_passphrase"] - ) - - try: - conn = snowflake.connector.connect( - application="feast", - autocommit=autocommit, - **kwargs, - ) - - conn.cursor().execute("ALTER SESSION SET TIMEZONE = 'UTC'", _is_internal=True) - - return conn - except KeyError as e: - raise SnowflakeIncompleteConfig(e) - - def get_snowflake_online_store_path( config: RepoConfig, feature_view: FeatureView, diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py index f0a09b4d5b..257e46df19 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py @@ -13,8 +13,8 @@ SnowflakeLoggingDestination, ) from feast.infra.utils.snowflake.snowflake_utils import ( + GetSnowflakeConnection, execute_snowflake_statement, - get_snowflake_conn, write_pandas, ) from feast.repo_config import FeastConfigBaseModel @@ -54,11 +54,10 @@ def create_data_source( field_mapping: Dict[str, str] = None, ) -> DataSource: - snowflake_conn = get_snowflake_conn(self.offline_store_config) - destination_name = self.get_prefixed_table_name(destination_name) - write_pandas(snowflake_conn, df, destination_name, auto_create_table=True) + with GetSnowflakeConnection(self.offline_store_config) as conn: + write_pandas(conn, df, destination_name, auto_create_table=True) self.tables.append(destination_name) @@ -93,7 +92,7 @@ def get_prefixed_table_name(self, suffix: str) -> str: return f"{self.project_name}_{suffix}" def teardown(self): - with get_snowflake_conn(self.offline_store_config) as conn: + with GetSnowflakeConnection(self.offline_store_config) as conn: for table in self.tables: query = f'DROP TABLE IF EXISTS "{table}"' execute_snowflake_statement(conn, query)