Skip to content

Commit

Permalink
Merge branch 'main' into bigquery-keyfile-dict-use-env-vars
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana authored Aug 16, 2023
2 parents b6b957a + c676429 commit a08455d
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 20 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha || github.ref }}
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
Expand Down
41 changes: 25 additions & 16 deletions cosmos/dbt/selector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from pathlib import Path
import copy

from typing import TYPE_CHECKING

Expand Down Expand Up @@ -43,6 +44,10 @@ def __init__(self, project_dir: Path, statement: str):
self.other: list[str] = []
self.load_from_statement(statement)

@property
def is_empty(self) -> bool:
return not (self.paths or self.tags or self.config or self.other)

def load_from_statement(self, statement: str) -> None:
"""
Load in-place select parameters.
Expand Down Expand Up @@ -84,27 +89,30 @@ def select_nodes_ids_by_intersection(nodes: dict[str, DbtNode], config: Selector
https://docs.getdbt.com/reference/node-selection/yaml-selectors
"""
selected_nodes = set()
for node_id, node in nodes.items():
if config.tags and not (sorted(node.tags) == sorted(config.tags)):
continue

supported_node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG}
if config.config:
config_tag = config.config.get("tags")
if config_tag and config_tag not in supported_node_config.get("tags", []):
if not config.is_empty:
for node_id, node in nodes.items():
if config.tags and not (sorted(node.tags) == sorted(config.tags)):
continue

# Remove 'tags' as they've already been filtered for
config.config.pop("tags", None)
supported_node_config.pop("tags", None)
supported_node_config = {key: value for key, value in node.config.items() if key in SUPPORTED_CONFIG}
config_tag = config.config.get("tags")
if config.config:
if config_tag and config_tag not in supported_node_config.get("tags", []):
continue

# Remove 'tags' as they've already been filtered for
config_copy = copy.deepcopy(config.config)
config_copy.pop("tags", None)
supported_node_config.pop("tags", None)

if not (config.config.items() <= supported_node_config.items()):
continue
if not (config_copy.items() <= supported_node_config.items()):
continue

if config.paths and not (set(config.paths).issubset(set(node.file_path.parents))):
continue
if config.paths and not (set(config.paths).issubset(set(node.file_path.parents))):
continue

selected_nodes.add(node_id)
selected_nodes.add(node_id)

return selected_nodes

