Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding namespace support for Snowflake #5486

Merged
merged 11 commits into from
Nov 14, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ class KeyfileCreds(BaseModel):

type: Optional[str] = None
project_id: str = Field(title="Project ID")
private_key_id: Optional[str] = Field(default=None, title="Private Key ID")
private_key_id: Optional[str] = Field(default=None, title="Private key ID")
private_key: Optional[str] = Field(
default=None, json_schema_extra={"sensitive": True}
default=None, title="Private key", json_schema_extra={"sensitive": True}
)
client_email: Optional[EmailStr] = None
client_email: Optional[EmailStr] = Field(None, title="Client email")
client_id: Optional[str] = Field(default=None, title="Client ID")
auth_uri: Optional[str] = Field(default=None, title="Auth URI")
token_uri: Optional[str] = Field(default=None, title="Token URI")
auth_provider_x509_cert_url: Optional[str] = Field(
default=None, title="Auth Provider X509 Cert URL"
default=None, title="Auth provider X509 cert URL"
)
client_x509_cert_url: Optional[str] = Field(
default=None, title="Client X509 Cert URL"
default=None, title="Client X509 cert URL"
)


Expand All @@ -42,8 +42,8 @@ class BigQuerySchema(ConnectionConfigSecretsSchema):
)
dataset: Optional[str] = Field(
default=None,
title="Default dataset",
description="The default BigQuery dataset that will be used if one isn't provided in the associated Fides datasets.",
title="Dataset",
description="Only provide a dataset to scope discovery monitors and privacy request automation to a specific BigQuery dataset. In most cases, this can be left blank.",
)

