Skip to content

Commit

Permalink
Ensure safer profile mappings by removing __getattr__
Browse files Browse the repository at this point in the history
  • Loading branch information
jlaneve committed Jul 27, 2023
1 parent 5352d69 commit c7e1da1
Show file tree
Hide file tree
Showing 15 changed files with 46 additions and 96 deletions.
42 changes: 18 additions & 24 deletions cosmos/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,23 @@ def can_claim_connection(self) -> bool:
if self.conn.conn_type != self.airflow_connection_type:
return False

logger.info(dir(self.conn))
logger.info(self.conn.__dict__)
generated_profile = self.profile

for field in self.required_fields:
try:
if not getattr(self, field):
# if it's a secret field, check if we can get it
if field in self.secret_fields:
if not self.get_dbt_value(field):
logger.info(
"1 Not using mapping %s because %s is not set",
"Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
)
return False
except AttributeError:

# otherwise, check if it's in the generated profile
if not generated_profile.get(field):
logger.info(
"2 Not using mapping %s because %s is not set",
"Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
)
Expand Down Expand Up @@ -159,23 +161,15 @@ def get_dbt_value(self, name: str) -> Any:
# otherwise, we don't have it - return None
return None

def __getattr__(self, name: str) -> Any:
"If the attribute doesn't exist, try to get it from profile_args or the Airflow connection."
logger.info(
"Couldn't find attribute %s on %s. Trying to get it from `profile_args` or the Airflow connection.",
name,
self.__class__.__name__,
)

# if it doesn't exist, try to get it from profile_args or the Airflow connection
attempted_value = self.get_dbt_value(name)
if attempted_value is not None:
return attempted_value

raise AttributeError(
f"{self.__class__.__name__} has no attribute {name}. If this is a dbt profile field, "
f"ensure it's set either in the profile_args or the Airflow connection."
)
@property
def mapped_params(self) -> dict[str, Any]:
"Turns the self.airflow_param_mapping into a dictionary of dbt fields and their values."
mapped_params = {}

for dbt_field in self.airflow_param_mapping:
mapped_params[dbt_field] = self.get_dbt_value(dbt_field)

return mapped_params

@classmethod
def filter_null(cls, args: dict[str, Any]) -> dict[str, Any]:
Expand Down
6 changes: 2 additions & 4 deletions cosmos/profiles/bigquery/service_account_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@ class GoogleCloudServiceAccountFileProfileMapping(BaseProfileMapping):
def profile(self) -> dict[str, Any | None]:
"Generates profile. Defaults `threads` to 1."
return {
**self.mapped_params,
"type": "bigquery",
"method": "service-account",
"project": self.project,
"dataset": self.dataset,
"threads": self.profile_args.get("threads") or 1,
"keyfile": self.keyfile,
"threads": 1,
**self.profile_args,
}
12 changes: 5 additions & 7 deletions cosmos/profiles/bigquery/service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ class GoogleCloudServiceAccountDictProfileMapping(BaseProfileMapping):
required_fields = [
"project",
"dataset",
"keyfile_dict",
"keyfile_json",
]

airflow_param_mapping = {
"project": "extra.project",
# multiple options for dataset because of older Airflow versions
"dataset": ["extra.dataset", "dataset"],
"dataset": "extra.dataset",
# multiple options for keyfile_dict param name because of older Airflow versions
"keyfile_dict": ["extra.keyfile_dict", "keyfile_dict", "extra__google_cloud_platform__keyfile_dict"],
"keyfile_json": ["extra.keyfile_dict", "keyfile_dict", "extra__google_cloud_platform__keyfile_dict"],
}

@property
Expand All @@ -38,11 +38,9 @@ def profile(self) -> dict[str, Any | None]:
we generate a temporary file and the DBT profile uses it.
"""
return {
**self.mapped_params,
"type": "bigquery",
"method": "service-account-json",
"project": self.project,
"dataset": self.dataset,
"threads": self.profile_args.get("threads") or 1,
"keyfile_json": self.keyfile_dict,
"threads": 1,
**self.profile_args,
}
4 changes: 1 addition & 3 deletions cosmos/profiles/databricks/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@ class DatabricksTokenProfileMapping(BaseProfileMapping):
def profile(self) -> dict[str, Any | None]:
"Generates profile. The token is stored in an environment variable."
return {
**self.mapped_params,
"type": "databricks",
"schema": self.schema,
"host": self.host,
"http_path": self.http_path,
**self.profile_args,
# token should always get set as env var
"token": self.get_env_var_format("token"),
Expand Down
11 changes: 1 addition & 10 deletions cosmos/profiles/exasol/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,8 @@ class ExasolUserPasswordProfileMapping(BaseProfileMapping):
def profile(self) -> dict[str, Any | None]:
"Gets profile. The password is stored in an environment variable."
profile_vars = {
**self.mapped_params,
"type": "exasol",
"threads": self.threads,
"dsn": self.dsn,
"user": self.user,
"dbname": self.dbname,
"schema": self.schema,
"encryption": self.conn.extra_dejson.get("encryption"),
"compression": self.conn.extra_dejson.get("compression"),
"connect_timeout": self.conn.extra_dejson.get("connection_timeout"),
"socket_timeout": self.conn.extra_dejson.get("socket_timeout"),
"protocol_version": self.conn.extra_dejson.get("protocol_version"),
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
Expand Down
11 changes: 2 additions & 9 deletions cosmos/profiles/postgres/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@ class PostgresUserPasswordProfileMapping(BaseProfileMapping):
"""

airflow_connection_type: str = "postgres"
default_port = 5432