Expand Down Expand Up @@ -166,9 +174,10 @@ def select_nodes(

nodes_ids = set(nodes.keys())

exclude_ids: set[str] = set()
for statement in exclude:
config = SelectorConfig(project_dir, statement)
exclude_ids = select_nodes_ids_by_intersection(nodes, config)
exclude_ids = exclude_ids.union(set(select_nodes_ids_by_intersection(nodes, config)))
subset_ids = set(nodes_ids) - set(exclude_ids)

return {id_: nodes[id_] for id_ in subset_ids}
5 changes: 5 additions & 0 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DbtBaseOperator(BaseOperator):
:param dbt_executable_path: Path to dbt executable can be used with venv
(i.e. /home/astro/.pyenv/versions/dbt_venv/bin/dbt)
:param dbt_cmd_flags: List of flags to pass to dbt command
:param dbt_cmd_global_flags: List of dbt global flags to be passed to the dbt command
"""

template_fields: Sequence[str] = ("env", "vars")
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
cancel_query_on_kill: bool = True,
dbt_executable_path: str = "dbt",
dbt_cmd_flags: list[str] | None = None,
dbt_cmd_global_flags: list[str] | None = None,
**kwargs: Any,
) -> None:
self.project_dir = project_dir
Expand Down Expand Up @@ -132,6 +134,7 @@ def __init__(
else:
self.dbt_executable_path = dbt_executable_path
self.dbt_cmd_flags = dbt_cmd_flags
self.dbt_cmd_global_flags = dbt_cmd_global_flags or []
super().__init__(**kwargs)

def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]:
Expand Down Expand Up @@ -210,6 +213,8 @@ def build_cmd(
) -> Tuple[list[str | None], dict[str, str | bytes | os.PathLike[Any]]]:
dbt_cmd = [self.dbt_executable_path]

dbt_cmd.extend(self.dbt_cmd_global_flags)

if self.base_cmd:
dbt_cmd.extend(self.base_cmd)

Expand Down
96 changes: 93 additions & 3 deletions tests/dbt/test_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,42 @@

import pytest

from cosmos.dbt.selector import SelectorConfig
from cosmos.constants import DbtResourceType
from cosmos.dbt.graph import DbtNode
from cosmos.dbt.selector import select_nodes
from cosmos.exceptions import CosmosValueError

SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/")


@pytest.fixture
def selector_config():
project_dir = Path("/path/to/project")
statement = ""
return SelectorConfig(project_dir, statement)


@pytest.mark.parametrize(
"paths, tags, config, other, expected",
[
([], [], {}, [], True),
([Path("path1")], [], {}, [], False),
([], ["tag:has_child"], {}, [], False),
([], [], {"config.tags:test"}, [], False),
([], [], {}, ["other"], False),
([Path("path1")], ["tag:has_child"], {"config.tags:test"}, ["other"], False),
],
)
def test_is_empty_config(selector_config, paths, tags, config, other, expected):
selector_config.paths = paths
selector_config.tags = tags
selector_config.config = config
selector_config.other = other

assert selector_config.is_empty == expected


grandparent_node = DbtNode(
name="grandparent",
unique_id="grandparent",
Expand Down Expand Up @@ -37,10 +66,32 @@
config={"materialized": "table", "tags": ["is_child"]},
)

grandchild_1_test_node = DbtNode(
name="grandchild_1",
unique_id="grandchild_1",
resource_type=DbtResourceType.MODEL,
depends_on=["parent"],
file_path=SAMPLE_PROJ_PATH / "gen3/models/grandchild_1.sql",
tags=["nightly"],
config={"materialized": "table", "tags": ["deprecated", "test"]},
)

grandchild_2_test_node = DbtNode(
name="grandchild_2",
unique_id="grandchild_2",
resource_type=DbtResourceType.MODEL,
depends_on=["parent"],
file_path=SAMPLE_PROJ_PATH / "gen3/models/grandchild_2.sql",
tags=["nightly"],
config={"materialized": "table", "tags": ["deprecated", "test2"]},
)

sample_nodes = {
grandparent_node.unique_id: grandparent_node,
parent_node.unique_id: parent_node,
child_node.unique_id: child_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
}


Expand All @@ -52,13 +103,19 @@ def test_select_nodes_by_select_tag():

def test_select_nodes_by_select_config():
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["config.materialized:table"])
expected = {child_node.unique_id: child_node}
expected = {
child_node.unique_id: child_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
}
assert selected == expected


def test_select_nodes_by_select_config_tag():
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["config.tags:is_child"])
expected = {child_node.unique_id: child_node}
expected = {
child_node.unique_id: child_node,
}
assert selected == expected


Expand All @@ -74,6 +131,21 @@ def test_select_nodes_by_select_union_config_tag():
assert selected == expected


def test_select_nodes_by_select_union_config_test_tags():
selected = select_nodes(
project_dir=SAMPLE_PROJ_PATH,
nodes=sample_nodes,
select=["config.tags:test", "config.tags:test2", "config.materialized:view"],
)
expected = {
grandparent_node.unique_id: grandparent_node,
parent_node.unique_id: parent_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
}
assert selected == expected


def test_select_nodes_by_select_intersection_config_tag():
selected = select_nodes(
project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, select=["config.tags:is_child,config.materialized:view"]
Expand All @@ -95,6 +167,8 @@ def test_select_nodes_by_select_union():
grandparent_node.unique_id: grandparent_node,
parent_node.unique_id: parent_node,
child_node.unique_id: child_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
}
assert selected == expected

Expand All @@ -106,7 +180,11 @@ def test_select_nodes_by_select_intersection():

def test_select_nodes_by_exclude_tag():
selected = select_nodes(project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["tag:has_child"])
expected = {child_node.unique_id: child_node}
expected = {
child_node.unique_id: child_node,
grandchild_1_test_node.unique_id: grandchild_1_test_node,
grandchild_2_test_node.unique_id: grandchild_2_test_node,
}
assert selected == expected


Expand All @@ -122,3 +200,15 @@ def test_select_nodes_by_select_union_exclude_tags():
)
expected = {}
assert selected == expected


def test_select_nodes_by_exclude_union_config_test_tags():
selected = select_nodes(
project_dir=SAMPLE_PROJ_PATH, nodes=sample_nodes, exclude=["config.tags:test", "config.tags:test2"]
)
expected = {
grandparent_node.unique_id: grandparent_node,
parent_node.unique_id: parent_node,
child_node.unique_id: child_node,
}
assert selected == expected
19 changes: 18 additions & 1 deletion tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,24 @@ def test_dbt_base_operator_add_user_supplied_flags() -> None:
cmd, _ = dbt_base_operator.build_cmd(
Context(execution_date=datetime(2023, 2, 15, 12, 30)),
)
assert "--full-refresh" in cmd
assert cmd[-2] == "run"
assert cmd[-1] == "--full-refresh"


def test_dbt_base_operator_add_user_supplied_global_flags() -> None:
dbt_base_operator = DbtLocalBaseOperator(
profile_config=profile_config,
task_id="my-task",
project_dir="my/dir",
base_cmd=["run"],
dbt_cmd_global_flags=["--cache-selected-only"],
)

cmd, _ = dbt_base_operator.build_cmd(
Context(execution_date=datetime(2023, 2, 15, 12, 30)),
)
assert cmd[-2] == "--cache-selected-only"
assert cmd[-1] == "run"


@pytest.mark.parametrize(
Expand Down

0 comments on commit a08455d

Please sign in to comment.