Skip to content

Commit

Permalink
Allow use of sources as unit testing inputs (#9059)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank authored Nov 15, 2023
1 parent 436dae6 commit c6be2d2
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 63 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231111-191150.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Support source inputs in unit tests
time: 2023-11-11T19:11:50.870494-05:00
custom:
Author: gshank
Issue: "8507"
18 changes: 14 additions & 4 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
from dataclasses import dataclass, field
from typing import Optional, TypeVar, Any, Type, Dict, Iterator, Tuple, Set, Union, FrozenSet

from dbt.contracts.graph.nodes import SourceDefinition, ManifestNode, ResultNode, ParsedNode
from dbt.contracts.graph.nodes import (
SourceDefinition,
ManifestNode,
ResultNode,
ParsedNode,
UnitTestSourceDefinition,
)
from dbt.contracts.relation import (
RelationType,
ComponentName,
Expand Down Expand Up @@ -201,7 +207,9 @@ def quoted(self, identifier):
)

@classmethod
def create_from_source(cls: Type[Self], source: SourceDefinition, **kwargs: Any) -> Self:
def create_from_source(
cls: Type[Self], source: Union[SourceDefinition, UnitTestSourceDefinition], **kwargs: Any
) -> Self:
source_quoting = source.quoting.to_dict(omit_none=True)
source_quoting.pop("column", None)
quote_policy = deep_merge(
Expand Down Expand Up @@ -263,8 +271,10 @@ def create_from(
node: ResultNode,
**kwargs: Any,
) -> Self:
if node.resource_type == NodeType.Source:
if not isinstance(node, SourceDefinition):
if node.resource_type == NodeType.Source or isinstance(node, UnitTestSourceDefinition):
if not (
isinstance(node, SourceDefinition) or isinstance(node, UnitTestSourceDefinition)
):
raise DbtInternalError(
"type mismatch, expected SourceDefinition but got {}".format(type(node))
)
Expand Down
25 changes: 24 additions & 1 deletion core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,29 @@ def resolve(self, source_name: str, table_name: str):
return self.Relation.create_from_source(target_source)


class RuntimeUnitTestSourceResolver(BaseSourceResolver):
def resolve(self, source_name: str, table_name: str):
target_source = self.manifest.resolve_source(
source_name,
table_name,
self.current_project,
self.model.package_name,
)
if target_source is None or isinstance(target_source, Disabled):
raise TargetNotFoundError(

Check warning on line 618 in core/dbt/context/providers.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/providers.py#L618

Added line #L618 was not covered by tests
node=self.model,
target_name=f"{source_name}.{table_name}",
target_kind="source",
disabled=(isinstance(target_source, Disabled)),
)
# For unit tests, this isn't a "real" source, it's a ModelNode taking
# the place of a source. We don't really need to return the relation here,
# we just need to set_cte, but skipping it confuses typing. We *do* need
# the relation in the "this" property.
self.model.set_cte(target_source.unique_id, None)
return self.Relation.create_ephemeral_from_node(self.config, target_source)


# metric` implementations
class ParseMetricResolver(BaseMetricResolver):
def resolve(self, name: str, package: Optional[str] = None) -> MetricReference:
Expand Down Expand Up @@ -746,7 +769,7 @@ class RuntimeUnitTestProvider(Provider):
DatabaseWrapper = RuntimeDatabaseWrapper
Var = UnitTestVar
ref = RuntimeUnitTestRefResolver
source = RuntimeSourceResolver # TODO: RuntimeUnitTestSourceResolver
source = RuntimeUnitTestSourceResolver
metric = RuntimeMetricResolver


Expand Down
12 changes: 11 additions & 1 deletion core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,9 +1062,19 @@ def test_node_type(self):
return "generic"


@dataclass
class UnitTestSourceDefinition(ModelNode):
source_name: str = "undefined"
quoting: Quoting = field(default_factory=Quoting)

@property
def search_name(self):
return f"{self.source_name}.{self.name}"


@dataclass
class UnitTestNode(CompiledNode):
resource_type: NodeType = field(metadata={"restrict": [NodeType.Unit]})
resource_type: Literal[NodeType.Unit]
tested_node_unique_id: Optional[str] = None
this_input_node_unique_id: Optional[str] = None
overrides: Optional[UnitTestOverrides] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,31 @@
{% set default_row = {} %}

{%- if not column_name_to_data_types -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- endfor -%}
{%- set columns_in_relation = adapter.get_columns_in_relation(this) -%}
{%- set column_name_to_data_types = {} -%}
{%- for column in columns_in_relation -%}
{%- do column_name_to_data_types.update({column.name: column.dtype}) -%}
{%- endfor -%}
{%- endif -%}

{%- if not column_name_to_data_types -%}
{{ exceptions.raise_compiler_error("columns not available for" ~ model.name) }}
{{ exceptions.raise_compiler_error("columns not available for " ~ model.name) }}
{%- endif -%}

{%- for column_name, column_type in column_name_to_data_types.items() -%}
{%- do default_row.update({column_name: (safe_cast("null", column_type) | trim )}) -%}
{%- endfor -%}

{%- for row in rows -%}
{%- do format_row(row, column_name_to_data_types) -%}

{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
{%- do format_row(row, column_name_to_data_types) -%}
{%- set default_row_copy = default_row.copy() -%}
{%- do default_row_copy.update(row) -%}
select
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
{%- for column_name, column_value in default_row_copy.items() %} {{ column_value }} AS {{ column_name }}{% if not loop.last -%}, {%- endif %}
{%- endfor %}
{%- if not loop.last %}
union all
{% endif %}
{% endif %}
{%- endfor -%}

{%- if (rows | length) == 0 -%}
Expand Down
102 changes: 59 additions & 43 deletions core/dbt/parser/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
UnitTestDefinition,
DependsOn,
UnitTestConfig,
UnitTestSourceDefinition,
)
from dbt.contracts.graph.unparsed import UnparsedUnitTest
from dbt.exceptions import ParsingError, InvalidUnitTestGivenInput
Expand Down Expand Up @@ -105,44 +106,54 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
- {id: 2, b: 2}
"""
# Add the model "input" nodes, consisting of all referenced models in the unit test.
# This creates a model for every input in every test, so there may be multiple
# input models substituting for the same input ref'd model.
# This creates an ephemeral model for every input in every test, so there may be multiple
# input models substituting for the same input ref'd model. Note that since these are
# always "ephemeral" they just wrap the tested_node SQL in additional CTEs. No actual table
# or view is created.
for given in test_case.given:
# extract the original_input_node from the ref in the "input" key of the given list
original_input_node = self._get_original_input_node(given.input, tested_node)

original_input_node_columns = None
if (
original_input_node.resource_type == NodeType.Model
and original_input_node.config.contract.enforced
):
original_input_node_columns = {
column.name: column.data_type for column in original_input_node.columns
}

# TODO: include package_name?
input_name = f"{unit_test_node.name}__{original_input_node.name}"
input_unique_id = f"model.{package_name}.{input_name}"
input_node = ModelNode(
raw_code=self._build_fixture_raw_code(
given.get_rows(
self.root_project.project_root, self.root_project.fixture_paths
),
original_input_node_columns,
project_root = self.root_project.project_root
common_fields = {
"resource_type": NodeType.Model,
"package_name": package_name,
"original_file_path": original_input_node.original_file_path,
"config": ModelConfig(materialized="ephemeral"),
"database": original_input_node.database,
"alias": original_input_node.identifier,
"schema": original_input_node.schema,
"fqn": original_input_node.fqn,
"checksum": FileHash.empty(),
"raw_code": self._build_fixture_raw_code(
given.get_rows(project_root, self.root_project.fixture_paths), None
),
resource_type=NodeType.Model,
package_name=package_name,
path=original_input_node.path,
original_file_path=original_input_node.original_file_path,
unique_id=input_unique_id,
name=input_name,
config=ModelConfig(materialized="ephemeral"),
database=original_input_node.database,
schema=original_input_node.schema,
alias=original_input_node.alias,
fqn=input_unique_id.split("."),
checksum=FileHash.empty(),
)
}

if original_input_node.resource_type == NodeType.Model:
input_name = f"{unit_test_node.name}__{original_input_node.name}"
input_node = ModelNode(
**common_fields,
unique_id=f"model.{package_name}.{input_name}",
name=input_name,
path=original_input_node.path,
)
elif original_input_node.resource_type == NodeType.Source:
# We are reusing the database/schema/identifier from the original source,
# but that shouldn't matter since this acts as an ephemeral model which just
# wraps a CTE around the unit test node.
input_name = f"{unit_test_node.name}__{original_input_node.search_name}__{original_input_node.name}"
input_node = UnitTestSourceDefinition(
**common_fields,
unique_id=f"model.{package_name}.{input_name}",
name=original_input_node.name, # must be the same name for source lookup to work
path=input_name + ".sql", # for writing out compiled_code
source_name=original_input_node.source_name, # needed for source lookup
)
# Sources need to go in the sources dictionary in order to create the right lookup
self.unit_test_manifest.sources[input_node.unique_id] = input_node # type: ignore

# Both ModelNode and UnitTestSourceDefinition need to go in nodes dictionary
self.unit_test_manifest.nodes[input_node.unique_id] = input_node

# Populate this_input_node_unique_id if input fixture represents node being tested
Expand All @@ -153,6 +164,8 @@ def parse_unit_test_case(self, test_case: UnitTestDefinition):
unit_test_node.depends_on.nodes.append(input_node.unique_id)

def _build_fixture_raw_code(self, rows, column_name_to_data_types) -> str:
# We're not currently using column_name_to_data_types, but leaving here for
# possible future use.
return ("{{{{ get_fixture_sql({rows}, {column_name_to_data_types}) }}}}").format(
rows=rows, column_name_to_data_types=column_name_to_data_types
)
Expand All @@ -178,18 +191,21 @@ def _get_original_input_node(self, input: str, tested_node: ModelNode):
raise InvalidUnitTestGivenInput(input=input)

Check warning on line 191 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L190-L191

Added lines #L190 - L191 were not covered by tests

if statically_parsed["refs"]:
for ref in statically_parsed["refs"]:
name = ref.get("name")
package = ref.get("package")
version = ref.get("version")
# TODO: disabled lookup, versioned lookup, public models
original_input_node = self.manifest.ref_lookup.find(
name, package, version, self.manifest
)
ref = list(statically_parsed["refs"])[0]
name = ref.get("name")
package = ref.get("package")
version = ref.get("version")
# TODO: disabled lookup, versioned lookup, public models
original_input_node = self.manifest.ref_lookup.find(
name, package, version, self.manifest
)
elif statically_parsed["sources"]:
input_package_name, input_source_name = statically_parsed["sources"][0]
source = list(statically_parsed["sources"])[0]
input_source_name, input_name = source
original_input_node = self.manifest.source_lookup.find(
input_source_name, input_package_name, self.manifest
f"{input_source_name}.{input_name}",
None,
self.manifest,
)
else:
raise InvalidUnitTestGivenInput(input=input)

Check warning on line 211 in core/dbt/parser/unit_tests.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/parser/unit_tests.py#L211

Added line #L211 was not covered by tests
Expand Down
69 changes: 69 additions & 0 deletions tests/functional/unit_testing/test_ut_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
from dbt.tests.util import run_dbt

raw_customers_csv = """id,first_name,last_name,email
1,Michael,Perez,[email protected]
2,Shawn,Mccoy,[email protected]
3,Kathleen,Payne,[email protected]
4,Jimmy,Cooper,[email protected]
5,Katherine,Rice,[email protected]
6,Sarah,Ryan,[email protected]
7,Martin,Mcdonald,[email protected]
8,Frank,Robinson,[email protected]
9,Jennifer,Franklin,[email protected]
10,Henry,Welch,[email protected]
"""

schema_sources_yml = """
sources:
- name: seed_sources
schema: "{{ target.schema }}"
tables:
- name: raw_customers
columns:
- name: id
tests:
- not_null:
severity: "{{ 'error' if target.name == 'prod' else 'warn' }}"
- unique
- name: first_name
- name: last_name
- name: email
unit_tests:
- name: test_customers
model: customers
given:
- input: source('seed_sources', 'raw_customers')
rows:
- {id: 1, first_name: Emily}
expect:
rows:
- {id: 1, first_name: Emily}
"""

customers_sql = """
select * from {{ source('seed_sources', 'raw_customers') }}
"""


class TestUnitTestSourceInput:
@pytest.fixture(scope="class")
def seeds(self):
return {
"raw_customers.csv": raw_customers_csv,
}

@pytest.fixture(scope="class")
def models(self):
return {
"customers.sql": customers_sql,
"sources.yml": schema_sources_yml,
}

def test_source_input(self, project):
results = run_dbt(["seed"])
results = run_dbt(["run"])
len(results) == 1

results = run_dbt(["unit-test"])
assert len(results) == 1

0 comments on commit c6be2d2

Please sign in to comment.