_required_components: ClassVar[List[str]] = ["keyfile_creds"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@ class SnowflakeSchema(ConnectionConfigSecretsSchema):
title="Warehouse",
description="The name of the Snowflake warehouse where your queries will be executed.",
)
database_name: str = Field(
database_name: Optional[str] = Field(
default=None,
title="Database",
description="The name of the Snowflake database you want to connect to.",
description="Only provide a database name to scope discovery monitors and privacy request automation to a specific database. In most cases, this can be left blank.",
)
schema_name: str = Field(
schema_name: Optional[str] = Field(
default=None,
title="Schema",
description="The name of the Snowflake schema within the selected database.",
description="Only provide a schema to scope discovery monitors and privacy request automation to a specific schema. In most cases, this can be left blank.",
)
role_name: Optional[str] = Field(
title="Role",
Expand All @@ -67,8 +69,6 @@ class SnowflakeSchema(ConnectionConfigSecretsSchema):
"account_identifier",
"user_login_name",
"warehouse_name",
"database_name",
"schema_name",
]

@model_validator(mode="after")
Expand Down
17 changes: 17 additions & 0 deletions src/fides/api/schemas/namespace_meta/snowflake_namespace_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Literal

from fides.api.schemas.namespace_meta.namespace_meta import NamespaceMeta


class SnowflakeNamespaceMeta(NamespaceMeta):
"""
Represents the namespace structure for Snowflake queries.
Attributes:
database_name (str): Name of the specific Snowflake database.
schema (str): The schema within the database.
"""

connection_type: Literal["snowflake"] = "snowflake"
database_name: str
schema: str # type: ignore[assignment]
32 changes: 29 additions & 3 deletions src/fides/api/service/connectors/query_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
BigQueryNamespaceMeta,
)
from fides.api.schemas.namespace_meta.namespace_meta import NamespaceMeta
from fides.api.schemas.namespace_meta.snowflake_namespace_meta import (
SnowflakeNamespaceMeta,
)
from fides.api.schemas.policy import ActionType
from fides.api.service.masking.strategy.masking_strategy import MaskingStrategy
from fides.api.service.masking.strategy.masking_strategy_nullify import (
Expand Down Expand Up @@ -775,6 +778,8 @@ class MicrosoftSQLServerQueryConfig(QueryStringWithoutTuplesOverrideQueryConfig)
class SnowflakeQueryConfig(SQLQueryConfig):
"""Generates SQL in Snowflake's custom dialect."""

namespace_meta_schema = SnowflakeNamespaceMeta

def generate_raw_query(
self, field_list: List[str], filters: Dict[str, List[Any]]
) -> Optional[TextClause]:
Expand All @@ -791,13 +796,34 @@ def format_clause_for_query(
"""Returns field names in clauses surrounded by quotation marks as required by Snowflake syntax."""
return f'"{string_path}" {operator} (:{operand})'

def _generate_table_name(self) -> str:
"""
Prepends the dataset name and schema to the base table name
if the Snowflake namespace meta is provided.
"""

table_name = (
f'"{self.node.collection.name}"' # Always quote the base table name
)

if not self.namespace_meta:
return table_name

snowflake_meta = cast(SnowflakeNamespaceMeta, self.namespace_meta)
qualified_name = f'"{snowflake_meta.schema}".{table_name}'

if database_name := snowflake_meta.database_name:
return f'"{database_name}".{qualified_name}'

return qualified_name

def get_formatted_query_string(
self,
field_list: str,
clauses: List[str],
) -> str:
"""Returns a query string with double quotation mark formatting as required by Snowflake syntax."""
return f'SELECT {field_list} FROM "{self.node.collection.name}" WHERE ({" OR ".join(clauses)})'
return f'SELECT {field_list} FROM {self._generate_table_name()} WHERE ({" OR ".join(clauses)})'

def format_key_map_for_update_stmt(self, fields: List[str]) -> List[str]:
"""Adds the appropriate formatting for update statements in this datastore."""
Expand All @@ -809,8 +835,8 @@ def get_update_stmt(
update_clauses: List[str],
pk_clauses: List[str],
) -> str:
"""Returns a parameterised update statement in Snowflake dialect."""
return f'UPDATE "{self.node.address.collection}" SET {",".join(update_clauses)} WHERE {" AND ".join(pk_clauses)}'
"""Returns a parameterized update statement in Snowflake dialect."""
return f'UPDATE {self._generate_table_name()} SET {", ".join(update_clauses)} WHERE {" AND ".join(pk_clauses)}'


class RedshiftQueryConfig(SQLQueryConfig):
Expand Down
6 changes: 5 additions & 1 deletion src/fides/api/service/connectors/sql_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,11 @@ def get_connect_args(self) -> Dict[str, Any]:

def query_config(self, node: ExecutionNode) -> SQLQueryConfig:
"""Query wrapper corresponding to the input execution_node."""
return SnowflakeQueryConfig(node)

db: Session = Session.object_session(self.configuration)
return SnowflakeQueryConfig(
node, SQLConnector.get_namespace_meta(db, node.address.dataset)
)


class MicrosoftSQLServerConnector(SQLConnector):
Expand Down
68 changes: 68 additions & 0 deletions tests/fixtures/snowflake_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,45 @@ def snowflake_connection_config(
connection_config.delete(db)


@pytest.fixture(scope="function")
def snowflake_connection_config_without_default_dataset_or_schema(
db: Session,
integration_config: Dict[str, str],
snowflake_connection_config_without_secrets: ConnectionConfig,
) -> Generator:
"""
Returns a Snowflake ConectionConfig with secrets attached if secrets are present
in the configuration.
"""
connection_config = snowflake_connection_config_without_secrets

account_identifier = integration_config.get("snowflake", {}).get(
"account_identifier"
) or os.environ.get("SNOWFLAKE_TEST_ACCOUNT_IDENTIFIER")
user_login_name = integration_config.get("snowflake", {}).get(
"user_login_name"
) or os.environ.get("SNOWFLAKE_TEST_USER_LOGIN_NAME")
password = integration_config.get("snowflake", {}).get(
"password"
) or os.environ.get("SNOWFLAKE_TEST_PASSWORD")
warehouse_name = integration_config.get("snowflake", {}).get(
"warehouse_name"
) or os.environ.get("SNOWFLAKE_TEST_WAREHOUSE_NAME")

if all([account_identifier, user_login_name, password, warehouse_name]):
schema = SnowflakeSchema(
account_identifier=account_identifier,
user_login_name=user_login_name,
password=password,
warehouse_name=warehouse_name,
)
connection_config.secrets = schema.model_dump(mode="json")
connection_config.save(db=db)

yield connection_config
connection_config.delete(db)


@pytest.fixture(scope="function")
def snowflake_connection_config_with_keypair(
db: Session,
Expand Down Expand Up @@ -190,3 +229,32 @@ def snowflake_example_test_dataset_config(
yield dataset_config
dataset_config.delete(db=db)
ctl_dataset.delete(db=db)


@pytest.fixture
def snowflake_example_test_dataset_config_with_namespace_meta(
snowflake_connection_config_without_default_dataset_or_schema: ConnectionConfig,
db: Session,
example_datasets: List[Dict],
) -> Generator:

connection_config = snowflake_connection_config_without_default_dataset_or_schema
dataset = example_datasets[2]
dataset["fides_meta"] = {
"namespace": {"database_name": "FIDESOPS_TEST", "schema": "TEST"}
}
fides_key = dataset["fides_key"]

ctl_dataset = CtlDataset.create_from_dataset_dict(db, dataset)

dataset_config = DatasetConfig.create(
db=db,
data={
"connection_config_id": connection_config.id,
"fides_key": fides_key,
"ctl_dataset_id": ctl_dataset.id,
},
)
yield dataset_config
dataset_config.delete(db=db)
ctl_dataset.delete(db=db)
40 changes: 17 additions & 23 deletions tests/ops/api/v1/endpoints/test_connection_template_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,8 @@ def test_get_connection_secret_schema_bigquery(
"allOf": [{"$ref": "#/definitions/KeyfileCreds"}],
},
"dataset": {
"title": "Default dataset",
"description": "The default BigQuery dataset that will be used if one isn't provided in the associated Fides datasets.",
"title": "Dataset",
"description": "Only provide a dataset to scope discovery monitors and privacy request automation to a specific BigQuery dataset. In most cases, this can be left blank.",
"type": "string",
},
},
Expand All @@ -796,28 +796,28 @@ def test_get_connection_secret_schema_bigquery(
"type": {"title": "Type", "type": "string"},
"project_id": {"title": "Project ID", "type": "string"},
"private_key_id": {
"title": "Private Key ID",
"title": "Private key ID",
"type": "string",
},
"private_key": {
"title": "Private Key",
"title": "Private key",
"sensitive": True,
"type": "string",
},
"client_email": {
"title": "Client Email",
"title": "Client email",
"type": "string",
"format": "email",
},
"client_id": {"title": "Client ID", "type": "string"},
"auth_uri": {"title": "Auth URI", "type": "string"},
"token_uri": {"title": "Token URI", "type": "string"},
"auth_provider_x509_cert_url": {
"title": "Auth Provider X509 Cert URL",
"title": "Auth provider X509 cert URL",
"type": "string",
},
"client_x509_cert_url": {
"title": "Client X509 Cert URL",
"title": "Client X509 cert URL",
"type": "string",
},
},
Expand Down Expand Up @@ -1477,24 +1477,22 @@ def test_get_connection_secret_schema_snowflake(
base_url.format(connection_type="snowflake"), headers=auth_header
)
assert resp.json() == {
"title": "SnowflakeSchema",
"description": "Schema to validate the secrets needed to connect to Snowflake",
"type": "object",
"properties": {
"account_identifier": {
"title": "Account Name",
"description": "The unique identifier for your Snowflake account.",
"title": "Account Name",
"type": "string",
},
"user_login_name": {
"title": "Username",
"description": "The user account used to authenticate and access the database.",
"title": "Username",
"type": "string",
},
"password": {
"title": "Password",
"description": "The password used to authenticate and access the database. You can use a password or a private key, but not both.",
"sensitive": True,
"title": "Password",
"type": "string",
},
"private_key": {
Expand All @@ -1510,33 +1508,29 @@ def test_get_connection_secret_schema_snowflake(
"type": "string",
},
"warehouse_name": {
"title": "Warehouse",
"description": "The name of the Snowflake warehouse where your queries will be executed.",
"title": "Warehouse",
"type": "string",
},
"database_name": {
"description": "Only provide a database name to scope discovery monitors and privacy request automation to a specific database. In most cases, this can be left blank.",
"title": "Database",
"description": "The name of the Snowflake database you want to connect to.",
"type": "string",
},
"schema_name": {
"description": "Only provide a schema to scope discovery monitors and privacy request automation to a specific schema. In most cases, this can be left blank.",
"title": "Schema",
"description": "The name of the Snowflake schema within the selected database.",
"type": "string",
},
"role_name": {
"title": "Role",
"description": "The Snowflake role to assume for the session, if different than Username.",
"title": "Role",
"type": "string",
},
},
"required": [
"account_identifier",
"user_login_name",
"warehouse_name",
"database_name",
"schema_name",
],
"required": ["account_identifier", "user_login_name", "warehouse_name"],
"title": "SnowflakeSchema",
"type": "object",
}

def test_get_connection_secret_schema_hubspot(
Expand Down
5 changes: 4 additions & 1 deletion tests/ops/service/connectors/test_queryconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,10 @@ class NewSQLQueryConfig(SQLQueryConfig):
pass

with pytest.raises(MissingNamespaceSchemaException) as exc:
NewSQLQueryConfig(payment_card_node, NewSQLNamespaceMeta(schema="public"))
NewSQLQueryConfig(
payment_card_node,
NewSQLNamespaceMeta(schema="public"),
)
assert (
"NewSQLQueryConfig must define a namespace_meta_schema when namespace_meta is provided."
in str(exc)
Expand Down
Loading
Loading