required_fields = [
"host",
"user",
"password",
"port",
"dbname",
"schema",
]
Expand All @@ -41,14 +39,9 @@ class PostgresUserPasswordProfileMapping(BaseProfileMapping):
def profile(self) -> dict[str, Any | None]:
"Gets profile. The password is stored in an environment variable."
profile = {
**self.mapped_params,
"type": "postgres",
"host": self.conn.host,
"user": self.conn.login,
"port": self.conn.port or self.default_port,
"dbname": self.dbname,
"schema": self.schema,
"keepalives_idle": self.conn.extra_dejson.get("keepalives_idle"),
"sslmode": self.conn.extra_dejson.get("sslmode"),
"port": 5432,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
Expand Down
14 changes: 4 additions & 10 deletions cosmos/profiles/redshift/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class RedshiftUserPasswordProfileMapping(BaseProfileMapping):
"""

airflow_connection_type: str = "redshift"
default_port = 5432

required_fields = [
"host",
Expand All @@ -41,17 +40,12 @@ class RedshiftUserPasswordProfileMapping(BaseProfileMapping):
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile = {
**self.mapped_params,
"type": "redshift",
"host": self.host,
"user": self.user,
"password": self.get_env_var_format("password"),
"port": self.port or self.default_port,
"dbname": self.dbname,
"schema": self.schema,
"connection_timeout": self.conn.extra_dejson.get("timeout"),
"sslmode": self.conn.extra_dejson.get("sslmode"),
"region": self.conn.extra_dejson.get("region"),
"port": 5439,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

return self.filter_null(profile)
7 changes: 1 addition & 6 deletions cosmos/profiles/snowflake/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,8 @@ def conn(self) -> Connection:
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
**self.mapped_params,
"type": "snowflake",
"account": self.account,
"user": self.user,
"schema": self.schema,
"database": self.database,
"role": self.conn.extra_dejson.get("role"),
"warehouse": self.conn.extra_dejson.get("warehouse"),
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
Expand Down
7 changes: 1 addition & 6 deletions cosmos/profiles/snowflake/user_privatekey.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,8 @@ def conn(self) -> Connection:
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
**self.mapped_params,
"type": "snowflake",
"account": self.account,
"user": self.user,
"schema": self.schema,
"database": self.database,
"role": self.conn.extra_dejson.get("role"),
"warehouse": self.conn.extra_dejson.get("warehouse"),
**self.profile_args,
# private_key should always get set as env var
"private_key_content": self.get_env_var_format("private_key_content"),
Expand Down
5 changes: 2 additions & 3 deletions cosmos/profiles/spark/thrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class SparkThriftProfileMapping(BaseProfileMapping):

airflow_param_mapping = {
"host": "host",
"port": "port",
}

@property
Expand All @@ -31,11 +32,9 @@ def profile(self) -> dict[str, Any | None]:
Return a dbt Spark profile based on the Airflow Spark connection.
"""
profile_vars = {
**self.mapped_params,
"type": "spark",
"method": "thrift",
"schema": self.schema,
"host": self.host,
"port": self.conn.port,
**self.profile_args,
}

Expand Down
8 changes: 2 additions & 6 deletions cosmos/profiles/trino/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,11 @@ class TrinoBaseProfileMapping(BaseProfileMapping):
}

@property
def profile(self) -> dict[str, Any | None]:
def profile(self) -> dict[str, Any]:
"Gets profile."
profile_vars = {
**self.mapped_params,
"type": "trino",
"host": self.host,
"port": self.port,
"database": self.database,
"schema": self.schema,
"session_properties": self.conn.extra_dejson.get("session_properties"),
**self.profile_args,
}

Expand Down
3 changes: 1 addition & 2 deletions cosmos/profiles/trino/certificate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ def profile(self) -> dict[str, Any | None]:
"Gets profile."
common_profile_vars = super().profile
profile_vars = {
**self.mapped_params,
**common_profile_vars,
"method": "certificate",
"client_certificate": self.client_certificate,
"client_private_key": self.client_private_key,
**self.profile_args,
}

Expand Down
5 changes: 3 additions & 2 deletions cosmos/profiles/trino/ldap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def profile(self) -> dict[str, Any | None]:
"""
common_profile_vars = super().profile
profile_vars = {
**self.mapped_params,
**common_profile_vars,
"method": "ldap",
"user": self.user,
"password": self.get_env_var_format("password"),
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

# remove any null values
Expand Down
6 changes: 3 additions & 3 deletions tests/profiles/exasol/test_exasol_user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_dsn_formatting() -> None:

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = ExasolUserPasswordProfileMapping(conn, {"schema": "my_schema", "threads": 1})
assert profile_mapping.dsn == "my_host:1000"
assert profile_mapping.get_dbt_value("dsn") == "my_host:1000"

# next, test with a host that doesn't include a port
conn = Connection(
Expand All @@ -207,7 +207,7 @@ def test_dsn_formatting() -> None:

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = ExasolUserPasswordProfileMapping(conn, {"schema": "my_schema", "threads": 1})
assert profile_mapping.dsn == "my_host:8563" # should default to 8563
assert profile_mapping.get_dbt_value("dsn") == "my_host:8563" # should default to 8563

# lastly, test with a port override
conn = Connection(
Expand All @@ -222,4 +222,4 @@ def test_dsn_formatting() -> None:

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = ExasolUserPasswordProfileMapping(conn, {"schema": "my_schema", "threads": 1})
assert profile_mapping.dsn == "my_host:1000"
assert profile_mapping.get_dbt_value("dsn") == "my_host:1000"
1 change: 0 additions & 1 deletion tests/profiles/postgres/test_pg_user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def test_connection_claiming() -> None:
"host": "my_host",
"login": "my_user",
"password": "my_password",
"port": 5432,
"schema": "my_database",
}

Expand Down

0 comments on commit c7e1da1

Please sign in to comment.