Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure safer profile mappings by removing __getattr__ #407

Merged
merged 1 commit into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 18 additions & 24 deletions cosmos/profiles/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,23 @@ def can_claim_connection(self) -> bool:
if self.conn.conn_type != self.airflow_connection_type:
return False

logger.info(dir(self.conn))
logger.info(self.conn.__dict__)
generated_profile = self.profile

for field in self.required_fields:
try:
if not getattr(self, field):
# if it's a secret field, check if we can get it
if field in self.secret_fields:
if not self.get_dbt_value(field):
logger.info(
"1 Not using mapping %s because %s is not set",
"Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
)
return False
except AttributeError:

# otherwise, check if it's in the generated profile
if not generated_profile.get(field):
logger.info(
"2 Not using mapping %s because %s is not set",
"Not using mapping %s because %s is not set",
self.__class__.__name__,
field,
)
Expand Down Expand Up @@ -159,23 +161,15 @@ def get_dbt_value(self, name: str) -> Any:
# otherwise, we don't have it - return None
return None

def __getattr__(self, name: str) -> Any:
"If the attribute doesn't exist, try to get it from profile_args or the Airflow connection."
logger.info(
"Couldn't find attribute %s on %s. Trying to get it from `profile_args` or the Airflow connection.",
name,
self.__class__.__name__,
)

# if it doesn't exist, try to get it from profile_args or the Airflow connection
attempted_value = self.get_dbt_value(name)
if attempted_value is not None:
return attempted_value

raise AttributeError(
f"{self.__class__.__name__} has no attribute {name}. If this is a dbt profile field, "
f"ensure it's set either in the profile_args or the Airflow connection."
)
@property
def mapped_params(self) -> dict[str, Any]:
"Turns the self.airflow_param_mapping into a dictionary of dbt fields and their values."
mapped_params = {}

for dbt_field in self.airflow_param_mapping:
mapped_params[dbt_field] = self.get_dbt_value(dbt_field)

return mapped_params

@classmethod
def filter_null(cls, args: dict[str, Any]) -> dict[str, Any]:
Expand Down
6 changes: 2 additions & 4 deletions cosmos/profiles/bigquery/service_account_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,9 @@ class GoogleCloudServiceAccountFileProfileMapping(BaseProfileMapping):
def profile(self) -> dict[str, Any | None]:
"Generates profile. Defaults `threads` to 1."
return {
**self.mapped_params,
"type": "bigquery",
"method": "service-account",
"project": self.project,
"dataset": self.dataset,
"threads": self.profile_args.get("threads") or 1,
"keyfile": self.keyfile,
"threads": 1,
**self.profile_args,
}
12 changes: 5 additions & 7 deletions cosmos/profiles/bigquery/service_account_keyfile_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ class GoogleCloudServiceAccountDictProfileMapping(BaseProfileMapping):
required_fields = [
"project",
"dataset",
"keyfile_dict",
"keyfile_json",
]

airflow_param_mapping = {
"project": "extra.project",
# multiple options for dataset because of older Airflow versions
"dataset": ["extra.dataset", "dataset"],
"dataset": "extra.dataset",
# multiple options for keyfile_dict param name because of older Airflow versions
"keyfile_dict": ["extra.keyfile_dict", "keyfile_dict", "extra__google_cloud_platform__keyfile_dict"],
"keyfile_json": ["extra.keyfile_dict", "keyfile_dict", "extra__google_cloud_platform__keyfile_dict"],
}

@property
Expand All @@ -38,11 +38,9 @@ def profile(self) -> dict[str, Any | None]:
we generate a temporary file and the DBT profile uses it.
"""
return {
**self.mapped_params,
"type": "bigquery",
"method": "service-account-json",
"project": self.project,
"dataset": self.dataset,
"threads": self.profile_args.get("threads") or 1,
"keyfile_json": self.keyfile_dict,
"threads": 1,
**self.profile_args,
}
4 changes: 1 addition & 3 deletions cosmos/profiles/databricks/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@ class DatabricksTokenProfileMapping(BaseProfileMapping):
def profile(self) -> dict[str, Any | None]:
"Generates profile. The token is stored in an environment variable."
return {
**self.mapped_params,
"type": "databricks",
"schema": self.schema,
"host": self.host,
"http_path": self.http_path,
**self.profile_args,
# token should always get set as env var
"token": self.get_env_var_format("token"),
Expand Down
11 changes: 1 addition & 10 deletions cosmos/profiles/exasol/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,8 @@ class ExasolUserPasswordProfileMapping(BaseProfileMapping):
def profile(self) -> dict[str, Any | None]:
"Gets profile. The password is stored in an environment variable."
profile_vars = {
**self.mapped_params,
"type": "exasol",
"threads": self.threads,
"dsn": self.dsn,
"user": self.user,
"dbname": self.dbname,
"schema": self.schema,
"encryption": self.conn.extra_dejson.get("encryption"),
"compression": self.conn.extra_dejson.get("compression"),
"connect_timeout": self.conn.extra_dejson.get("connection_timeout"),
"socket_timeout": self.conn.extra_dejson.get("socket_timeout"),
"protocol_version": self.conn.extra_dejson.get("protocol_version"),
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
Expand Down
11 changes: 2 additions & 9 deletions cosmos/profiles/postgres/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@ class PostgresUserPasswordProfileMapping(BaseProfileMapping):
"""

airflow_connection_type: str = "postgres"
default_port = 5432

