diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0ea06541b..baf2cd57c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: - id: end-of-file-fixer - id: mixed-line-ending - id: pretty-format-json - args: ['--autofix'] + args: ["--autofix"] - id: trailing-whitespace - repo: https://github.com/codespell-project/codespell rev: v2.2.6 @@ -54,7 +54,7 @@ repos: - --py37-plus - --keep-runtime-typing - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.6 + rev: v0.1.7 hooks: - id: ruff args: @@ -63,7 +63,7 @@ repos: rev: 23.11.0 hooks: - id: black - args: [ "--config", "./pyproject.toml" ] + args: ["--config", "./pyproject.toml"] - repo: https://github.com/asottile/blacken-docs rev: 1.16.0 hooks: @@ -71,22 +71,26 @@ repos: alias: black additional_dependencies: [black>=22.10.0] - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.7.1' + rev: "v1.7.1" + hooks: - id: mypy name: mypy-python - additional_dependencies: [types-PyYAML, types-attrs, attrs, types-requests, types-python-dateutil, apache-airflow] + args: [--config-file, "./pyproject.toml"] + additional_dependencies: + [ + types-PyYAML, + types-attrs, + attrs, + types-requests, + types-python-dateutil, + apache-airflow, + ] files: ^cosmos - - repo: https://github.com/pycqa/flake8 - rev: 6.1.0 - hooks: - - id: flake8 - entry: pflake8 - additional_dependencies: [pyproject-flake8] ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate skip: - - mypy # build of https://github.com/pre-commit/mirrors-mypy:types-PyYAML,types-attrs,attrs,types-requests, + - mypy # build of https://github.com/pre-commit/mirrors-mypy:types-PyYAML,types-attrs,attrs,types-requests, #types-python-dateutil,apache-airflow@v1.5.0 for python@python3 exceeds tier max size 250MiB: 262.6MiB diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ecade7112..252875cfc 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,7 +1,7 @@ Changelog ========= -1.3.0a2 (2023-11-23) +1.3.0a3 (2023-12-07) -------------------- Features @@ -10,6 +10,24 @@ Features * Add ``ProfileMapping`` for Snowflake encrypted private key path by @ivanstillfront in #608 * Add support for Snowflake encrypted private key environment variable by @DanMawdsleyBA in #649 * Add ``DbtDocsGCSOperator`` for uploading dbt docs to GCS by @jbandoro in #616 +* Add support to select using (some) graph operators when using ``LoadMode.CUSTOM`` and ``LoadMode.DBT_MANIFEST`` by @tatiana in #728 +* Add cosmos/propagate_logs Airflow config support for disabling log pr… by @agreenburg in #648 +* Add operator_args ``full_refresh`` as a templated field by @joppevos in #623 +* Expose environment variables and dbt variables in ``ProjectConfig`` by @jbandoro in #735 + +Enhancements + +* Make Pydantic an optional dependency by @pixie79 in #736 +* Create a symbolic link to ``dbt_packages`` when ``dbt_deps`` is False when using ``LoadMode.DBT_LS`` by @DanMawdsleyBA in #730 +* Support no ``profile_config`` for ``ExecutionMode.KUBERNETES`` and ``ExecutionMode.DOCKER`` by @MrBones757 and @tatiana in #681 and #731 +* Add ``aws_session_token`` for Athena mapping by @benjamin-awd in #663 + +Others + +* Replace flake8 for Ruff by @joppevos in #743 +* Reduce code complexity to 8 by @joppevos in #738 +* Update conflict matrix between Airflow and dbt versions by @tatiana in #731 +* Speed up integration tests by @jbandoro in #732 1.2.5 (2023-11-23) @@ -46,14 +64,13 @@ Others * Docs: add execution config to MWAA code example by @ugmuka in #674 * Docs: highlight DAG examples in docs by @iancmoritz and @jlaneve in #695 + 1.2.3 (2023-11-09) ------------------ -Features +Bug fix -* Add ``ProfileMapping`` for Vertica by @perttus in #540 -* Add ``ProfileMapping`` for Snowflake encrypted private key path by @ivanstillfront in #608 -* Add ``DbtDocsGCSOperator`` for uploading dbt docs to GCS by @jbandoro in #616 +* Fix reusing config across TaskGroups/DAGs by @tatiana in #664 1.2.2 (2023-11-06) diff --git a/cosmos/__init__.py b/cosmos/__init__.py index f1f204634..2d3c2f6ac 100644 --- a/cosmos/__init__.py +++ b/cosmos/__init__.py @@ -5,7 +5,7 @@ Contains dags, task groups, and operators. """ -__version__ = "1.3.0a2" +__version__ = "1.3.0a3" from cosmos.airflow.dag import DbtDag diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index af854d4f5..615c2a124 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -18,6 +18,7 @@ from cosmos.core.graph.entities import Task as TaskMetadata from cosmos.dbt.graph import DbtNode from cosmos.log import get_logger +from typing import Union logger = get_logger(__name__) @@ -271,7 +272,17 @@ def build_airflow_graph( for leaf_node_id in leaves_ids: tasks_map[leaf_node_id] >> test_task - # Create the Airflow task dependencies between non-test nodes + create_airflow_task_dependencies(nodes, tasks_map) + + +def create_airflow_task_dependencies( + nodes: dict[str, DbtNode], tasks_map: dict[str, Union[TaskGroup, BaseOperator]] +) -> None: + """ + Create the Airflow task dependencies between non-test nodes. + :param nodes: Dictionary mapping dbt nodes (node.unique_id to node) + :param tasks_map: Dictionary mapping dbt nodes (node.unique_id to Airflow task) + """ for node_id, node in nodes.items(): for parent_node_id in node.depends_on: # depending on the node type, it will not have mapped 1:1 to tasks_map diff --git a/cosmos/config.py b/cosmos/config.py index 40756d2bb..3b332931f 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -7,6 +7,7 @@ import tempfile from dataclasses import InitVar, dataclass, field from pathlib import Path +import warnings from typing import Any, Iterator, Callable from cosmos.constants import DbtResourceType, TestBehavior, ExecutionMode, LoadMode, TestIndirectSelection @@ -39,10 +40,11 @@ class RenderConfig: :param load_method: The parsing method for loading the dbt model. Defaults to AUTOMATIC :param select: A list of dbt select arguments (e.g. 'config.materialized:incremental') :param exclude: A list of dbt exclude arguments (e.g. 'tag:nightly') + :param selector: Name of a dbt YAML selector to use for parsing. Only supported when using ``load_method=LoadMode.DBT_LS``. :param dbt_deps: Configure to run dbt deps when using dbt ls for dag parsing :param node_converters: a dictionary mapping a ``DbtResourceType`` into a callable. Users can control how to render dbt nodes in Airflow. Only supported when using ``load_method=LoadMode.DBT_MANIFEST`` or ``LoadMode.DBT_LS``. :param dbt_executable_path: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. - :param env_vars: A dictionary of environment variables for rendering. Only supported when using ``LoadMode.DBT_LS``. + :param env_vars: (Deprecated since Cosmos 1.3 use ProjectConfig.env_vars) A dictionary of environment variables for rendering. Only supported when using ``LoadMode.DBT_LS``. :param dbt_project_path Configures the DBT project location accessible on the airflow controller for DAG rendering. Mutually Exclusive with ProjectConfig.dbt_project_path. Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM``. """ @@ -51,15 +53,21 @@ class RenderConfig: load_method: LoadMode = LoadMode.AUTOMATIC select: list[str] = field(default_factory=list) exclude: list[str] = field(default_factory=list) + selector: str | None = None dbt_deps: bool = True node_converters: dict[DbtResourceType, Callable[..., Any]] | None = None dbt_executable_path: str | Path = get_system_dbt() - env_vars: dict[str, str] = field(default_factory=dict) + env_vars: dict[str, str] | None = None dbt_project_path: InitVar[str | Path | None] = None project_path: Path | None = field(init=False) def __post_init__(self, dbt_project_path: str | Path | None) -> None: + if self.env_vars: + warnings.warn( + "RenderConfig.env_vars is deprecated since Cosmos 1.3 and will be removed in Cosmos 2.0. Use ProjectConfig.env_vars instead.", + DeprecationWarning, + ) self.project_path = Path(dbt_project_path) if dbt_project_path else None def validate_dbt_command(self, fallback_cmd: str | Path = "") -> None: @@ -96,6 +104,11 @@ class ProjectConfig: :param manifest_path: The absolute path to the dbt manifest file. Defaults to None :param project_name: Allows the user to define the project name. Required if dbt_project_path is not defined. Defaults to the folder name of dbt_project_path. + :param env_vars: Dictionary of environment variables that are used for both rendering and execution. Rendering with + env vars is only supported when using ``RenderConfig.LoadMode.DBT_LS`` load mode. + :param dbt_vars: Dictionary of dbt variables for the project. This argument overrides variables defined in your dbt_project.yml + file. The dictionary is dumped to a yaml string and passed to dbt commands as the --vars argument. Variables are only + supported for rendering when using ``RenderConfig.LoadMode.DBT_LS`` and ``RenderConfig.LoadMode.CUSTOM`` load mode. """ dbt_project_path: Path | None = None @@ -113,6 +126,8 @@ def __init__( snapshots_relative_path: str | Path = "snapshots", manifest_path: str | Path | None = None, project_name: str | None = None, + env_vars: dict[str, str] | None = None, + dbt_vars: dict[str, str] | None = None, ): # Since we allow dbt_project_path to be defined in ExecutionConfig and RenderConfig # dbt_project_path may not always be defined here. @@ -136,6 +151,9 @@ def __init__( if manifest_path: self.manifest_path = Path(manifest_path) + self.env_vars = env_vars + self.dbt_vars = dbt_vars + def validate_project(self) -> None: """ Validates necessary context is present for a project. diff --git a/cosmos/converter.py b/cosmos/converter.py index 2142cc6e4..c2b31700b 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -3,9 +3,10 @@ from __future__ import annotations -import copy import inspect from typing import Any, Callable +import copy +from warnings import warn from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup @@ -21,6 +22,18 @@ logger = get_logger(__name__) +def migrate_to_new_interface( + execution_config: ExecutionConfig, project_config: ProjectConfig, render_config: RenderConfig +): + # We copy the configuration so the change does not affect other DAGs or TaskGroups + # that may reuse the same original configuration + render_config = copy.deepcopy(render_config) + execution_config = copy.deepcopy(execution_config) + render_config.project_path = project_config.dbt_project_path + execution_config.project_path = project_config.dbt_project_path + return execution_config, render_config + + def specific_kwargs(**kwargs: dict[str, Any]) -> dict[str, Any]: """ Extract kwargs specific to the cosmos.converter.DbtToAirflowConverter class initialization method. @@ -83,6 +96,88 @@ def validate_arguments( profile_config.validate_profiles_yml() +def validate_initial_user_config( + execution_config: ExecutionConfig, + profile_config: ProfileConfig | None, + project_config: ProjectConfig, + render_config: RenderConfig, + operator_args: dict[str, Any], +): + """ + Validates if the user set the fields as expected. + + :param execution_config: Configuration related to how to run dbt in Airflow tasks + :param profile_config: Configuration related to dbt database configuration (profile) + :param project_config: Configuration related to the overall dbt project + :param render_config: Configuration related to how to convert the dbt workflow into an Airflow DAG + :param operator_args: Arguments to pass to the underlying operators. + """ + if profile_config is None and execution_config.execution_mode not in ( + ExecutionMode.KUBERNETES, + ExecutionMode.DOCKER, + ): + raise CosmosValueError(f"The profile_config is mandatory when using {execution_config.execution_mode}") + + # Since we now support both project_config.dbt_project_path, render_config.project_path and execution_config.project_path + # We need to ensure that only one interface is being used. + if project_config.dbt_project_path and (render_config.project_path or execution_config.project_path): + raise CosmosValueError( + "ProjectConfig.dbt_project_path is mutually exclusive with RenderConfig.dbt_project_path and ExecutionConfig.dbt_project_path." + + "If using RenderConfig.dbt_project_path or ExecutionConfig.dbt_project_path, ProjectConfig.dbt_project_path should be None" + ) + + # Cosmos 2.0 will remove the ability to pass in operator_args with 'env' and 'vars' in place of ProjectConfig.env_vars and + # ProjectConfig.dbt_vars. + if "env" in operator_args: + warn( + "operator_args with 'env' is deprecated since Cosmos 1.3 and will be removed in Cosmos 2.0. Use ProjectConfig.env_vars instead.", + DeprecationWarning, + ) + if project_config.env_vars: + raise CosmosValueError( + "ProjectConfig.env_vars and operator_args with 'env' are mutually exclusive and only one can be used." + ) + if "vars" in operator_args: + warn( + "operator_args with 'vars' is deprecated since Cosmos 1.3 and will be removed in Cosmos 2.0. Use ProjectConfig.vars instead.", + DeprecationWarning, + ) + if project_config.dbt_vars: + raise CosmosValueError( + "ProjectConfig.dbt_vars and operator_args with 'vars' are mutually exclusive and only one can be used." + ) + # Cosmos 2.0 will remove the ability to pass RenderConfig.env_vars in place of ProjectConfig.env_vars, check that both are not set. + if project_config.env_vars and render_config.env_vars: + raise CosmosValueError( + "Both ProjectConfig.env_vars and RenderConfig.env_vars were provided. RenderConfig.env_vars is deprecated since Cosmos 1.3, " + "please use ProjectConfig.env_vars instead." + ) + + +def validate_adapted_user_config( + execution_config: ExecutionConfig | None, project_config: ProjectConfig, render_config: RenderConfig | None +): + """ + Validates if all the necessary fields required by Cosmos to render the DAG are set. + + :param execution_config: Configuration related to how to run dbt in Airflow tasks + :param project_config: Configuration related to the overall dbt project + :param render_config: Configuration related to how to convert the dbt workflow into an Airflow DAG + """ + # At this point, execution_config.project_path should always be non-null + if not execution_config.project_path: + raise CosmosValueError( + "ExecutionConfig.dbt_project_path is required for the execution of dbt tasks in all execution modes." + ) + + # We now have a guaranteed execution_config.project_path, but still need to process render_config.project_path + # We require render_config.project_path when we dont have a manifest + if not project_config.manifest_path and not render_config.project_path: + raise CosmosValueError( + "RenderConfig.dbt_project_path is required for rendering an airflow DAG from a DBT Graph if no manifest is provided." + ) + + class DbtToAirflowConverter: """ Logic common to build an Airflow DbtDag and DbtTaskGroup from a DBT project. @@ -101,7 +196,7 @@ class DbtToAirflowConverter: def __init__( self, project_config: ProjectConfig, - profile_config: ProfileConfig, + profile_config: ProfileConfig | None = None, execution_config: ExecutionConfig | None = None, render_config: RenderConfig | None = None, dag: DAG | None = None, @@ -113,44 +208,21 @@ def __init__( ) -> None: project_config.validate_project() - if not execution_config: - execution_config = ExecutionConfig() - if not render_config: - render_config = RenderConfig() + execution_config = execution_config or ExecutionConfig() + render_config = render_config or RenderConfig() + operator_args = operator_args or {} - # Since we now support both project_config.dbt_project_path, render_config.project_path and execution_config.project_path - # We need to ensure that only one interface is being used. - if project_config.dbt_project_path and (render_config.project_path or execution_config.project_path): - raise CosmosValueError( - "ProjectConfig.dbt_project_path is mutually exclusive with RenderConfig.dbt_project_path and ExecutionConfig.dbt_project_path." - + "If using RenderConfig.dbt_project_path or ExecutionConfig.dbt_project_path, ProjectConfig.dbt_project_path should be None" - ) + validate_initial_user_config(execution_config, profile_config, project_config, render_config, operator_args) # If we are using the old interface, we should migrate it to the new interface # This is safe to do now since we have validated which config interface we're using if project_config.dbt_project_path: - # We copy the configuration so the change does not affect other DAGs or TaskGroups - # that may reuse the same original configuration - render_config = copy.deepcopy(render_config) - execution_config = copy.deepcopy(execution_config) - render_config.project_path = project_config.dbt_project_path - execution_config.project_path = project_config.dbt_project_path - - # At this point, execution_config.project_path should always be non-null - if not execution_config.project_path: - raise CosmosValueError( - "ExecutionConfig.dbt_project_path is required for the execution of dbt tasks in all execution modes." - ) + execution_config, render_config = migrate_to_new_interface(execution_config, project_config, render_config) - # We now have a guaranteed execution_config.project_path, but still need to process render_config.project_path - # We require render_config.project_path when we dont have a manifest - if not project_config.manifest_path and not render_config.project_path: - raise CosmosValueError( - "RenderConfig.dbt_project_path is required for rendering an airflow DAG from a DBT Graph if no manifest is provided." - ) + validate_adapted_user_config(execution_config, project_config, render_config) - if not operator_args: - operator_args = {} + env_vars = project_config.env_vars or operator_args.pop("env", None) + dbt_vars = project_config.dbt_vars or operator_args.pop("vars", None) # Previously, we were creating a cosmos.dbt.project.DbtProject # DbtProject has now been replaced with ProjectConfig directly @@ -167,7 +239,7 @@ def __init__( render_config=render_config, execution_config=execution_config, profile_config=profile_config, - operator_args=operator_args, + dbt_vars=dbt_vars, ) dbt_graph.load(method=render_config.load_method, execution_mode=execution_config.execution_mode) @@ -176,6 +248,8 @@ def __init__( "project_dir": execution_config.project_path, "profile_config": profile_config, "emit_datasets": render_config.emit_datasets, + "env": env_vars, + "vars": dbt_vars, } if execution_config.dbt_executable_path: task_args["dbt_executable_path"] = execution_config.dbt_executable_path diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index a890c137c..e943d9527 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -4,6 +4,7 @@ import json import os import tempfile +import yaml from dataclasses import dataclass, field from pathlib import Path from subprocess import PIPE, Popen @@ -134,13 +135,14 @@ def __init__( render_config: RenderConfig = RenderConfig(), execution_config: ExecutionConfig = ExecutionConfig(), profile_config: ProfileConfig | None = None, - operator_args: dict[str, Any] | None = None, + # dbt_vars only supported for LegacyDbtProject + dbt_vars: dict[str, str] | None = None, ): self.project = project self.render_config = render_config self.profile_config = profile_config self.execution_config = execution_config - self.operator_args = operator_args or {} + self.dbt_vars = dbt_vars or {} def load( self, @@ -190,6 +192,12 @@ def run_dbt_ls( if self.render_config.select: ls_command.extend(["--select", *self.render_config.select]) + if self.project.dbt_vars: + ls_command.extend(["--vars", yaml.dump(self.project.dbt_vars)]) + + if self.render_config.selector: + ls_command.extend(["--selector", self.render_config.selector]) + ls_command.extend(self.local_flags) stdout = run_command(ls_command, tmp_dir, env_vars) @@ -232,10 +240,10 @@ def load_via_dbt_ls(self) -> None: f"Content of the dbt project dir {self.render_config.project_path}: `{os.listdir(self.render_config.project_path)}`" ) tmpdir_path = Path(tmpdir) - create_symlinks(self.render_config.project_path, tmpdir_path) + create_symlinks(self.render_config.project_path, tmpdir_path, self.render_config.dbt_deps) with self.profile_config.ensure_profile(use_mock_values=True) as profile_values, environ( - self.render_config.env_vars + self.project.env_vars or self.render_config.env_vars or {} ): (profile_path, env_vars) = profile_values env = os.environ.copy() @@ -286,6 +294,11 @@ def load_via_custom_parser(self) -> None: """ logger.info("Trying to parse the dbt project `%s` using a custom Cosmos method...", self.project.project_name) + if self.render_config.selector: + raise CosmosLoadDbtException( + "RenderConfig.selector is not yet supported when loading dbt projects using the LoadMode.CUSTOM parser." + ) + if not self.render_config.project_path or not self.execution_config.project_path: raise CosmosLoadDbtException( "Unable to load dbt project without RenderConfig.dbt_project_path and ExecutionConfig.dbt_project_path" @@ -296,7 +309,7 @@ def load_via_custom_parser(self) -> None: dbt_root_path=self.render_config.project_path.parent.as_posix(), dbt_models_dir=self.project.models_path.stem if self.project.models_path else "models", dbt_seeds_dir=self.project.seeds_path.stem if self.project.seeds_path else "seeds", - operator_args=self.operator_args, + dbt_vars=self.dbt_vars, ) nodes = {} models = itertools.chain( @@ -344,6 +357,11 @@ def load_from_dbt_manifest(self) -> None: """ logger.info("Trying to parse the dbt project `%s` using a dbt manifest...", self.project.project_name) + if self.render_config.selector: + raise CosmosLoadDbtException( + "RenderConfig.selector is not yet supported when loading dbt projects using the LoadMode.DBT_MANIFEST parser." + ) + if not self.project.is_manifest_available(): raise CosmosLoadDbtException(f"Unable to load manifest using {self.project.manifest_path}") diff --git a/cosmos/dbt/parser/project.py b/cosmos/dbt/parser/project.py index 278b1a0f7..de506e02d 100644 --- a/cosmos/dbt/parser/project.py +++ b/cosmos/dbt/parser/project.py @@ -130,7 +130,7 @@ class DbtModel: name: str type: DbtModelType path: Path - operator_args: Dict[str, Any] = field(default_factory=dict) + dbt_vars: Dict[str, str] = field(default_factory=dict) config: DbtModelConfig = field(default_factory=DbtModelConfig) def __post_init__(self) -> None: @@ -141,7 +141,6 @@ def __post_init__(self) -> None: return config = DbtModelConfig() - self.var_args: Dict[str, Any] = self.operator_args.get("vars", {}) code = self.path.read_text() if self.type == DbtModelType.DBT_SNAPSHOT: @@ -203,7 +202,7 @@ def _parse_jinja_ref_node(self, base_node: jinja2.nodes.Call) -> str | None: and isinstance(node.args[0], jinja2.nodes.Const) and node.node.name == "var" ): - value += self.var_args[node.args[0].value] + value += self.dbt_vars[node.args[0].value] # type: ignore elif isinstance(first_arg, jinja2.nodes.Const): # and add it to the config value = first_arg.value @@ -272,20 +271,16 @@ class LegacyDbtProject: snapshots_dir: Path = field(init=False) seeds_dir: Path = field(init=False) - operator_args: Dict[str, Any] = field(default_factory=dict) + dbt_vars: Dict[str, str] = field(default_factory=dict) def __post_init__(self) -> None: """ Initializes the parser. """ - if self.dbt_root_path is None: - self.dbt_root_path = "/usr/local/airflow/dags/dbt" - if self.dbt_models_dir is None: - self.dbt_models_dir = "models" - if self.dbt_snapshots_dir is None: - self.dbt_snapshots_dir = "snapshots" - if self.dbt_seeds_dir is None: - self.dbt_seeds_dir = "seeds" + self.dbt_root_path = self.dbt_root_path or "/usr/local/airflow/dags/dbt" + self.dbt_models_dir = self.dbt_models_dir or "models" + self.dbt_snapshots_dir = self.dbt_snapshots_dir or "snapshots" + self.dbt_seeds_dir = self.dbt_seeds_dir or "seeds" # set the project and model dirs self.project_dir = Path(os.path.join(self.dbt_root_path, self.project_name)) @@ -325,7 +320,7 @@ def _handle_csv_file(self, path: Path) -> None: name=model_name, type=DbtModelType.DBT_SEED, path=path, - operator_args=self.operator_args, + dbt_vars=self.dbt_vars, ) # add the model to the project self.seeds[model_name] = model @@ -343,7 +338,7 @@ def _handle_sql_file(self, path: Path) -> None: name=model_name, type=DbtModelType.DBT_MODEL, path=path, - operator_args=self.operator_args, + dbt_vars=self.dbt_vars, ) # add the model to the project self.models[model.name] = model @@ -353,7 +348,7 @@ def _handle_sql_file(self, path: Path) -> None: name=model_name, type=DbtModelType.DBT_SNAPSHOT, path=path, - operator_args=self.operator_args, + dbt_vars=self.dbt_vars, ) # add the snapshot to the project self.snapshots[model.name] = model @@ -414,7 +409,7 @@ def _extract_model_tests( name=f"{test}_{column['name']}_{model_name}", type=DbtModelType.DBT_TEST, path=path, - operator_args=self.operator_args, + dbt_vars=self.dbt_vars, config=DbtModelConfig(upstream_models=set({model_name})), ) tests[test_model.name] = test_model diff --git a/cosmos/dbt/project.py b/cosmos/dbt/project.py index 14b2f5e4b..aff6ed03e 100644 --- a/cosmos/dbt/project.py +++ b/cosmos/dbt/project.py @@ -10,9 +10,12 @@ from typing import Generator -def create_symlinks(project_path: Path, tmp_dir: Path) -> None: +def create_symlinks(project_path: Path, tmp_dir: Path, ignore_dbt_packages: bool) -> None: """Helper function to create symlinks to the dbt project files.""" - ignore_paths = (DBT_LOG_DIR_NAME, DBT_TARGET_DIR_NAME, "dbt_packages", "profiles.yml") + ignore_paths = [DBT_LOG_DIR_NAME, DBT_TARGET_DIR_NAME, "profiles.yml"] + if ignore_dbt_packages: + # this is linked to dbt deps so if dbt deps is true then ignore existing dbt_packages folder + ignore_paths.append("dbt_packages") for child_name in os.listdir(project_path): if child_name not in ignore_paths: os.symlink(project_path / child_name, tmp_dir / child_name) diff --git a/cosmos/dbt/selector.py b/cosmos/dbt/selector.py index c7316dc75..76ec31a54 100644 --- a/cosmos/dbt/selector.py +++ b/cosmos/dbt/selector.py @@ -1,7 +1,9 @@ from __future__ import annotations -from pathlib import Path import copy - +import re +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Any from cosmos.constants import DbtResourceType @@ -16,11 +18,154 @@ PATH_SELECTOR = "path:" TAG_SELECTOR = "tag:" CONFIG_SELECTOR = "config." - +PLUS_SELECTOR = "+" +GRAPH_SELECTOR_REGEX = r"^([0-9]*\+)?([^\+]+)(\+[0-9]*)?$|" logger = get_logger(__name__) +@dataclass +class GraphSelector: + """ + Implements dbt graph operator selectors: + model_a + +model_b + model_c+ + +model_d+ + 2+model_e + model_f+3 + + https://docs.getdbt.com/reference/node-selection/graph-operators + """ + + node_name: str + precursors: str | None + descendants: str | None + + @property + def precursors_depth(self) -> int: + """ + Calculates the depth/degrees/generations of precursors (parents). + Return: + -1: if it should return all the generations of precursors + 0: if it shouldn't return any precursors + >0: upperbound number of parent generations + """ + if not self.precursors: + return 0 + if self.precursors == "+": + return -1 + else: + return int(self.precursors[:-1]) + + @property + def descendants_depth(self) -> int: + """ + Calculates the depth/degrees/generations of descendants (children). + Return: + -1: if it should return all the generations of children + 0: if it shouldn't return any children + >0: upperbound of children generations + """ + if not self.descendants: + return 0 + if self.descendants == "+": + return -1 + else: + return int(self.descendants[1:]) + + @staticmethod + def parse(text: str) -> GraphSelector | None: + """ + Parse a string and identify if there are graph selectors, including the desired node name, descendants and + precursors. Return a GraphSelector instance if the pattern matches. + """ + regex_match = re.search(GRAPH_SELECTOR_REGEX, text) + if regex_match: + precursors, node_name, descendants = regex_match.groups() + return GraphSelector(node_name, precursors, descendants) + return None + + def select_node_precursors(self, nodes: dict[str, DbtNode], root_id: str, selected_nodes: set[str]) -> None: + """ + Parse original nodes and add the precursor nodes related to this config to the selected_nodes set. + + :param nodes: Original dbt nodes list + :param root_id: Unique identifier of self.node_name + :param selected_nodes: Set where precursor nodes will be added to. + """ + if self.precursors: + depth = self.precursors_depth + previous_generation = {root_id} + processed_nodes = set() + while depth and previous_generation: + new_generation: set[str] = set() + for node_id in previous_generation: + if node_id not in processed_nodes: + new_generation.update(set(nodes[node_id].depends_on)) + processed_nodes.add(node_id) + selected_nodes.update(new_generation) + previous_generation = new_generation + depth -= 1 + + def select_node_descendants(self, nodes: dict[str, DbtNode], root_id: str, selected_nodes: set[str]) -> None: + """ + Parse original nodes and add the descendant nodes related to this config to the selected_nodes set. + + :param nodes: Original dbt nodes list + :param root_id: Unique identifier of self.node_name + :param selected_nodes: Set where descendant nodes will be added to. + """ + if self.descendants: + children_by_node = defaultdict(set) + # Index nodes by parent id + # We could optimize by doing this only once for the dbt project and giving it + # as a parameter to the GraphSelector + for node_id, node in nodes.items(): + for parent_id in node.depends_on: + children_by_node[parent_id].add(node_id) + + depth = self.descendants_depth + previous_generation = {root_id} + processed_nodes = set() + while depth and previous_generation: + new_generation: set[str] = set() + for node_id in previous_generation: + if node_id not in processed_nodes: + new_generation.update(children_by_node[node_id]) + processed_nodes.add(node_id) + selected_nodes.update(new_generation) + previous_generation = new_generation + depth -= 1 + + def filter_nodes(self, nodes: dict[str, DbtNode]) -> set[str]: + """ + Given a dictionary with the original dbt project nodes, applies the current graph selector to + identify the subset of nodes that matches the selection criteria. + + :param nodes: dbt project nodes + :return: set of node ids that matches current graph selector + """ + selected_nodes: set[str] = set() + + # Index nodes by name, we can improve performance by doing this once + # for multiple GraphSelectors + node_by_name = {} + for node_id, node in nodes.items(): + node_by_name[node.name] = node_id + + if self.node_name in node_by_name: + root_id = node_by_name[self.node_name] + else: + logger.warn(f"Selector {self.node_name} not found.") + return selected_nodes + + selected_nodes.add(root_id) + self.select_node_precursors(nodes, root_id, selected_nodes) + self.select_node_descendants(nodes, root_id, selected_nodes) + return selected_nodes + + class SelectorConfig: """ Represents a select/exclude statement. @@ -43,11 +188,12 @@ def __init__(self, project_dir: Path | None, statement: str): self.tags: list[str] = [] self.config: dict[str, str] = {} self.other: list[str] = [] + self.graph_selectors: list[GraphSelector] = [] self.load_from_statement(statement) @property def is_empty(self) -> bool: - return not (self.paths or self.tags or self.config or self.other) + return not (self.paths or self.tags or self.config or self.graph_selectors or self.other) def load_from_statement(self, statement: str) -> None: """ @@ -61,27 +207,45 @@ def load_from_statement(self, statement: str) -> None: https://docs.getdbt.com/reference/node-selection/yaml-selectors """ items = statement.split(",") + for item in items: if item.startswith(PATH_SELECTOR): - index = len(PATH_SELECTOR) - if self.project_dir: - self.paths.append(self.project_dir / Path(item[index:])) - else: - self.paths.append(Path(item[index:])) + self._parse_path_selector(item) elif item.startswith(TAG_SELECTOR): - index = len(TAG_SELECTOR) - self.tags.append(item[index:]) + self._parse_tag_selector(item) elif item.startswith(CONFIG_SELECTOR): - index = len(CONFIG_SELECTOR) - key, value = item[index:].split(":") - if key in SUPPORTED_CONFIG: - self.config[key] = value + self._parse_config_selector(item) + else: + self._parse_unknown_selector(item) + + def _parse_unknown_selector(self, item: str) -> None: + if item: + graph_selector = GraphSelector.parse(item) + if graph_selector is not None: + self.graph_selectors.append(graph_selector) else: self.other.append(item) logger.warning("Unsupported select statement: %s", item) + def _parse_config_selector(self, item: str) -> None: + index = len(CONFIG_SELECTOR) + key, value = item[index:].split(":") + if key in SUPPORTED_CONFIG: + self.config[key] = value + + def _parse_tag_selector(self, item: str) -> None: + index = len(TAG_SELECTOR) + self.tags.append(item[index:]) + + def _parse_path_selector(self, item: str) -> None: + index = len(PATH_SELECTOR) + if self.project_dir: + self.paths.append(self.project_dir / Path(item[index:])) + else: + self.paths.append(Path(item[index:])) + def __repr__(self) -> str: - return f"SelectorConfig(paths={self.paths}, tags={self.tags}, config={self.config}, other={self.other})" + return f"SelectorConfig(paths={self.paths}, tags={self.tags}, config={self.config}, other={self.other}, graph_selectors={self.graph_selectors})" class NodeSelector: @@ -95,7 +259,9 @@ class NodeSelector: def __init__(self, nodes: dict[str, DbtNode], config: SelectorConfig) -> None: self.nodes = nodes self.config = config + self.selected_nodes: set[str] = set() + @property def select_nodes_ids_by_intersection(self) -> set[str]: """ Return a list of node ids which matches the configuration defined in config. @@ -107,14 +273,19 @@ def select_nodes_ids_by_intersection(self) -> set[str]: if self.config.is_empty: return set(self.nodes.keys()) - self.selected_nodes: set[str] = set() + selected_nodes: set[str] = set() self.visited_nodes: set[str] = set() for node_id, node in self.nodes.items(): if self._should_include_node(node_id, node): - self.selected_nodes.add(node_id) + selected_nodes.add(node_id) + + if self.config.graph_selectors: + nodes_by_graph_selector = self.select_by_graph_operator() + selected_nodes = selected_nodes.intersection(nodes_by_graph_selector) - return self.selected_nodes + self.selected_nodes = selected_nodes + return selected_nodes def _should_include_node(self, node_id: str, node: DbtNode) -> bool: "Checks if a single node should be included. Only runs once per node with caching." @@ -175,6 +346,22 @@ def _is_path_matching(self, node: DbtNode) -> bool: return self._should_include_node(node.depends_on[0], model_node) return False + def select_by_graph_operator(self) -> set[str]: + """ + Return a list of node ids which match the configuration defined in the config. + + Return all nodes that are parents (or parents from parents) of the root defined in the configuration. + + References: + https://docs.getdbt.com/reference/node-selection/syntax + https://docs.getdbt.com/reference/node-selection/yaml-selectors + """ + selected_nodes_by_selector: list[set[str]] = [] + + for graph_selector in self.config.graph_selectors: + selected_nodes_by_selector.append(graph_selector.filter_nodes(self.nodes)) + return set.intersection(*selected_nodes_by_selector) + def retrieve_by_label(statement_list: list[str], label: str) -> set[str]: """ @@ -189,7 +376,7 @@ def retrieve_by_label(statement_list: list[str], label: str) -> set[str]: for statement in statement_list: config = SelectorConfig(Path(), statement) item_values = getattr(config, label) - label_values = label_values.union(item_values) + label_values.update(item_values) return label_values @@ -213,35 +400,53 @@ def select_nodes( if not select and not exclude: return nodes - # validates select and exclude filters - filters = [["select", select], ["exclude", exclude]] - for filter_type, filter in filters: - for filter_parameter in filter: - if filter_parameter.startswith(PATH_SELECTOR) or filter_parameter.startswith(TAG_SELECTOR): - continue - elif any([filter_parameter.startswith(CONFIG_SELECTOR + config + ":") for config in SUPPORTED_CONFIG]): - continue - else: - raise CosmosValueError(f"Invalid {filter_type} filter: {filter_parameter}") + validate_filters(exclude, select) + subset_ids = apply_select_filter(nodes, project_dir, select) + if select: + nodes = get_nodes_from_subset(nodes, subset_ids) + exclude_ids = apply_exclude_filter(nodes, project_dir, exclude) + subset_ids = set(nodes.keys()) - exclude_ids - subset_ids: set[str] = set() + return get_nodes_from_subset(nodes, subset_ids) - for statement in select: - config = SelectorConfig(project_dir, statement) - node_selector = NodeSelector(nodes, config) - select_ids = node_selector.select_nodes_ids_by_intersection() - subset_ids = subset_ids.union(set(select_ids)) - if select: - nodes = {id_: nodes[id_] for id_ in subset_ids} +def get_nodes_from_subset(nodes: dict[str, DbtNode], subset_ids: set[str]) -> dict[str, DbtNode]: + nodes = {id_: nodes[id_] for id_ in subset_ids} + return nodes - nodes_ids = set(nodes.keys()) +def apply_exclude_filter(nodes: dict[str, DbtNode], project_dir: Path | None, exclude: list[str]) -> set[str]: exclude_ids: set[str] = set() for statement in exclude: config = SelectorConfig(project_dir, statement) node_selector = NodeSelector(nodes, config) - exclude_ids = exclude_ids.union(set(node_selector.select_nodes_ids_by_intersection())) - subset_ids = set(nodes_ids) - set(exclude_ids) + exclude_ids.update(node_selector.select_nodes_ids_by_intersection) + return exclude_ids - return {id_: nodes[id_] for id_ in subset_ids} + +def apply_select_filter(nodes: dict[str, DbtNode], project_dir: Path | None, select: list[str]) -> set[str]: + subset_ids: set[str] = set() + for statement in select: + config = SelectorConfig(project_dir, statement) + node_selector = NodeSelector(nodes, config) + select_ids = node_selector.select_nodes_ids_by_intersection + subset_ids.update(select_ids) + return subset_ids + + +def validate_filters(exclude: list[str], select: list[str]) -> None: + """ + Validate select and exclude filters. + """ + filters = [["select", select], ["exclude", exclude]] + for filter_type, filter in filters: + for filter_parameter in filter: + if ( + filter_parameter.startswith(PATH_SELECTOR) + or filter_parameter.startswith(TAG_SELECTOR) + or PLUS_SELECTOR in filter_parameter + or any([filter_parameter.startswith(CONFIG_SELECTOR + config + ":") for config in SUPPORTED_CONFIG]) + ): + continue + elif ":" in filter_parameter: + raise CosmosValueError(f"Invalid {filter_type} filter: {filter_parameter}") diff --git a/cosmos/operators/kubernetes.py b/cosmos/operators/kubernetes.py index af0988a6a..b844716de 100644 --- a/cosmos/operators/kubernetes.py +++ b/cosmos/operators/kubernetes.py @@ -56,10 +56,12 @@ def __init__(self, profile_config: ProfileConfig | None = None, **kwargs: Any) - def build_env_args(self, env: dict[str, str | bytes | PathLike[Any]]) -> None: env_vars_dict: dict[str, str] = dict() + for env_var_key, env_var_value in env.items(): + env_vars_dict[env_var_key] = str(env_var_value) for env_var in self.env_vars: env_vars_dict[env_var.name] = env_var.value - self.env_vars: list[Any] = convert_env_vars({**env, **env_vars_dict}) + self.env_vars: list[Any] = convert_env_vars(env_vars_dict) def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any: self.build_kube_args(context, cmd_flags) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 6eea764ad..b0b572430 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -199,7 +199,7 @@ def run_command( self.project_dir, ) - create_symlinks(Path(self.project_dir), Path(tmp_project_dir)) + create_symlinks(Path(self.project_dir), Path(tmp_project_dir), self.install_deps) with self.profile_config.ensure_profile() as profile_values: (profile_path, env_vars) = profile_values @@ -548,6 +548,15 @@ def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.base_cmd = ["docs", "generate"] + self.check_static_flag() + + def check_static_flag(self) -> None: + flag = "--static" + if self.dbt_cmd_flags: + if flag in self.dbt_cmd_flags: + # For the --static flag we only upload the generated static_index.html file + self.required_files = ["static_index.html"] + class DbtDocsCloudLocalOperator(DbtDocsLocalOperator, ABC): """ @@ -578,7 +587,7 @@ def upload_to_cloud_storage(self, project_dir: str) -> None: class DbtDocsS3LocalOperator(DbtDocsCloudLocalOperator): """ - Executes `dbt docs generate` command and upload to S3 storage. Returns the S3 path to the generated documentation. + Executes `dbt docs generate` command and upload to S3 storage. :param connection_id: S3's Airflow connection ID :param bucket_name: S3's bucket name diff --git a/cosmos/profiles/athena/access_key.py b/cosmos/profiles/athena/access_key.py index a8f71c2b7..02de2be24 100644 --- a/cosmos/profiles/athena/access_key.py +++ b/cosmos/profiles/athena/access_key.py @@ -3,20 +3,33 @@ from typing import Any +from cosmos.exceptions import CosmosValueError + from ..base import BaseProfileMapping class AthenaAccessKeyProfileMapping(BaseProfileMapping): """ - Maps Airflow AWS connections to a dbt Athena profile using an access key id and secret access key. + Uses the Airflow AWS Connection provided to get_credentials() to generate the profile for dbt. - https://docs.getdbt.com/docs/core/connect-data-platform/athena-setup https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/connections/aws.html + + + This behaves similarly to other provider operators such as the AWS Athena Operator. + Where you pass the aws_conn_id and the operator will generate the credentials for you. + + https://registry.astronomer.io/providers/amazon/versions/latest/modules/athenaoperator + + Information about the dbt Athena profile that is generated can be found here: + + https://github.com/dbt-athena/dbt-athena?tab=readme-ov-file#configuring-your-profile + https://docs.getdbt.com/docs/core/connect-data-platform/athena-setup """ airflow_connection_type: str = "aws" dbt_profile_type: str = "athena" is_community: bool = True + temporary_credentials = None required_fields = [ "aws_access_key_id", @@ -26,11 +39,7 @@ class AthenaAccessKeyProfileMapping(BaseProfileMapping): "s3_staging_dir", "schema", ] - secret_fields = ["aws_secret_access_key", "aws_session_token"] airflow_param_mapping = { - "aws_access_key_id": "login", - "aws_secret_access_key": "password", - "aws_session_token": "extra.aws_session_token", "aws_profile_name": "extra.aws_profile_name", "database": "extra.database", "debug_query_state": "extra.debug_query_state", @@ -49,11 +58,43 @@ class AthenaAccessKeyProfileMapping(BaseProfileMapping): @property def profile(self) -> dict[str, Any | None]: "Gets profile. The password is stored in an environment variable." + + self.temporary_credentials = self._get_temporary_credentials() # type: ignore + profile = { **self.mapped_params, **self.profile_args, - # aws_secret_access_key and aws_session_token should always get set as env var + "aws_access_key_id": self.temporary_credentials.access_key, "aws_secret_access_key": self.get_env_var_format("aws_secret_access_key"), "aws_session_token": self.get_env_var_format("aws_session_token"), } + return self.filter_null(profile) + + @property + def env_vars(self) -> dict[str, str]: + "Overwrites the env_vars for athena, Returns a dictionary of environment variables that should be set based on the self.temporary_credentials." + + if self.temporary_credentials is None: + raise CosmosValueError(f"Could not find the athena credentials.") + + env_vars = {} + + env_secret_key_name = self.get_env_var_name("aws_secret_access_key") + env_session_token_name = self.get_env_var_name("aws_session_token") + + env_vars[env_secret_key_name] = str(self.temporary_credentials.secret_key) + env_vars[env_session_token_name] = str(self.temporary_credentials.token) + + return env_vars + + def _get_temporary_credentials(self): # type: ignore + """ + Helper function to retrieve temporary short lived credentials + Returns an object including access_key, secret_key and token + """ + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook + + hook = AwsGenericHook(self.conn_id) # type: ignore + credentials = hook.get_credentials() + return credentials diff --git a/cosmos/profiles/vertica/user_pass.py b/cosmos/profiles/vertica/user_pass.py index ccaaf301d..e016b612c 100644 --- a/cosmos/profiles/vertica/user_pass.py +++ b/cosmos/profiles/vertica/user_pass.py @@ -9,8 +9,14 @@ class VerticaUserPasswordProfileMapping(BaseProfileMapping): """ Maps Airflow Vertica connections using username + password authentication to dbt profiles. - https://docs.getdbt.com/reference/warehouse-setups/vertica-setup - https://airflow.apache.org/docs/apache-airflow-providers-vertica/stable/connections/vertica.html + .. note:: + Use Airflow connection ``schema`` for vertica ``database`` to keep it consistent with other connection types and profiles. \ + The Vertica Airflow provider hook `assumes this `_. + This seems to be a common approach also for `Postgres `_, \ + Redshift and Exasol since there is no ``database`` field in Airflow connection and ``schema`` is not required for the database connection. + .. seealso:: + https://docs.getdbt.com/reference/warehouse-setups/vertica-setup + https://airflow.apache.org/docs/apache-airflow-providers-vertica/stable/connections/vertica.html """ airflow_connection_type: str = "vertica" @@ -31,8 +37,7 @@ class VerticaUserPasswordProfileMapping(BaseProfileMapping): "username": "login", "password": "password", "port": "port", - "schema": "schema", - "database": "extra.database", + "database": "schema", "autocommit": "extra.autocommit", "backup_server_node": "extra.backup_server_node", "binary_transfer": "extra.binary_transfer", diff --git a/dev/dags/dbt/simple/models/top_animations.sql b/dev/dags/dbt/simple/models/top_animations.sql index 2b365b09c..cfae1c595 100644 --- a/dev/dags/dbt/simple/models/top_animations.sql +++ b/dev/dags/dbt/simple/models/top_animations.sql @@ -1,4 +1,8 @@ -{{ config(materialized='table') }} +{{ config( + materialized='table', + alias=var('animation_alias', 'top_animations') + ) +}} SELECT Title, Rating FROM {{ ref('movies_ratings_simplified') }} diff --git a/dev/dags/example_cosmos_sources.py b/dev/dags/example_cosmos_sources.py index 157b3adb3..0553b2f10 100644 --- a/dev/dags/example_cosmos_sources.py +++ b/dev/dags/example_cosmos_sources.py @@ -62,19 +62,24 @@ def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs): node_converters={ DbtResourceType("source"): convert_source, # known dbt node type to Cosmos (part of DbtResourceType) DbtResourceType("exposure"): convert_exposure, # dbt node type new to Cosmos (will be added to DbtResourceType) - }, + } +) + +# `ProjectConfig` can pass dbt variables and environment variables to dbt commands. Below is an example of +# passing a required env var for the profiles.yml file and a dbt variable that is used for rendering and +# executing dbt models. +project_config = ProjectConfig( + DBT_ROOT_PATH / "simple", env_vars={"DBT_SQLITE_PATH": DBT_SQLITE_PATH}, + dbt_vars={"animation_alias": "top_5_animated_movies"}, ) example_cosmos_sources = DbtDag( # dbt/cosmos-specific parameters - project_config=ProjectConfig( - DBT_ROOT_PATH / "simple", - ), + project_config=project_config, profile_config=profile_config, render_config=render_config, - operator_args={"env": {"DBT_SQLITE_PATH": DBT_SQLITE_PATH}}, # normal dag parameters schedule_interval="@daily", start_date=datetime(2023, 1, 1), diff --git a/docs/configuration/generating-docs.rst b/docs/configuration/generating-docs.rst index 88459fd14..6112ebcee 100644 --- a/docs/configuration/generating-docs.rst +++ b/docs/configuration/generating-docs.rst @@ -83,6 +83,34 @@ You can use the :class:`~cosmos.operators.DbtDocsGCSOperator` to generate and up bucket_name="test_bucket", ) +Static Flag +~~~~~~~~~~~~~~~~~~~~~~~ + +All of the DbtDocsOperator accept the ``--static`` flag. To learn more about the static flag, check out the `original PR on dbt-core `_. +The static flag is used to generate a single doc file that can be hosted directly from cloud storage. +By having a single documentation file, you can make use of Access control can be configured through Identity-Aware Proxy (IAP), and making it easy to host. + +.. note:: + The static flag is only available from dbt-core >=1.7 + +The following code snippet shows how to provide this flag with the default jaffle_shop project: + + +.. code-block:: python + + from cosmos.operators import DbtDocsGCSOperator + + # then, in your DAG code: + generate_dbt_docs_aws = DbtDocsGCSOperator( + task_id="generate_dbt_docs_gcs", + project_dir="path/to/jaffle_shop", + profile_config=profile_config, + # docs-specific arguments + connection_id="test_gcs", + bucket_name="test_bucket", + dbt_cmd_flags=["--static"], + ) + Custom Callback ~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/configuration/operator-args.rst b/docs/configuration/operator-args.rst index 9d533bf13..5ddbe6565 100644 --- a/docs/configuration/operator-args.rst +++ b/docs/configuration/operator-args.rst @@ -47,12 +47,12 @@ dbt-related - ``dbt_cmd_flags``: List of command flags to pass to ``dbt`` command, added after dbt subcommand - ``dbt_cmd_global_flags``: List of ``dbt`` `global flags `_ to be passed to the ``dbt`` command, before the subcommand - ``dbt_executable_path``: Path to dbt executable. -- ``env``: Declare, using a Python dictionary, values to be set as environment variables when running ``dbt`` commands. +- ``env``: (Deprecated since Cosmos 1.3 use ``ProjectConfig.env_vars`` instead) Declare, using a Python dictionary, values to be set as environment variables when running ``dbt`` commands. - ``fail_fast``: ``dbt`` exits immediately if ``dbt`` fails to process a resource. - ``models``: Specifies which nodes to include. - ``no_version_check``: If set, skip ensuring ``dbt``'s version matches the one specified in the ``dbt_project.yml``. - ``quiet``: run ``dbt`` in silent mode, only displaying its error logs. -- ``vars``: Supply variables to the project. This argument overrides variables defined in the ``dbt_project.yml``. +- ``vars``: (Deprecated since Cosmos 1.3 use ``ProjectConfig.dbt_vars`` instead) Supply variables to the project. This argument overrides variables defined in the ``dbt_project.yml``. - ``warn_error``: convert ``dbt`` warnings into errors. Airflow-related diff --git a/docs/configuration/project-config.rst b/docs/configuration/project-config.rst index c1d952f6e..c062a1de5 100644 --- a/docs/configuration/project-config.rst +++ b/docs/configuration/project-config.rst @@ -1,8 +1,8 @@ Project Config ================ -The ``cosmos.config.ProjectConfig`` allows you to specify information about where your dbt project is located. It -takes the following arguments: +The ``cosmos.config.ProjectConfig`` allows you to specify information about where your dbt project is located and project +variables that should be used for rendering and execution. It takes the following arguments: - ``dbt_project_path``: The full path to your dbt project. This directory should have a ``dbt_project.yml`` file - ``models_relative_path``: The path to your models directory, relative to the ``dbt_project_path``. This defaults to @@ -16,7 +16,13 @@ takes the following arguments: - ``project_name`` : The name of the project. If ``dbt_project_path`` is provided, the ``project_name`` defaults to the folder name containing ``dbt_project.yml``. If ``dbt_project_path`` is not provided, and ``manifest_path`` is provided, ``project_name`` is required as the name can not be inferred from ``dbt_project_path`` - +- ``dbt_vars``: (new in v1.3) A dictionary of dbt variables for the project rendering and execution. This argument overrides variables + defined in the dbt_project.yml file. The dictionary of variables is dumped to a yaml string and passed to dbt commands + as the --vars argument. Variables are only supported for rendering when using ``RenderConfig.LoadMode.DBT_LS`` and + ``RenderConfig.LoadMode.CUSTOM`` load mode. Variables using `Airflow templating `_ + will only be rendered at execution time, not at render time. +- ``env_vars``: (new in v1.3) A dictionary of environment variables used for rendering and execution. Rendering with + env vars is only supported when using ``RenderConfig.LoadMode.DBT_LS`` load mode. Project Config Example ---------------------- @@ -31,4 +37,10 @@ Project Config Example seeds_relative_path="data", snapshots_relative_path="snapshots", manifest_path="/path/to/manifests", + env_vars={"MY_ENV_VAR": "my_env_value"}, + dbt_vars={ + "my_dbt_var": "my_value", + "start_time": "{{ data_interval_start.strftime('%Y%m%d%H%M%S') }}", + "end_time": "{{ data_interval_end.strftime('%Y%m%d%H%M%S') }}", + }, ) diff --git a/docs/configuration/render-config.rst b/docs/configuration/render-config.rst index 5e1c23824..6d669d0a5 100644 --- a/docs/configuration/render-config.rst +++ b/docs/configuration/render-config.rst @@ -11,10 +11,11 @@ The ``RenderConfig`` class takes the following arguments: - ``test_behavior``: how to run tests. Defaults to running a model's tests immediately after the model is run. For more information, see the `Testing Behavior `_ section. - ``load_method``: how to load your dbt project. See `Parsing Methods `_ for more information. - ``select`` and ``exclude``: which models to include or exclude from your DAGs. See `Selecting & Excluding `_ for more information. +- ``selector``: (new in v1.3) name of a dbt YAML selector to use for DAG parsing. Only supported when using ``load_method=LoadMode.DBT_LS``. See `Selecting & Excluding `_ for more information. - ``dbt_deps``: A Boolean to run dbt deps when using dbt ls for dag parsing. Default True - ``node_converters``: a dictionary mapping a ``DbtResourceType`` into a callable. Users can control how to render dbt nodes in Airflow. Only supported when using ``load_method=LoadMode.DBT_MANIFEST`` or ``LoadMode.DBT_LS``. Find more information below. - ``dbt_executable_path``: The path to the dbt executable for dag generation. Defaults to dbt if available on the path. -- ``env_vars``: A dictionary of environment variables for rendering. Only supported when using ``load_method=LoadMode.DBT_LS``. +- ``env_vars``: (available in v1.2.5, use``ProjectConfig.env_vars`` for v1.3.0 onwards) A dictionary of environment variables for rendering. Only supported when using ``load_method=LoadMode.DBT_LS``. - ``dbt_project_path``: Configures the DBT project location accessible on their airflow controller for DAG rendering - Required when using ``load_method=LoadMode.DBT_LS`` or ``load_method=LoadMode.CUSTOM`` Customizing how nodes are rendered (experimental) diff --git a/docs/configuration/selecting-excluding.rst b/docs/configuration/selecting-excluding.rst index fadea1485..01ee536b0 100644 --- a/docs/configuration/selecting-excluding.rst +++ b/docs/configuration/selecting-excluding.rst @@ -3,14 +3,22 @@ Selecting & Excluding ======================= -Cosmos allows you to filter to a subset of your dbt project in each ``DbtDag`` / ``DbtTaskGroup`` using the ``select`` and ``exclude`` parameters in the ``RenderConfig`` class. +Cosmos allows you to filter to a subset of your dbt project in each ``DbtDag`` / ``DbtTaskGroup`` using the ``select `` and ``exclude`` parameters in the ``RenderConfig`` class. + + Since Cosmos 1.3, the ``selector`` parameter is also available in ``RenderConfig`` when using the ``LoadMode.DBT_LS`` to parse the dbt project into Airflow. + + +Using ``select`` and ``exclude`` +-------------------------------- The ``select`` and ``exclude`` parameters are lists, with values like the following: - ``tag:my_tag``: include/exclude models with the tag ``my_tag`` - ``config.materialized:table``: include/exclude models with the config ``materialized: table`` - ``path:analytics/tables``: include/exclude models in the ``analytics/tables`` directory - +- ``+node_name+1`` (graph operators): include/exclude the node with name ``node_name``, all its parents, and its first generation of children (`dbt graph selector docs `_) +- ``tag:my_tag,+node_name`` (intersection): include/exclude ``node_name`` and its parents if they have the tag ``my_tag`` (`dbt set operator docs `_) +- ``['tag:first_tag', 'tag:second_tag']`` (union): include/exclude nodes that have either ``tag:first_tag`` or ``tag:second_tag`` .. note:: @@ -51,3 +59,55 @@ Examples: select=["path:analytics/tables"], ) ) + + +.. code-block:: python + + from cosmos import DbtDag, RenderConfig + + jaffle_shop = DbtDag( + render_config=RenderConfig( + select=["tag:include_tag1", "tag:include_tag2"], # union + ) + ) + +.. code-block:: python + + from cosmos import DbtDag, RenderConfig + + jaffle_shop = DbtDag( + render_config=RenderConfig( + select=["tag:include_tag1,tag:include_tag2"], # intersection + ) + ) + +.. code-block:: python + + from cosmos import DbtDag, RenderConfig + + jaffle_shop = DbtDag( + render_config=RenderConfig( + exclude=["node_name+"], # node_name and its children + ) + ) + +Using ``selector`` +-------------------------------- +.. note:: + Only currently supported using the ``dbt_ls`` parsing method since Cosmos 1.3 where the selector is passed directly to the dbt CLI command. \ + If ``select`` and/or ``exclude`` are used with ``selector``, dbt will ignore the ``select`` and ``exclude`` parameters. + +The ``selector`` parameter is a string that references a `dbt YAML selector `_ already defined in a dbt project. + +Examples: + +.. code-block:: python + + from cosmos import DbtDag, RenderConfig, LoadMode + + jaffle_shop = DbtDag( + render_config=RenderConfig( + selector="my_selector", # this selector must be defined in your dbt project + load_method=LoadMode.DBT_LS, + ) + ) diff --git a/docs/contributors.rst b/docs/contributors.rst index 273358d3c..16f7dba17 100644 --- a/docs/contributors.rst +++ b/docs/contributors.rst @@ -11,6 +11,7 @@ Committers * Chris Hronek (`@chrishronek `_) * Harel Shein (`@harels `_) * Julian LaNeve (`@jlaneve `_) +* Justin Bandoro (`@jbandoro `_) * Tatiana Al-Chueyr (`@tatiana `_) diff --git a/pyproject.toml b/pyproject.toml index fe399daee..9d367c075 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ dbt-all = [ ] dbt-athena = [ "dbt-athena-community", + "apache-airflow-providers-amazon>=8.0.0", ] dbt-bigquery = [ "dbt-bigquery", @@ -110,7 +111,6 @@ tests = [ "mypy", "sqlalchemy-stubs", # Change when sqlalchemy is upgraded https://docs.sqlalchemy.org/en/14/orm/extensions/mypy.html ] - docker = [ "apache-airflow-providers-docker>=3.5.0", ] @@ -121,7 +121,6 @@ pydantic = [ "pydantic>=1.10.0,<2.0.0", ] - [project.entry-points.cosmos] provider_info = "cosmos:get_provider_info" @@ -138,6 +137,9 @@ include = [ "/cosmos", ] +[tool.hatch.build.targets.wheel] +packages = ["cosmos"] + ###################################### # TESTING ###################################### @@ -255,10 +257,10 @@ no_warn_unused_ignores = true [tool.ruff] line-length = 120 +[tool.ruff.lint] +select = ["C901"] +[tool.ruff.lint.mccabe] +max-complexity = 8 [tool.distutils.bdist_wheel] universal = true - -[tool.flake8] -max-complexity = 10 -select = "C" diff --git a/tests/dbt/parser/test_project.py b/tests/dbt/parser/test_project.py index 4f13a3eb3..31fe7e18d 100644 --- a/tests/dbt/parser/test_project.py +++ b/tests/dbt/parser/test_project.py @@ -219,6 +219,6 @@ def test_dbtmodelconfig_with_vars(tmp_path): name="some_name", type=DbtModelType.DBT_MODEL, path=path_with_sources, - operator_args={"vars": {"country_code": "us"}}, + dbt_vars={"country_code": "us"}, ) assert "stg_customers_us" in dbt_model.config.upstream_models diff --git a/tests/dbt/test_graph.py b/tests/dbt/test_graph.py index a424976a1..2816fd07a 100644 --- a/tests/dbt/test_graph.py +++ b/tests/dbt/test_graph.py @@ -1,7 +1,8 @@ import shutil import tempfile from pathlib import Path -from unittest.mock import patch +from unittest.mock import patch, MagicMock +import yaml import pytest @@ -16,6 +17,7 @@ run_command, ) from cosmos.profiles import PostgresUserPasswordProfileMapping +from subprocess import Popen, PIPE DBT_PROJECTS_ROOT_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt" DBT_PROJECT_NAME = "jaffle_shop" @@ -42,6 +44,18 @@ def tmp_dbt_project_dir(): shutil.rmtree(tmp_dir, ignore_errors=True) # delete directory +@pytest.fixture +def postgres_profile_config() -> ProfileConfig: + return ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="airflow_db", + profile_args={"schema": "public"}, + ), + ) + + @pytest.mark.parametrize( "unique_id,expected_name, expected_select", [ @@ -218,7 +232,9 @@ def test_load( @pytest.mark.integration @patch("cosmos.dbt.graph.Popen") -def test_load_via_dbt_ls_does_not_create_target_logs_in_original_folder(mock_popen, tmp_dbt_project_dir): +def test_load_via_dbt_ls_does_not_create_target_logs_in_original_folder( + mock_popen, tmp_dbt_project_dir, postgres_profile_config +): mock_popen().communicate.return_value = ("", "") mock_popen().returncode = 0 assert not (tmp_dbt_project_dir / "target").exists() @@ -231,14 +247,7 @@ def test_load_via_dbt_ls_does_not_create_target_logs_in_original_folder(mock_pop project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) dbt_graph.load_via_dbt_ls() assert not (tmp_dbt_project_dir / "target").exists() @@ -250,7 +259,7 @@ def test_load_via_dbt_ls_does_not_create_target_logs_in_original_folder(mock_pop @pytest.mark.integration -def test_load_via_dbt_ls_with_exclude(): +def test_load_via_dbt_ls_with_exclude(postgres_profile_config): project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) render_config = RenderConfig( dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME, select=["*customers*"], exclude=["*orders*"] @@ -260,14 +269,7 @@ def test_load_via_dbt_ls_with_exclude(): project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) dbt_graph.load_via_dbt_ls() @@ -299,7 +301,7 @@ def test_load_via_dbt_ls_with_exclude(): @pytest.mark.integration @pytest.mark.parametrize("project_name", ("jaffle_shop", "jaffle_shop_python")) -def test_load_via_dbt_ls_without_exclude(project_name): +def test_load_via_dbt_ls_without_exclude(project_name, postgres_profile_config): project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name) render_config = RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) execution_config = ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) @@ -307,14 +309,7 @@ def test_load_via_dbt_ls_without_exclude(project_name): project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) dbt_graph.load_via_dbt_ls() @@ -411,7 +406,7 @@ def test_load_via_dbt_ls_with_sources(load_method): @pytest.mark.integration -def test_load_via_dbt_ls_without_dbt_deps(): +def test_load_via_dbt_ls_without_dbt_deps(postgres_profile_config): project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) render_config = RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME, dbt_deps=False) execution_config = ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) @@ -419,14 +414,7 @@ def test_load_via_dbt_ls_without_dbt_deps(): project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) with pytest.raises(CosmosLoadDbtException) as err_info: @@ -436,9 +424,48 @@ def test_load_via_dbt_ls_without_dbt_deps(): assert err_info.value.args[0] == expected +@pytest.mark.integration +def test_load_via_dbt_ls_without_dbt_deps_and_preinstalled_dbt_packages(tmp_dbt_project_dir, postgres_profile_config): + local_flags = [ + "--project-dir", + tmp_dbt_project_dir / DBT_PROJECT_NAME, + "--profiles-dir", + tmp_dbt_project_dir / DBT_PROJECT_NAME, + "--profile", + "default", + "--target", + "dev", + ] + + deps_command = ["dbt", "deps"] + deps_command.extend(local_flags) + process = Popen( + deps_command, + stdout=PIPE, + stderr=PIPE, + cwd=tmp_dbt_project_dir / DBT_PROJECT_NAME, + universal_newlines=True, + ) + stdout, stderr = process.communicate() + + project_config = ProjectConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME) + render_config = RenderConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME, dbt_deps=False) + execution_config = ExecutionConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME) + dbt_graph = DbtGraph( + project=project_config, + render_config=render_config, + execution_config=execution_config, + profile_config=postgres_profile_config, + ) + + dbt_graph.load_via_dbt_ls() # does not raise exception + + @pytest.mark.integration @patch("cosmos.dbt.graph.Popen") -def test_load_via_dbt_ls_with_zero_returncode_and_non_empty_stderr(mock_popen, tmp_dbt_project_dir): +def test_load_via_dbt_ls_with_zero_returncode_and_non_empty_stderr( + mock_popen, tmp_dbt_project_dir, postgres_profile_config +): mock_popen().communicate.return_value = ("", "Some stderr warnings") mock_popen().returncode = 0 @@ -449,14 +476,7 @@ def test_load_via_dbt_ls_with_zero_returncode_and_non_empty_stderr(mock_popen, t project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) dbt_graph.load_via_dbt_ls() # does not raise exception @@ -464,7 +484,7 @@ def test_load_via_dbt_ls_with_zero_returncode_and_non_empty_stderr(mock_popen, t @pytest.mark.integration @patch("cosmos.dbt.graph.Popen") -def test_load_via_dbt_ls_with_non_zero_returncode(mock_popen): +def test_load_via_dbt_ls_with_non_zero_returncode(mock_popen, postgres_profile_config): mock_popen().communicate.return_value = ("", "Some stderr message") mock_popen().returncode = 1 @@ -475,14 +495,7 @@ def test_load_via_dbt_ls_with_non_zero_returncode(mock_popen): project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) expected = r"Unable to run \['.+dbt', 'deps', .*\] due to the error:\nSome stderr message" with pytest.raises(CosmosLoadDbtException, match=expected): @@ -491,7 +504,7 @@ def test_load_via_dbt_ls_with_non_zero_returncode(mock_popen): @pytest.mark.integration @patch("cosmos.dbt.graph.Popen.communicate", return_value=("Some Runtime Error", "")) -def test_load_via_dbt_ls_with_runtime_error_in_stdout(mock_popen_communicate): +def test_load_via_dbt_ls_with_runtime_error_in_stdout(mock_popen_communicate, postgres_profile_config): # It may seem strange, but at least until dbt 1.6.0, there are circumstances when it outputs errors to stdout project_config = ProjectConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) render_config = RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME) @@ -500,14 +513,7 @@ def test_load_via_dbt_ls_with_runtime_error_in_stdout(mock_popen_communicate): project=project_config, render_config=render_config, execution_config=execution_config, - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) expected = r"Unable to run \['.+dbt', 'deps', .*\] due to the error:\nSome Runtime Error" with pytest.raises(CosmosLoadDbtException, match=expected): @@ -627,7 +633,7 @@ def test_tag_selected_node_test_exist(): @pytest.mark.integration @pytest.mark.parametrize("load_method", ["load_via_dbt_ls", "load_from_dbt_manifest"]) -def test_load_dbt_ls_and_manifest_with_model_version(load_method): +def test_load_dbt_ls_and_manifest_with_model_version(load_method, postgres_profile_config): dbt_graph = DbtGraph( project=ProjectConfig( dbt_project_path=DBT_PROJECTS_ROOT_DIR / "model_version", @@ -635,14 +641,7 @@ def test_load_dbt_ls_and_manifest_with_model_version(load_method): ), render_config=RenderConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / "model_version"), execution_config=ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / "model_version"), - profile_config=ProfileConfig( - profile_name="default", - target_name="default", - profile_mapping=PostgresUserPasswordProfileMapping( - conn_id="airflow_db", - profile_args={"schema": "public"}, - ), - ), + profile_config=postgres_profile_config, ) getattr(dbt_graph, load_method)() expected_dbt_nodes = { @@ -722,3 +721,179 @@ def test_parse_dbt_ls_output(): nodes = parse_dbt_ls_output(Path("fake-project"), fake_ls_stdout) assert expected_nodes == nodes + + +@patch("cosmos.dbt.graph.Popen") +@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency") +@patch("cosmos.config.RenderConfig.validate_dbt_command") +def test_load_via_dbt_ls_project_config_env_vars(mock_validate, mock_update_nodes, mock_popen, tmp_dbt_project_dir): + """Tests that the dbt ls command in the subprocess has the project config env vars set.""" + mock_popen().communicate.return_value = ("", "") + mock_popen().returncode = 0 + env_vars = {"MY_ENV_VAR": "my_value"} + project_config = ProjectConfig(env_vars=env_vars) + render_config = RenderConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME) + profile_config = ProfileConfig( + profile_name="test", + target_name="test", + profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml", + ) + execution_config = ExecutionConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME) + dbt_graph = DbtGraph( + project=project_config, + render_config=render_config, + execution_config=execution_config, + profile_config=profile_config, + ) + dbt_graph.load_via_dbt_ls() + + assert "MY_ENV_VAR" in mock_popen.call_args.kwargs["env"] + assert mock_popen.call_args.kwargs["env"]["MY_ENV_VAR"] == "my_value" + + +@patch("cosmos.dbt.graph.Popen") +@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency") +@patch("cosmos.config.RenderConfig.validate_dbt_command") +def test_load_via_dbt_ls_project_config_dbt_vars(mock_validate, mock_update_nodes, mock_popen, tmp_dbt_project_dir): + """Tests that the dbt ls command in the subprocess has "--vars" with the project config dbt_vars.""" + mock_popen().communicate.return_value = ("", "") + mock_popen().returncode = 0 + dbt_vars = {"my_var1": "my_value1", "my_var2": "my_value2"} + project_config = ProjectConfig(dbt_vars=dbt_vars) + render_config = RenderConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME) + profile_config = ProfileConfig( + profile_name="test", + target_name="test", + profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml", + ) + execution_config = ExecutionConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME) + dbt_graph = DbtGraph( + project=project_config, + render_config=render_config, + execution_config=execution_config, + profile_config=profile_config, + ) + dbt_graph.load_via_dbt_ls() + ls_command = mock_popen.call_args.args[0] + assert "--vars" in ls_command + assert ls_command[ls_command.index("--vars") + 1] == yaml.dump(dbt_vars) + + +@patch("cosmos.dbt.graph.Popen") +@patch("cosmos.dbt.graph.DbtGraph.update_node_dependency") +@patch("cosmos.config.RenderConfig.validate_dbt_command") +def test_load_via_dbt_ls_render_config_selector_arg_is_used( + mock_validate, mock_update_nodes, mock_popen, tmp_dbt_project_dir +): + """Tests that the dbt ls command in the subprocess has "--selector" with the RenderConfig.selector.""" + mock_popen().communicate.return_value = ("", "") + mock_popen().returncode = 0 + selector = "my_selector" + project_config = ProjectConfig() + render_config = RenderConfig( + dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME, + load_method=LoadMode.DBT_LS, + selector=selector, + ) + profile_config = ProfileConfig( + profile_name="test", + target_name="test", + profiles_yml_filepath=DBT_PROJECTS_ROOT_DIR / DBT_PROJECT_NAME / "profiles.yml", + ) + execution_config = MagicMock() + dbt_graph = DbtGraph( + project=project_config, + render_config=render_config, + execution_config=execution_config, + profile_config=profile_config, + ) + dbt_graph.load_via_dbt_ls() + ls_command = mock_popen.call_args.args[0] + assert "--selector" in ls_command + assert ls_command[ls_command.index("--selector") + 1] == selector + + +@pytest.mark.parametrize("load_method", [LoadMode.DBT_MANIFEST, LoadMode.CUSTOM]) +def test_load_method_with_unsupported_render_config_selector_arg(load_method): + """Tests that error is raised when RenderConfig.selector is used with LoadMode.DBT_MANIFEST or LoadMode.CUSTOM.""" + + expected_error_msg = ( + f"RenderConfig.selector is not yet supported when loading dbt projects using the {load_method} parser." + ) + dbt_graph = DbtGraph( + render_config=RenderConfig(load_method=load_method, selector="my_selector"), + project=MagicMock(), + ) + with pytest.raises(CosmosLoadDbtException, match=expected_error_msg): + dbt_graph.load(method=load_method) + + +@pytest.mark.sqlite +@pytest.mark.integration +def test_load_via_dbt_ls_with_project_config_vars(): + """ + Integration that tests that the dbt ls command is successful and that the node affected by the dbt_vars is + rendered correctly. + """ + project_name = "simple" + dbt_graph = DbtGraph( + project=ProjectConfig( + dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name, + env_vars={"DBT_SQLITE_PATH": str(DBT_PROJECTS_ROOT_DIR / "data")}, + dbt_vars={"animation_alias": "top_5_animated_movies"}, + ), + render_config=RenderConfig( + dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name, + dbt_deps=False, + ), + execution_config=ExecutionConfig(dbt_project_path=DBT_PROJECTS_ROOT_DIR / project_name), + profile_config=ProfileConfig( + profile_name="simple", + target_name="dev", + profiles_yml_filepath=(DBT_PROJECTS_ROOT_DIR / project_name / "profiles.yml"), + ), + ) + dbt_graph.load_via_dbt_ls() + assert dbt_graph.nodes["model.simple.top_animations"].config["alias"] == "top_5_animated_movies" + + +@pytest.mark.integration +def test_load_via_dbt_ls_with_selector_arg(tmp_dbt_project_dir, postgres_profile_config): + """ + Tests that the dbt ls load method is successful if a selector arg is used with RenderConfig + and that the filtered nodes are expected. + """ + # Add a selectors yaml file to the project that will select the stg_customers model and all + # parents (raw_customers) + selectors_yaml = """ + selectors: + - name: stage_customers + definition: + method: fqn + value: stg_customers + parents: true + """ + with open(tmp_dbt_project_dir / DBT_PROJECT_NAME / "selectors.yml", "w") as f: + f.write(selectors_yaml) + + project_config = ProjectConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME) + execution_config = ExecutionConfig(dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME) + render_config = RenderConfig( + dbt_project_path=tmp_dbt_project_dir / DBT_PROJECT_NAME, + selector="stage_customers", + ) + + dbt_graph = DbtGraph( + project=project_config, + render_config=render_config, + execution_config=execution_config, + profile_config=postgres_profile_config, + ) + dbt_graph.load_via_dbt_ls() + + filtered_nodes = dbt_graph.filtered_nodes.keys() + assert len(filtered_nodes) == 4 + assert "model.jaffle_shop.stg_customers" in filtered_nodes + assert "seed.jaffle_shop.raw_customers" in filtered_nodes + # Two tests should be filtered + assert sum(node.startswith("test.jaffle_shop") for node in filtered_nodes) == 2 diff --git a/tests/dbt/test_project.py b/tests/dbt/test_project.py index ec5612904..000ad06bd 100644 --- a/tests/dbt/test_project.py +++ b/tests/dbt/test_project.py @@ -11,7 +11,7 @@ def test_create_symlinks(tmp_path): tmp_dir = tmp_path / "dbt-project" tmp_dir.mkdir() - create_symlinks(DBT_PROJECTS_ROOT_DIR / "jaffle_shop", tmp_dir) + create_symlinks(DBT_PROJECTS_ROOT_DIR / "jaffle_shop", tmp_dir, False) for child in tmp_dir.iterdir(): assert child.is_symlink() assert child.name not in ("logs", "target", "profiles.yml", "dbt_packages") diff --git a/tests/dbt/test_selector.py b/tests/dbt/test_selector.py index f7ece6391..1cf987124 100644 --- a/tests/dbt/test_selector.py +++ b/tests/dbt/test_selector.py @@ -46,47 +46,69 @@ def test_is_empty_config(selector_config, paths, tags, config, other, expected): tags=["has_child"], config={"materialized": "view", "tags": ["has_child"]}, ) + +another_grandparent_node = DbtNode( + unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.another_grandparent_node", + resource_type=DbtResourceType.MODEL, + depends_on=[], + file_path=SAMPLE_PROJ_PATH / "gen1/models/another_grandparent_node.sql", + tags=[], + config={}, +) + parent_node = DbtNode( unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.parent", resource_type=DbtResourceType.MODEL, - depends_on=["grandparent"], + depends_on=[grandparent_node.unique_id, another_grandparent_node.unique_id], file_path=SAMPLE_PROJ_PATH / "gen2/models/parent.sql", tags=["has_child", "is_child"], config={"materialized": "view", "tags": ["has_child", "is_child"]}, ) + child_node = DbtNode( unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.child", resource_type=DbtResourceType.MODEL, - depends_on=["parent"], + depends_on=[parent_node.unique_id], file_path=SAMPLE_PROJ_PATH / "gen3/models/child.sql", tags=["nightly", "is_child"], config={"materialized": "table", "tags": ["nightly", "is_child"]}, ) -grandchild_1_test_node = DbtNode( - unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.grandchild_1", +sibling1_node = DbtNode( + unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.sibling1", resource_type=DbtResourceType.MODEL, - depends_on=["parent"], - file_path=SAMPLE_PROJ_PATH / "gen3/models/grandchild_1.sql", + depends_on=[parent_node.unique_id], + file_path=SAMPLE_PROJ_PATH / "gen3/models/sibling1.sql", tags=["nightly", "deprecated", "test"], config={"materialized": "table", "tags": ["nightly", "deprecated", "test"]}, ) -grandchild_2_test_node = DbtNode( - unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.grandchild_2", +sibling2_node = DbtNode( + unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.sibling2", resource_type=DbtResourceType.MODEL, - depends_on=["parent"], - file_path=SAMPLE_PROJ_PATH / "gen3/models/grandchild_2.sql", + depends_on=[parent_node.unique_id], + file_path=SAMPLE_PROJ_PATH / "gen3/models/sibling2.sql", tags=["nightly", "deprecated", "test2"], config={"materialized": "table", "tags": ["nightly", "deprecated", "test2"]}, ) +orphaned_node = DbtNode( + unique_id=f"{DbtResourceType.MODEL.value}.{SAMPLE_PROJ_PATH.stem}.orphaned", + resource_type=DbtResourceType.MODEL, + depends_on=[], + file_path=SAMPLE_PROJ_PATH / "gen3/models/orphaned.sql", + tags=[], + config={}, +) + sample_nodes = { grandparent_node.unique_id: grandparent_node, + another_grandparent_node.unique_id: another_grandparent_node, parent_node.unique_id: parent_node, child_node.unique_id: child_node, - grandchild_1_test_node.unique_id: grandchild_1_test_node, - grandchild_2_test_node.unique_id: grandchild_2_test_node, + sibling1_node.unique_id: sibling1_node, + sibling2_node.unique_id: sibling2_node, + orphaned_node.unique_id: orphaned_node, } @@ -100,8 +122,8 @@ def test_select_nodes_by_select_config(): selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["config.materialized:table"]) expected = { child_node.unique_id: child_node, - grandchild_1_test_node.unique_id: grandchild_1_test_node, - grandchild_2_test_node.unique_id: grandchild_2_test_node, + sibling1_node.unique_id: sibling1_node, + sibling2_node.unique_id: sibling2_node, } assert selected == expected @@ -136,8 +158,8 @@ def test_select_nodes_by_select_union_config_test_tags(): expected = { grandparent_node.unique_id: grandparent_node, parent_node.unique_id: parent_node, - grandchild_1_test_node.unique_id: grandchild_1_test_node, - grandchild_2_test_node.unique_id: grandchild_2_test_node, + sibling1_node.unique_id: sibling1_node, + sibling2_node.unique_id: sibling2_node, } assert selected == expected @@ -176,8 +198,8 @@ def test_select_nodes_by_select_union(): grandparent_node.unique_id: grandparent_node, parent_node.unique_id: parent_node, child_node.unique_id: child_node, - grandchild_1_test_node.unique_id: grandchild_1_test_node, - grandchild_2_test_node.unique_id: grandchild_2_test_node, + sibling1_node.unique_id: sibling1_node, + sibling2_node.unique_id: sibling2_node, } assert selected == expected @@ -191,8 +213,10 @@ def test_select_nodes_by_exclude_tag(): selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["tag:has_child"]) expected = { child_node.unique_id: child_node, - grandchild_1_test_node.unique_id: grandchild_1_test_node, - grandchild_2_test_node.unique_id: grandchild_2_test_node, + sibling1_node.unique_id: sibling1_node, + sibling2_node.unique_id: sibling2_node, + another_grandparent_node.unique_id: another_grandparent_node, + orphaned_node.unique_id: orphaned_node, } assert selected == expected @@ -217,8 +241,10 @@ def test_select_nodes_by_exclude_union_config_test_tags(): ) expected = { grandparent_node.unique_id: grandparent_node, + another_grandparent_node.unique_id: another_grandparent_node, parent_node.unique_id: parent_node, child_node.unique_id: child_node, + orphaned_node.unique_id: orphaned_node, } assert selected == expected @@ -227,15 +253,156 @@ def test_select_nodes_by_path_dir(): selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["path:gen3/models"]) expected = { child_node.unique_id: child_node, - grandchild_1_test_node.unique_id: grandchild_1_test_node, - grandchild_2_test_node.unique_id: grandchild_2_test_node, + sibling1_node.unique_id: sibling1_node, + sibling2_node.unique_id: sibling2_node, + orphaned_node.unique_id: orphaned_node, } assert selected == expected def test_select_nodes_by_path_file(): selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["path:gen2/models/parent.sql"]) - expected = { - parent_node.unique_id: parent_node, - } - assert selected == expected + expected = [parent_node.unique_id] + assert list(selected.keys()) == expected + + +def test_select_nodes_by_child_and_precursors(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+child"]) + expected = [ + another_grandparent_node.unique_id, + child_node.unique_id, + grandparent_node.unique_id, + parent_node.unique_id, + ] + assert sorted(selected.keys()) == expected + + +def test_select_nodes_by_child_and_precursors_exclude_tags(): + selected = select_nodes( + project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+child"], exclude=["tag:has_child"] + ) + expected = [another_grandparent_node.unique_id, child_node.unique_id] + assert sorted(selected.keys()) == expected + + +def test_select_node_by_child_and_precursors_partial_tree(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+parent"]) + expected = [another_grandparent_node.unique_id, grandparent_node.unique_id, parent_node.unique_id] + assert sorted(selected.keys()) == expected + + +def test_select_node_by_precursors_with_orphaned_node(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+orphaned"]) + expected = [orphaned_node.unique_id] + assert list(selected.keys()) == expected + + +def test_select_nodes_by_child_and_first_degree_precursors(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["1+child"]) + expected = [ + child_node.unique_id, + parent_node.unique_id, + ] + assert sorted(selected.keys()) == expected + + +def test_select_nodes_by_child_and_second_degree_precursors(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["2+child"]) + expected = [ + another_grandparent_node.unique_id, + child_node.unique_id, + grandparent_node.unique_id, + parent_node.unique_id, + ] + assert sorted(selected.keys()) == expected + + +def test_select_node_by_exact_node_name(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["child"]) + expected = [child_node.unique_id] + assert list(selected.keys()) == expected + + +def test_select_node_by_child_and_precursors_no_node(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["+modelDoesntExist"]) + expected = [] + assert list(selected.keys()) == expected + + +def test_select_node_by_descendants(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["grandparent+"]) + expected = [ + "model.dbt-proj.child", + "model.dbt-proj.grandparent", + "model.dbt-proj.parent", + "model.dbt-proj.sibling1", + "model.dbt-proj.sibling2", + ] + assert sorted(selected.keys()) == expected + + +def test_select_node_by_descendants_depth_first_degree(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["grandparent+1"]) + expected = [ + "model.dbt-proj.grandparent", + "model.dbt-proj.parent", + ] + assert sorted(selected.keys()) == expected + + +def test_select_node_by_descendants_union(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["grandparent+1", "parent+1"]) + expected = [ + "model.dbt-proj.child", + "model.dbt-proj.grandparent", + "model.dbt-proj.parent", + "model.dbt-proj.sibling1", + "model.dbt-proj.sibling2", + ] + assert sorted(selected.keys()) == expected + + +def test_select_node_by_descendants_intersection(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["grandparent+1,parent+1"]) + expected = [ + "model.dbt-proj.parent", + ] + assert sorted(selected.keys()) == expected + + +def test_select_node_by_descendants_intersection_with_tag(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["parent+1,tag:has_child"]) + expected = [ + "model.dbt-proj.parent", + ] + assert sorted(selected.keys()) == expected + + +def test_select_node_by_descendants_and_tag_union(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["child", "tag:has_child"]) + expected = [ + "model.dbt-proj.child", + "model.dbt-proj.grandparent", + "model.dbt-proj.parent", + ] + assert sorted(selected.keys()) == expected + + +def test_exclude_by_graph_selector(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["+parent"]) + expected = [ + "model.dbt-proj.child", + "model.dbt-proj.orphaned", + "model.dbt-proj.sibling1", + "model.dbt-proj.sibling2", + ] + assert sorted(selected.keys()) == expected + + +def test_exclude_by_union_graph_selector_and_tag(): + selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["+parent", "tag:deprecated"]) + expected = [ + "model.dbt-proj.child", + "model.dbt-proj.orphaned", + ] + assert sorted(selected.keys()) == expected diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index b0a36b335..dd7d34a6d 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -488,3 +488,14 @@ def test_operator_execute_deps_parameters( mock_ensure_profile.return_value.__enter__.return_value = (Path("/path/to/profile"), {"ENV_VAR": "value"}) task.execute(context={"task_instance": MagicMock()}) assert mock_build_and_run_cmd.call_args_list[0].kwargs["command"] == expected_call_kwargs + + +def test_dbt_docs_local_operator_with_static_flag(): + # Check when static flag is passed, the required files are correctly adjusted to a single file + operator = DbtDocsLocalOperator( + task_id="fake-task", + project_dir="fake-dir", + profile_config=profile_config, + dbt_cmd_flags=["--static"], + ) + assert operator.required_files == ["static_index.html"] diff --git a/tests/profiles/athena/test_athena_access_key.py b/tests/profiles/athena/test_athena_access_key.py index 22c8efa2c..c224a9d4b 100644 --- a/tests/profiles/athena/test_athena_access_key.py +++ b/tests/profiles/athena/test_athena_access_key.py @@ -1,20 +1,49 @@ "Tests for the Athena profile." import json -from unittest.mock import patch - +from collections import namedtuple +import sys +from unittest.mock import MagicMock, patch import pytest from airflow.models.connection import Connection from cosmos.profiles import get_automatic_profile_mapping from cosmos.profiles.athena.access_key import AthenaAccessKeyProfileMapping +Credentials = namedtuple("Credentials", ["access_key", "secret_key", "token"]) + +mock_assumed_credentials = Credentials( + secret_key="my_aws_assumed_secret_key", + access_key="my_aws_assumed_access_key", + token="my_aws_assumed_token", +) + +mock_missing_credentials = Credentials(access_key=None, secret_key=None, token=None) + + +@pytest.fixture(autouse=True) +def mock_aws_module(): + mock_aws_hook = MagicMock() + + class MockAwsGenericHook: + def __init__(self, conn_id: str) -> None: + pass + + def get_credentials(self) -> Credentials: + return mock_assumed_credentials + + mock_aws_hook.AwsGenericHook = MockAwsGenericHook + + with patch.dict(sys.modules, {"airflow.providers.amazon.aws.hooks.base_aws": mock_aws_hook}): + yield mock_aws_hook + @pytest.fixture() def mock_athena_conn(): # type: ignore """ Sets the connection as an environment variable. """ + conn = Connection( conn_id="my_athena_connection", conn_type="aws", @@ -24,7 +53,7 @@ def mock_athena_conn(): # type: ignore { "aws_session_token": "token123", "database": "my_database", - "region_name": "my_region", + "region_name": "us-east-1", "s3_staging_dir": "s3://my_bucket/dbt/", "schema": "my_schema", } @@ -48,6 +77,7 @@ def test_athena_connection_claiming() -> None: # - region_name # - s3_staging_dir # - schema + potential_values = { "conn_type": "aws", "login": "my_aws_access_key_id", @@ -55,7 +85,7 @@ def test_athena_connection_claiming() -> None: "extra": json.dumps( { "database": "my_database", - "region_name": "my_region", + "region_name": "us-east-1", "s3_staging_dir": "s3://my_bucket/dbt/", "schema": "my_schema", } @@ -68,12 +98,14 @@ def test_athena_connection_claiming() -> None: del values[key] conn = Connection(**values) # type: ignore - print("testing with", values) - - with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - # should raise an InvalidMappingException - profile_mapping = AthenaAccessKeyProfileMapping(conn, {}) - assert not profile_mapping.can_claim_connection() + with patch( + "cosmos.profiles.athena.access_key.AthenaAccessKeyProfileMapping._get_temporary_credentials", + return_value=mock_missing_credentials, + ): + with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): + # should raise an InvalidMappingException + profile_mapping = AthenaAccessKeyProfileMapping(conn, {}) + assert not profile_mapping.can_claim_connection() # if we have them all, it should claim conn = Connection(**potential_values) # type: ignore @@ -88,6 +120,7 @@ def test_athena_profile_mapping_selected( """ Tests that the correct profile mapping is selected for Athena. """ + profile_mapping = get_automatic_profile_mapping( mock_athena_conn.conn_id, ) @@ -100,13 +133,14 @@ def test_athena_profile_args( """ Tests that the profile values get set correctly for Athena. """ + profile_mapping = get_automatic_profile_mapping( mock_athena_conn.conn_id, ) assert profile_mapping.profile == { "type": "athena", - "aws_access_key_id": mock_athena_conn.login, + "aws_access_key_id": mock_assumed_credentials.access_key, "aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", "aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}", "database": mock_athena_conn.extra_dejson.get("database"), @@ -122,9 +156,14 @@ def test_athena_profile_args_overrides( """ Tests that you can override the profile values for Athena. """ + profile_mapping = get_automatic_profile_mapping( mock_athena_conn.conn_id, - profile_args={"schema": "my_custom_schema", "database": "my_custom_db", "aws_session_token": "override_token"}, + profile_args={ + "schema": "my_custom_schema", + "database": "my_custom_db", + "aws_session_token": "override_token", + }, ) assert profile_mapping.profile_args == { @@ -135,7 +174,7 @@ def test_athena_profile_args_overrides( assert profile_mapping.profile == { "type": "athena", - "aws_access_key_id": mock_athena_conn.login, + "aws_access_key_id": mock_assumed_credentials.access_key, "aws_secret_access_key": "{{ env_var('COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY') }}", "aws_session_token": "{{ env_var('COSMOS_CONN_AWS_AWS_SESSION_TOKEN') }}", "database": "my_custom_db", @@ -151,10 +190,12 @@ def test_athena_profile_env_vars( """ Tests that the environment variables get set correctly for Athena. """ + profile_mapping = get_automatic_profile_mapping( mock_athena_conn.conn_id, ) + assert profile_mapping.env_vars == { - "COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_athena_conn.password, - "COSMOS_CONN_AWS_AWS_SESSION_TOKEN": mock_athena_conn.extra_dejson.get("aws_session_token"), + "COSMOS_CONN_AWS_AWS_SECRET_ACCESS_KEY": mock_assumed_credentials.secret_key, + "COSMOS_CONN_AWS_AWS_SESSION_TOKEN": mock_assumed_credentials.token, } diff --git a/tests/profiles/vertica/test_vertica_user_pass.py b/tests/profiles/vertica/test_vertica_user_pass.py index 19771c799..6459dea96 100644 --- a/tests/profiles/vertica/test_vertica_user_pass.py +++ b/tests/profiles/vertica/test_vertica_user_pass.py @@ -23,8 +23,7 @@ def mock_vertica_conn(): # type: ignore login="my_user", password="my_password", port=5433, - schema="my_schema", - extra='{"database": "my_database"}', + schema="my_database", ) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): @@ -43,8 +42,7 @@ def mock_vertica_conn_custom_port(): # type: ignore login="my_user", password="my_password", port=7472, - schema="my_schema", - extra='{"database": "my_database"}', + schema="my_database", ) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): @@ -69,8 +67,7 @@ def test_connection_claiming() -> None: "host": "my_host", "login": "my_user", "password": "my_password", - "schema": "my_schema", - "extra": '{"database": "my_database"}', + "schema": "my_database", } # if we're missing any of the values, it shouldn't claim @@ -82,20 +79,20 @@ def test_connection_claiming() -> None: print("testing with", values) with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = VerticaUserPasswordProfileMapping(conn) + profile_mapping = VerticaUserPasswordProfileMapping(conn, {"schema": "my_schema"}) assert not profile_mapping.can_claim_connection() - # also test when there's no database + # also test when there's no schema conn = Connection(**potential_values) # type: ignore conn.extra = "" with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = VerticaUserPasswordProfileMapping(conn) + profile_mapping = VerticaUserPasswordProfileMapping(conn, {}) assert not profile_mapping.can_claim_connection() # if we have them all, it should claim conn = Connection(**potential_values) # type: ignore with patch("airflow.hooks.base.BaseHook.get_connection", return_value=conn): - profile_mapping = VerticaUserPasswordProfileMapping(conn) + profile_mapping = VerticaUserPasswordProfileMapping(conn, {"schema": "my_schema"}) assert profile_mapping.can_claim_connection() @@ -107,7 +104,7 @@ def test_profile_mapping_selected( """ profile_mapping = get_automatic_profile_mapping( mock_vertica_conn.conn_id, - {"schema": "my_schema"}, + {"schema": "my_database"}, ) assert isinstance(profile_mapping, VerticaUserPasswordProfileMapping) @@ -145,8 +142,8 @@ def test_profile_args( "username": mock_vertica_conn.login, "password": "{{ env_var('COSMOS_CONN_VERTICA_PASSWORD') }}", "port": mock_vertica_conn.port, + "database": mock_vertica_conn.schema, "schema": "my_schema", - "database": mock_vertica_conn.extra_dejson.get("database"), } diff --git a/tests/test_config.py b/tests/test_config.py index 578a68f76..734303a3e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -174,3 +174,9 @@ def test_render_config_uses_default_if_exists(mock_which): render_config = RenderConfig(dbt_executable_path="user-dbt") render_config.validate_dbt_command("fallback-dbt-path") assert render_config.dbt_executable_path == "user-dbt" + + +def test_render_config_env_vars_deprecated(): + """RenderConfig.env_vars is deprecated since Cosmos 1.3, should warn user.""" + with pytest.deprecated_call(): + RenderConfig(env_vars={"VAR": "value"}) diff --git a/tests/test_converter.py b/tests/test_converter.py index 3bb5af163..d84249aae 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -1,13 +1,13 @@ from datetime import datetime from pathlib import Path -from unittest.mock import patch +from unittest.mock import patch, MagicMock from cosmos.profiles.postgres import PostgresUserPasswordProfileMapping import pytest from airflow.models import DAG -from cosmos.converter import DbtToAirflowConverter, validate_arguments -from cosmos.constants import DbtResourceType, ExecutionMode +from cosmos.converter import DbtToAirflowConverter, validate_arguments, validate_initial_user_config +from cosmos.constants import DbtResourceType, ExecutionMode, LoadMode from cosmos.config import ProjectConfig, ProfileConfig, ExecutionConfig, RenderConfig, CosmosConfigException from cosmos.dbt.graph import DbtNode from cosmos.exceptions import CosmosValueError @@ -35,6 +35,79 @@ def test_validate_arguments_tags(argument_key): assert err.value.args[0] == expected +@pytest.mark.parametrize( + "execution_mode", + (ExecutionMode.LOCAL, ExecutionMode.VIRTUALENV), +) +def test_validate_initial_user_config_no_profile(execution_mode): + execution_config = ExecutionConfig(execution_mode=execution_mode) + profile_config = None + project_config = ProjectConfig() + with pytest.raises(CosmosValueError) as err_info: + validate_initial_user_config(execution_config, profile_config, project_config, None, {}) + err_msg = f"The profile_config is mandatory when using {execution_mode}" + assert err_info.value.args[0] == err_msg + + +@pytest.mark.parametrize( + "execution_mode", + (ExecutionMode.DOCKER, ExecutionMode.KUBERNETES), +) +def test_validate_initial_user_config_expects_profile(execution_mode): + execution_config = ExecutionConfig(execution_mode=execution_mode) + profile_config = None + project_config = ProjectConfig() + assert validate_initial_user_config(execution_config, profile_config, project_config, None, {}) is None + + +@pytest.mark.parametrize("operator_args", [{"env": {"key": "value"}}, {"vars": {"key": "value"}}]) +def test_validate_user_config_operator_args_deprecated(operator_args): + """Deprecating warnings should be raised when using operator_args with "vars" or "env".""" + project_config = ProjectConfig() + execution_config = ExecutionConfig() + render_config = RenderConfig() + profile_config = MagicMock() + + with pytest.deprecated_call(): + validate_initial_user_config(execution_config, profile_config, project_config, render_config, operator_args) + + +@pytest.mark.parametrize("project_config_arg, operator_arg", [("dbt_vars", "vars"), ("env_vars", "env")]) +def test_validate_user_config_fails_project_config_and_operator_args_overlap(project_config_arg, operator_arg): + """ + The validation should fail if a user specifies both a ProjectConfig and operator_args with dbt_vars/vars or env_vars/env + that overlap. + """ + project_config = ProjectConfig( + project_name="fake-project", + dbt_project_path="/some/project/path", + **{project_config_arg: {"key": "value"}}, # type: ignore + ) + execution_config = ExecutionConfig() + render_config = RenderConfig() + profile_config = MagicMock() + operator_args = {operator_arg: {"key": "value"}} + + expected_error_msg = f"ProjectConfig.{project_config_arg} and operator_args with '{operator_arg}' are mutually exclusive and only one can be used." + with pytest.raises(CosmosValueError, match=expected_error_msg): + validate_initial_user_config(execution_config, profile_config, project_config, render_config, operator_args) + + +def test_validate_user_config_fails_project_config_render_config_env_vars(): + """ + The validation should fail if a user specifies both ProjectConfig.env_vars and RenderConfig.env_vars. + """ + project_config = ProjectConfig(env_vars={"key": "value"}) + execution_config = ExecutionConfig() + render_config = RenderConfig(env_vars={"key": "value"}) + profile_config = MagicMock() + operator_args = {} + + expected_error_match = "Both ProjectConfig.env_vars and RenderConfig.env_vars were provided.*" + with pytest.raises(CosmosValueError, match=expected_error_match): + validate_initial_user_config(execution_config, profile_config, project_config, render_config, operator_args) + + def test_validate_arguments_schema_in_task_args(): profile_config = ProfileConfig( profile_name="test", @@ -302,3 +375,33 @@ def test_converter_fails_no_manifest_no_render_config(mock_load_dbt_graph, execu err_info.value.args[0] == "RenderConfig.dbt_project_path is required for rendering an airflow DAG from a DBT Graph if no manifest is provided." ) + + +@patch("cosmos.config.ProjectConfig.validate_project") +@patch("cosmos.converter.build_airflow_graph") +@patch("cosmos.dbt.graph.LegacyDbtProject") +def test_converter_project_config_dbt_vars_with_custom_load_mode( + mock_legacy_dbt_project, mock_validate_project, mock_build_airflow_graph +): + """Tests that if ProjectConfig.dbt_vars are used with RenderConfig.load_method of "custom" that the + expected dbt_vars are passed to LegacyDbtProject. + """ + project_config = ProjectConfig( + project_name="fake-project", dbt_project_path="/some/project/path", dbt_vars={"key": "value"} + ) + execution_config = ExecutionConfig() + render_config = RenderConfig(load_method=LoadMode.CUSTOM) + profile_config = MagicMock() + + with DAG("test-id", start_date=datetime(2022, 1, 1)) as dag: + DbtToAirflowConverter( + dag=dag, + nodes=nodes, + project_config=project_config, + profile_config=profile_config, + execution_config=execution_config, + render_config=render_config, + operator_args={}, + ) + _, kwargs = mock_legacy_dbt_project.call_args + assert kwargs["dbt_vars"] == {"key": "value"}