Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added SnowflakeConnection caching #3531

Merged
merged 1 commit into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions sdk/python/feast/infra/materialization/snowflake_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.registry.base_registry import BaseRegistry
from feast.infra.utils.snowflake.snowflake_utils import (
GetSnowflakeConnection,
_run_snowflake_field_mapping,
assert_snowflake_feature_names,
execute_snowflake_statement,
get_snowflake_conn,
get_snowflake_online_store_path,
package_snowpark_zip,
)
Expand Down Expand Up @@ -121,7 +121,7 @@ def update(
):
stage_context = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"'
stage_path = f'{stage_context}."feast_{project}"'
with get_snowflake_conn(self.repo_config.batch_engine) as conn:
with GetSnowflakeConnection(self.repo_config.batch_engine) as conn:
query = f"SHOW STAGES IN {stage_context}"
cursor = execute_snowflake_statement(conn, query)
stage_list = pd.DataFrame(
Expand Down Expand Up @@ -173,7 +173,7 @@ def teardown_infra(
):

stage_path = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"."feast_{project}"'
with get_snowflake_conn(self.repo_config.batch_engine) as conn:
with GetSnowflakeConnection(self.repo_config.batch_engine) as conn:
query = f"DROP STAGE IF EXISTS {stage_path}"
execute_snowflake_statement(conn, query)

Expand Down Expand Up @@ -263,10 +263,11 @@ def _materialize_one(

# Lets check and see if we can skip this query, because the table hasnt changed
# since before the start date of this query
with get_snowflake_conn(self.repo_config.offline_store) as conn:
with GetSnowflakeConnection(self.repo_config.offline_store) as conn:
query = f"""SELECT SYSTEM$LAST_CHANGE_COMMIT_TIME('{feature_view.batch_source.get_table_query_string()}') AS last_commit_change_time"""
last_commit_change_time = (
conn.cursor().execute(query).fetchall()[0][0] / 1_000_000_000
execute_snowflake_statement(conn, query).fetchall()[0][0]
/ 1_000_000_000
)
if last_commit_change_time < start_date.astimezone(tz=utc).timestamp():
return SnowflakeMaterializationJob(
Expand Down Expand Up @@ -432,7 +433,7 @@ def materialize_to_snowflake_online_store(
)
"""

with get_snowflake_conn(repo_config.batch_engine) as conn:
with GetSnowflakeConnection(repo_config.batch_engine) as conn:
query_id = execute_snowflake_statement(conn, query).sfqid

click.echo(
Expand All @@ -450,7 +451,7 @@ def materialize_to_external_online_store(

feature_names = [feature.name for feature in feature_view.features]

with get_snowflake_conn(repo_config.batch_engine) as conn:
with GetSnowflakeConnection(repo_config.batch_engine) as conn:
query = materialization_sql
cursor = execute_snowflake_statement(conn, query)
for i, df in enumerate(cursor.fetch_pandas_batches()):
Expand Down
24 changes: 14 additions & 10 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
)
from feast.infra.registry.base_registry import BaseRegistry
from feast.infra.utils.snowflake.snowflake_utils import (
GetSnowflakeConnection,
execute_snowflake_statement,
get_snowflake_conn,
write_pandas,
write_parquet,
)
Expand All @@ -74,13 +74,13 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel):
"""Offline store config for Snowflake"""

type: Literal["snowflake.offline"] = "snowflake.offline"
""" Offline store type selector"""
""" Offline store type selector """

config_path: Optional[str] = os.path.expanduser("~/.snowsql/config")
""" Snowflake config path -- absolute path required (Cant use ~)"""

account: Optional[str] = None
""" Snowflake deployment identifier -- drop .snowflakecomputing.com"""
""" Snowflake deployment identifier -- drop .snowflakecomputing.com """

user: Optional[str] = None
""" Snowflake user name """
Expand All @@ -89,7 +89,7 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel):
""" Snowflake password """

role: Optional[str] = None
""" Snowflake role name"""
""" Snowflake role name """

warehouse: Optional[str] = None
""" Snowflake warehouse name """
Expand Down Expand Up @@ -155,7 +155,8 @@ def pull_latest_from_table_or_query(
if data_source.snowflake_options.warehouse:
config.offline_store.warehouse = data_source.snowflake_options.warehouse

snowflake_conn = get_snowflake_conn(config.offline_store)
with GetSnowflakeConnection(config.offline_store) as conn:
snowflake_conn = conn

start_date = start_date.astimezone(tz=utc)
end_date = end_date.astimezone(tz=utc)
Expand Down Expand Up @@ -208,7 +209,8 @@ def pull_all_from_table_or_query(
if data_source.snowflake_options.warehouse:
config.offline_store.warehouse = data_source.snowflake_options.warehouse

snowflake_conn = get_snowflake_conn(config.offline_store)
with GetSnowflakeConnection(config.offline_store) as conn:
snowflake_conn = conn

start_date = start_date.astimezone(tz=utc)
end_date = end_date.astimezone(tz=utc)
Expand Down Expand Up @@ -241,7 +243,8 @@ def get_historical_features(
for fv in feature_views:
assert isinstance(fv.batch_source, SnowflakeSource)

snowflake_conn = get_snowflake_conn(config.offline_store)
with GetSnowflakeConnection(config.offline_store) as conn:
snowflake_conn = conn

entity_schema = _get_entity_schema(entity_df, snowflake_conn, config)

Expand Down Expand Up @@ -319,7 +322,8 @@ def write_logged_features(
):
assert isinstance(logging_config.destination, SnowflakeLoggingDestination)

snowflake_conn = get_snowflake_conn(config.offline_store)
with GetSnowflakeConnection(config.offline_store) as conn:
snowflake_conn = conn

if isinstance(data, Path):
write_parquet(
Expand Down Expand Up @@ -359,7 +363,8 @@ def offline_write_batch(
if table.schema != pa_schema:
table = table.cast(pa_schema)

snowflake_conn = get_snowflake_conn(config.offline_store)
with GetSnowflakeConnection(config.offline_store) as conn:
snowflake_conn = conn

write_pandas(
snowflake_conn,
Expand Down Expand Up @@ -427,7 +432,6 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
).fetch_arrow_all()

if pa_table:

return pa_table
else:
empty_result = execute_snowflake_statement(self.snowflake_conn, query)
Expand Down
6 changes: 3 additions & 3 deletions sdk/python/feast/infra/offline_stores/snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,13 @@ def get_table_column_names_and_types(
"""
from feast.infra.offline_stores.snowflake import SnowflakeOfflineStoreConfig
from feast.infra.utils.snowflake.snowflake_utils import (
GetSnowflakeConnection,
execute_snowflake_statement,
get_snowflake_conn,
)

assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig)

with get_snowflake_conn(config.offline_store) as conn:
with GetSnowflakeConnection(config.offline_store) as conn:
query = f"SELECT * FROM {self.get_table_query_string()} LIMIT 5"
cursor = execute_snowflake_statement(conn, query)

Expand Down Expand Up @@ -250,7 +250,7 @@ def get_table_column_names_and_types(
else:
column = row["column_name"]

with get_snowflake_conn(config.offline_store) as conn:
with GetSnowflakeConnection(config.offline_store) as conn:
query = f'SELECT MAX("{column}") AS "{column}" FROM {self.get_table_query_string()}'
result = execute_snowflake_statement(
conn, query
Expand Down
16 changes: 8 additions & 8 deletions sdk/python/feast/infra/online_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.snowflake.snowflake_utils import (
GetSnowflakeConnection,
execute_snowflake_statement,
get_snowflake_conn,
get_snowflake_online_store_path,
write_pandas_binary,
)
Expand All @@ -29,13 +29,13 @@ class SnowflakeOnlineStoreConfig(FeastConfigBaseModel):
"""Online store config for Snowflake"""

type: Literal["snowflake.online"] = "snowflake.online"
""" Online store type selector"""
""" Online store type selector """

config_path: Optional[str] = os.path.expanduser("~/.snowsql/config")
""" Snowflake config path -- absolute path required (Can't use ~)"""

account: Optional[str] = None
""" Snowflake deployment identifier -- drop .snowflakecomputing.com"""
""" Snowflake deployment identifier -- drop .snowflakecomputing.com """

user: Optional[str] = None
""" Snowflake user name """
Expand All @@ -44,7 +44,7 @@ class SnowflakeOnlineStoreConfig(FeastConfigBaseModel):
""" Snowflake password """

role: Optional[str] = None
""" Snowflake role name"""
""" Snowflake role name """

warehouse: Optional[str] = None
""" Snowflake warehouse name """
Expand Down Expand Up @@ -114,7 +114,7 @@ def online_write_batch(

# This combines both the data upload plus the overwrite in the same transaction
online_path = get_snowflake_online_store_path(config, table)
with get_snowflake_conn(config.online_store, autocommit=False) as conn:
with GetSnowflakeConnection(config.online_store, autocommit=False) as conn:
write_pandas_binary(
conn,
agg_df,
Expand Down Expand Up @@ -178,7 +178,7 @@ def online_read(
)

online_path = get_snowflake_online_store_path(config, table)
with get_snowflake_conn(config.online_store) as conn:
with GetSnowflakeConnection(config.online_store) as conn:
query = f"""
SELECT
"entity_key", "feature_name", "value", "event_ts"
Expand Down Expand Up @@ -220,7 +220,7 @@ def update(
):
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)

with get_snowflake_conn(config.online_store) as conn:
with GetSnowflakeConnection(config.online_store) as conn:
for table in tables_to_keep:
online_path = get_snowflake_online_store_path(config, table)
query = f"""
Expand Down Expand Up @@ -248,7 +248,7 @@ def teardown(
):
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)

with get_snowflake_conn(config.online_store) as conn:
with GetSnowflakeConnection(config.online_store) as conn:
for table in tables:
online_path = get_snowflake_online_store_path(config, table)
query = f'DROP TABLE IF EXISTS {online_path}."[online-transient] {config.project}_{table.name}"'
Expand Down
28 changes: 14 additions & 14 deletions sdk/python/feast/infra/registry/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from feast.infra.registry import proto_registry_utils
from feast.infra.registry.base_registry import BaseRegistry
from feast.infra.utils.snowflake.snowflake_utils import (
GetSnowflakeConnection,
execute_snowflake_statement,
get_snowflake_conn,
)
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.project_metadata import ProjectMetadata
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
f'"{self.registry_config.database}"."{self.registry_config.schema_}"'
)

with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_creation.sql"
with open(sql_function_file, "r") as file:
sqlFile = file.read()
Expand Down Expand Up @@ -177,7 +177,7 @@ def _refresh_cached_registry_if_necessary(self):
self.refresh()

def teardown(self):
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql"
with open(sql_function_file, "r") as file:
sqlFile = file.read()
Expand Down Expand Up @@ -284,7 +284,7 @@ def _apply_object(
if hasattr(obj, "last_updated_timestamp"):
obj.last_updated_timestamp = update_datetime

with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
project_id
Expand Down Expand Up @@ -405,7 +405,7 @@ def _delete_object(
id_field_name: str,
not_found_exception: Optional[Callable],
):
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
DELETE FROM {self.registry_path}."{table}"
WHERE
Expand Down Expand Up @@ -616,7 +616,7 @@ def _get_object(
not_found_exception: Optional[Callable],
):
self._maybe_init_project_metadata(project)
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
{proto_field_name}
Expand Down Expand Up @@ -776,7 +776,7 @@ def _list_objects(
proto_field_name: str,
):
self._maybe_init_project_metadata(project)
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
{proto_field_name}
Expand Down Expand Up @@ -839,7 +839,7 @@ def list_project_metadata(
return proto_registry_utils.list_project_metadata(
self.cached_registry_proto, project
)
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
metadata_key,
Expand Down Expand Up @@ -869,7 +869,7 @@ def apply_user_metadata(
):
fv_table_str = self._infer_fv_table(feature_view)
fv_column_name = fv_table_str[:-1].lower()
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
project_id
Expand Down Expand Up @@ -905,7 +905,7 @@ def get_user_metadata(
) -> Optional[bytes]:
fv_table_str = self._infer_fv_table(feature_view)
fv_column_name = fv_table_str[:-1].lower()
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
user_metadata
Expand Down Expand Up @@ -971,7 +971,7 @@ def _get_all_projects(self) -> Set[str]:
"STREAM_FEATURE_VIEWS",
]

with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
for table in base_tables:
query = (
f'SELECT DISTINCT project_id FROM {self.registry_path}."{table}"'
Expand All @@ -984,7 +984,7 @@ def _get_all_projects(self) -> Set[str]:
return projects

def _get_last_updated_metadata(self, project: str):
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
metadata_value
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def _infer_fv_table(self, feature_view) -> str:
return table

def _maybe_init_project_metadata(self, project):
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
metadata_value
Expand All @@ -1056,7 +1056,7 @@ def _maybe_init_project_metadata(self, project):
usage.set_current_project_uuid(new_project_uuid)

def _set_last_updated_metadata(self, last_updated: datetime, project: str):
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
project_id
Expand Down
Loading