required_fields = [
"host",
"user",
"password",
"port",
"dbname",
"schema",
]
Expand All @@ -41,14 +39,9 @@ class PostgresUserPasswordProfileMapping(BaseProfileMapping):
def profile(self) -> dict[str, Any | None]:
"Gets profile. The password is stored in an environment variable."
profile = {
**self.mapped_params,
"type": "postgres",
"host": self.conn.host,
"user": self.conn.login,
"port": self.conn.port or self.default_port,
"dbname": self.dbname,
"schema": self.schema,
"keepalives_idle": self.conn.extra_dejson.get("keepalives_idle"),
"sslmode": self.conn.extra_dejson.get("sslmode"),
"port": 5432,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
Expand Down
14 changes: 4 additions & 10 deletions cosmos/profiles/redshift/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class RedshiftUserPasswordProfileMapping(BaseProfileMapping):
"""

airflow_connection_type: str = "redshift"
default_port = 5432

required_fields = [
"host",
Expand All @@ -41,17 +40,12 @@ class RedshiftUserPasswordProfileMapping(BaseProfileMapping):
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile = {
**self.mapped_params,
"type": "redshift",
"host": self.host,
"user": self.user,
"password": self.get_env_var_format("password"),
"port": self.port or self.default_port,
"dbname": self.dbname,
"schema": self.schema,
"connection_timeout": self.conn.extra_dejson.get("timeout"),
"sslmode": self.conn.extra_dejson.get("sslmode"),
"region": self.conn.extra_dejson.get("region"),
"port": 5439,
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

return self.filter_null(profile)
7 changes: 1 addition & 6 deletions cosmos/profiles/snowflake/user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,8 @@ def conn(self) -> Connection:
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
**self.mapped_params,
"type": "snowflake",
"account": self.account,
"user": self.user,
"schema": self.schema,
"database": self.database,
"role": self.conn.extra_dejson.get("role"),
"warehouse": self.conn.extra_dejson.get("warehouse"),
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
Expand Down
7 changes: 1 addition & 6 deletions cosmos/profiles/snowflake/user_privatekey.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,8 @@ def conn(self) -> Connection:
def profile(self) -> dict[str, Any | None]:
"Gets profile."
profile_vars = {
**self.mapped_params,
"type": "snowflake",
"account": self.account,
"user": self.user,
"schema": self.schema,
"database": self.database,
"role": self.conn.extra_dejson.get("role"),
"warehouse": self.conn.extra_dejson.get("warehouse"),
**self.profile_args,
# private_key should always get set as env var
"private_key_content": self.get_env_var_format("private_key_content"),
Expand Down
5 changes: 2 additions & 3 deletions cosmos/profiles/spark/thrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class SparkThriftProfileMapping(BaseProfileMapping):

airflow_param_mapping = {
"host": "host",
"port": "port",
}

@property
Expand All @@ -31,11 +32,9 @@ def profile(self) -> dict[str, Any | None]:
Return a dbt Spark profile based on the Airflow Spark connection.
"""
profile_vars = {
**self.mapped_params,
"type": "spark",
"method": "thrift",
"schema": self.schema,
"host": self.host,
"port": self.conn.port,
**self.profile_args,
}

Expand Down
8 changes: 2 additions & 6 deletions cosmos/profiles/trino/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,11 @@ class TrinoBaseProfileMapping(BaseProfileMapping):
}

@property
def profile(self) -> dict[str, Any | None]:
def profile(self) -> dict[str, Any]:
"Gets profile."
profile_vars = {
**self.mapped_params,
"type": "trino",
"host": self.host,
"port": self.port,
"database": self.database,
"schema": self.schema,
"session_properties": self.conn.extra_dejson.get("session_properties"),
**self.profile_args,
}

Expand Down
3 changes: 1 addition & 2 deletions cosmos/profiles/trino/certificate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ def profile(self) -> dict[str, Any | None]:
"Gets profile."
common_profile_vars = super().profile
profile_vars = {
**self.mapped_params,
**common_profile_vars,
"method": "certificate",
"client_certificate": self.client_certificate,
"client_private_key": self.client_private_key,
**self.profile_args,
}

Expand Down
5 changes: 3 additions & 2 deletions cosmos/profiles/trino/ldap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def profile(self) -> dict[str, Any | None]:
"""
common_profile_vars = super().profile
profile_vars = {
**self.mapped_params,
**common_profile_vars,
"method": "ldap",
"user": self.user,
"password": self.get_env_var_format("password"),
**self.profile_args,
# password should always get set as env var
"password": self.get_env_var_format("password"),
}

# remove any null values
Expand Down
6 changes: 3 additions & 3 deletions tests/profiles/exasol/test_exasol_user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_dsn_formatting() -> None:

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = ExasolUserPasswordProfileMapping(conn, {"schema": "my_schema", "threads": 1})
assert profile_mapping.dsn == "my_host:1000"
assert profile_mapping.get_dbt_value("dsn") == "my_host:1000"

# next, test with a host that doesn't include a port
conn = Connection(
Expand All @@ -207,7 +207,7 @@ def test_dsn_formatting() -> None:

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = ExasolUserPasswordProfileMapping(conn, {"schema": "my_schema", "threads": 1})
assert profile_mapping.dsn == "my_host:8563" # should default to 8563
assert profile_mapping.get_dbt_value("dsn") == "my_host:8563" # should default to 8563

# lastly, test with a port override
conn = Connection(
Expand All @@ -222,4 +222,4 @@ def test_dsn_formatting() -> None:

with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn):
profile_mapping = ExasolUserPasswordProfileMapping(conn, {"schema": "my_schema", "threads": 1})
assert profile_mapping.dsn == "my_host:1000"
assert profile_mapping.get_dbt_value("dsn") == "my_host:1000"
1 change: 0 additions & 1 deletion tests/profiles/postgres/test_pg_user_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def test_connection_claiming() -> None:
"host": "my_host",
"login": "my_user",
"password": "my_password",
"port": 5432,
"schema": "my_database",
}

Expand Down