Skip to content

Commit

Permalink
fix(dbt): fix issue of assertion error when stateful ingestion is use…
Browse files Browse the repository at this point in the history
…d with dbt tests (#5540)

* fix(dbt): fix issue of dbt stateful ingestion with tests

Co-authored-by: MugdhaHardikar-GSLab <[email protected]>
Co-authored-by: MohdSiddique Bagwan <[email protected]>
Co-authored-by: Ravindra Lanka <[email protected]>
  • Loading branch information
4 people authored Aug 3, 2022
1 parent 0cbcaf3 commit f1abdc9
Show file tree
Hide file tree
Showing 8 changed files with 3,684 additions and 70 deletions.
2 changes: 0 additions & 2 deletions metadata-ingestion/src/datahub/emitter/mce_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,13 @@ def make_domain_urn(domain: str) -> str:


def make_ml_primary_key_urn(feature_table_name: str, primary_key_name: str) -> str:

return f"urn:li:mlPrimaryKey:({feature_table_name},{primary_key_name})"


def make_ml_feature_urn(
feature_table_name: str,
feature_name: str,
) -> str:

return f"urn:li:mlFeature:({feature_table_name},{feature_name})"


Expand Down
113 changes: 80 additions & 33 deletions metadata-ingestion/src/datahub/ingestion/source/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
resolve_trino_modified_type,
)
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.dbt_state import DbtCheckpointState
from datahub.ingestion.source.state.sql_common_state import (
BaseSQLAlchemyCheckpointState,
)
Expand Down Expand Up @@ -1005,15 +1006,48 @@ def __init__(self, config: DBTConfig, ctx: PipelineContext, platform: str):
self.config.owner_extraction_pattern
)

def get_last_dbt_checkpoint(
self, job_id: JobId, checkpoint_state_class: Type[DbtCheckpointState]
) -> Optional[Checkpoint]:

last_checkpoint: Optional[Checkpoint]
is_conversion_required: bool = False
try:
# Best-case that last checkpoint state is DbtCheckpointState
last_checkpoint = self.get_last_checkpoint(job_id, checkpoint_state_class)
except Exception as e:
# Backward compatibility for old dbt ingestion source which was saving dbt-nodes in
# BaseSQLAlchemyCheckpointState
last_checkpoint = self.get_last_checkpoint(
job_id, BaseSQLAlchemyCheckpointState
)
logger.debug(
f"Found BaseSQLAlchemyCheckpointState as checkpoint state (got {e})."
)
is_conversion_required = True

if last_checkpoint is not None and is_conversion_required:
# Map the BaseSQLAlchemyCheckpointState to DbtCheckpointState
dbt_checkpoint_state: DbtCheckpointState = DbtCheckpointState()
dbt_checkpoint_state.encoded_node_urns = (
cast(BaseSQLAlchemyCheckpointState, last_checkpoint.state)
).encoded_table_urns
# Old dbt source was not supporting the assertion
dbt_checkpoint_state.encoded_assertion_urns = []
last_checkpoint.state = dbt_checkpoint_state

return last_checkpoint

# TODO: Consider refactoring this logic out for use across sources as it is leading to a significant amount of
# code duplication.
def gen_removed_entity_workunits(self) -> Iterable[MetadataWorkUnit]:
last_checkpoint = self.get_last_checkpoint(
self.get_default_ingestion_job_id(), BaseSQLAlchemyCheckpointState
last_checkpoint: Optional[Checkpoint] = self.get_last_dbt_checkpoint(
self.get_default_ingestion_job_id(), DbtCheckpointState
)
cur_checkpoint = self.get_current_checkpoint(
self.get_default_ingestion_job_id()
)

if (
self.config.stateful_ingestion
and self.config.stateful_ingestion.remove_stale_metadata
Expand All @@ -1024,7 +1058,7 @@ def gen_removed_entity_workunits(self) -> Iterable[MetadataWorkUnit]:
):
logger.debug("Checking for stale entity removal.")

def soft_delete_item(urn: str, type: str) -> Iterable[MetadataWorkUnit]:
def get_soft_delete_item_workunit(urn: str, type: str) -> MetadataWorkUnit:

logger.info(f"Soft-deleting stale entity of type {type} - {urn}.")
mcp = MetadataChangeProposalWrapper(
Expand All @@ -1037,19 +1071,28 @@ def soft_delete_item(urn: str, type: str) -> Iterable[MetadataWorkUnit]:
wu = MetadataWorkUnit(id=f"soft-delete-{type}-{urn}", mcp=mcp)
self.report.report_workunit(wu)
self.report.report_stale_entity_soft_deleted(urn)
yield wu
return wu

last_checkpoint_state = cast(
BaseSQLAlchemyCheckpointState, last_checkpoint.state
)
cur_checkpoint_state = cast(
BaseSQLAlchemyCheckpointState, cur_checkpoint.state
)
last_checkpoint_state = cast(DbtCheckpointState, last_checkpoint.state)
cur_checkpoint_state = cast(DbtCheckpointState, cur_checkpoint.state)

for table_urn in last_checkpoint_state.get_table_urns_not_in(
cur_checkpoint_state
):
yield from soft_delete_item(table_urn, "dataset")
urns_to_soft_delete_by_type: Dict = {
"dataset": [
node_urn
for node_urn in last_checkpoint_state.get_node_urns_not_in(
cur_checkpoint_state
)
],
"assertion": [
assertion_urn
for assertion_urn in last_checkpoint_state.get_assertion_urns_not_in(
cur_checkpoint_state
)
],
}
for entity_type in urns_to_soft_delete_by_type:
for urn in urns_to_soft_delete_by_type[entity_type]:
yield get_soft_delete_item_workunit(urn, entity_type)

def load_file_as_json(self, uri: str) -> Any:
if re.match("^https?://", uri):
Expand Down Expand Up @@ -1155,7 +1198,7 @@ def string_map(input_map: Dict[str, Any]) -> Dict[str, str]:
}
)
)
self.save_checkpoint(node_datahub_urn)
self.save_checkpoint(node_datahub_urn, "assertion")

dpi_mcp = MetadataChangeProposalWrapper(
entityType="assertion",
Expand Down Expand Up @@ -1412,10 +1455,12 @@ def remove_duplicate_urns_from_checkpoint_state(self) -> None:
)

if cur_checkpoint is not None:
# Utilizing BaseSQLAlchemyCheckpointState class to save state
checkpoint_state = cast(BaseSQLAlchemyCheckpointState, cur_checkpoint.state)
checkpoint_state.encoded_table_urns = list(
set(checkpoint_state.encoded_table_urns)
checkpoint_state = cast(DbtCheckpointState, cur_checkpoint.state)
checkpoint_state.encoded_node_urns = list(
set(checkpoint_state.encoded_node_urns)
)
checkpoint_state.encoded_assertion_urns = list(
set(checkpoint_state.encoded_assertion_urns)
)

def create_platform_mces(
Expand Down Expand Up @@ -1458,7 +1503,7 @@ def create_platform_mces(
self.config.env,
mce_platform_instance,
)
self.save_checkpoint(node_datahub_urn)
self.save_checkpoint(node_datahub_urn, "dataset")

meta_aspects: Dict[str, Any] = {}
if self.config.enable_meta_mapping and node.meta:
Expand Down Expand Up @@ -1534,18 +1579,21 @@ def create_platform_mces(
self.report.report_workunit(wu)
yield wu

def save_checkpoint(self, node_datahub_urn: str) -> None:
if self.is_stateful_ingestion_configured():
cur_checkpoint = self.get_current_checkpoint(
self.get_default_ingestion_job_id()
)
def save_checkpoint(self, urn: str, entity_type: str) -> None:
# if stateful ingestion is not configured then return
if not self.is_stateful_ingestion_configured():
return

if cur_checkpoint is not None:
# Utilizing BaseSQLAlchemyCheckpointState class to save state
checkpoint_state = cast(
BaseSQLAlchemyCheckpointState, cur_checkpoint.state
)
checkpoint_state.add_table_urn(node_datahub_urn)
cur_checkpoint = self.get_current_checkpoint(
self.get_default_ingestion_job_id()
)
# if no checkpoint found then return
if cur_checkpoint is None:
return

# Cast and set the state
checkpoint_state = cast(DbtCheckpointState, cur_checkpoint.state)
checkpoint_state.set_checkpoint_urn(urn, entity_type)

def extract_query_tag_aspects(
self,
Expand Down Expand Up @@ -1900,8 +1948,7 @@ def create_checkpoint(self, job_id: JobId) -> Optional[Checkpoint]:
platform_instance_id=self.get_platform_instance_id(),
run_id=self.ctx.run_id,
config=self.config,
# Reusing BaseSQLAlchemyCheckpointState as it has needed functionality to support statefulness of DBT
state=BaseSQLAlchemyCheckpointState(),
state=DbtCheckpointState(),
)
return None

Expand Down
70 changes: 70 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/source/state/dbt_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
from typing import Callable, Dict, Iterable, List

import pydantic

from datahub.emitter.mce_builder import make_assertion_urn
from datahub.ingestion.source.state.checkpoint import CheckpointStateBase
from datahub.utilities.checkpoint_state_util import CheckpointStateUtil
from datahub.utilities.urns.urn import Urn

logger = logging.getLogger(__name__)


class DbtCheckpointState(CheckpointStateBase):
"""
Class for representing the checkpoint state for DBT sources.
Stores all nodes and assertions being ingested and is used to remove any stale entities.
"""

encoded_node_urns: List[str] = pydantic.Field(default_factory=list)
encoded_assertion_urns: List[str] = pydantic.Field(default_factory=list)

@staticmethod
def _get_assertion_lightweight_repr(assertion_urn: str) -> str:
"""Reduces the amount of text in the URNs for smaller state footprint."""
urn = Urn.create_from_string(assertion_urn)
key = urn.get_entity_id_as_string()
assert key is not None
return key

def add_assertion_urn(self, assertion_urn: str) -> None:
self.encoded_assertion_urns.append(
self._get_assertion_lightweight_repr(assertion_urn)
)

def get_assertion_urns_not_in(
self, checkpoint: "DbtCheckpointState"
) -> Iterable[str]:
"""
Dbt assertion are mapped to DataHub assertion concept
"""
difference = CheckpointStateUtil.get_encoded_urns_not_in(
self.encoded_assertion_urns, checkpoint.encoded_assertion_urns
)
for key in difference:
yield make_assertion_urn(key)

def get_node_urns_not_in(self, checkpoint: "DbtCheckpointState") -> Iterable[str]:
"""
Dbt node are mapped to DataHub dataset concept
"""
yield from CheckpointStateUtil.get_dataset_urns_not_in(
self.encoded_node_urns, checkpoint.encoded_node_urns
)

def add_node_urn(self, node_urn: str) -> None:
self.encoded_node_urns.append(
CheckpointStateUtil.get_dataset_lightweight_repr(node_urn)
)

def set_checkpoint_urn(self, urn: str, entity_type: str) -> None:
supported_entities_add_handlers: Dict[str, Callable[[str], None]] = {
"dataset": self.add_node_urn,
"assertion": self.add_assertion_urn,
}

if entity_type not in supported_entities_add_handlers:
logger.error(f"Can not save Unknown entity {entity_type} to checkpoint.")

supported_entities_add_handlers[entity_type](urn)
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@

import pydantic

from datahub.emitter.mce_builder import (
container_urn_to_key,
dataset_urn_to_key,
make_container_urn,
make_dataset_urn,
)
from datahub.emitter.mce_builder import container_urn_to_key, make_container_urn
from datahub.ingestion.source.state.checkpoint import CheckpointStateBase
from datahub.utilities.checkpoint_state_util import CheckpointStateUtil


class BaseSQLAlchemyCheckpointState(CheckpointStateBase):
Expand All @@ -21,19 +17,12 @@ class BaseSQLAlchemyCheckpointState(CheckpointStateBase):
encoded_table_urns: List[str] = pydantic.Field(default_factory=list)
encoded_view_urns: List[str] = pydantic.Field(default_factory=list)
encoded_container_urns: List[str] = pydantic.Field(default_factory=list)

@staticmethod
def _get_separator() -> str:
# Unique small string not allowed in URNs.
return "||"
encoded_assertion_urns: List[str] = pydantic.Field(default_factory=list)

@staticmethod
def _get_lightweight_repr(dataset_urn: str) -> str:
"""Reduces the amount of text in the URNs for smaller state footprint."""
SEP = BaseSQLAlchemyCheckpointState._get_separator()
key = dataset_urn_to_key(dataset_urn)
assert key is not None
return f"{key.platform}{SEP}{key.name}{SEP}{key.origin}"
return CheckpointStateUtil.get_dataset_lightweight_repr(dataset_urn)

@staticmethod
def _get_container_lightweight_repr(container_urn: str) -> str:
Expand All @@ -42,36 +31,29 @@ def _get_container_lightweight_repr(container_urn: str) -> str:
assert key is not None
return f"{key.guid}"

@staticmethod
def _get_dataset_urns_not_in(
encoded_urns_1: List[str], encoded_urns_2: List[str]
) -> Iterable[str]:
difference = set(encoded_urns_1) - set(encoded_urns_2)
for encoded_urn in difference:
platform, name, env = encoded_urn.split(
BaseSQLAlchemyCheckpointState._get_separator()
)
yield make_dataset_urn(platform, name, env)

@staticmethod
def _get_container_urns_not_in(
encoded_urns_1: List[str], encoded_urns_2: List[str]
) -> Iterable[str]:
difference = set(encoded_urns_1) - set(encoded_urns_2)
difference = CheckpointStateUtil.get_encoded_urns_not_in(
encoded_urns_1, encoded_urns_2
)
for guid in difference:
yield make_container_urn(guid)

def get_table_urns_not_in(
self, checkpoint: "BaseSQLAlchemyCheckpointState"
) -> Iterable[str]:
yield from self._get_dataset_urns_not_in(
"""Tables are mapped to DataHub dataset concept."""
yield from CheckpointStateUtil.get_dataset_urns_not_in(
self.encoded_table_urns, checkpoint.encoded_table_urns
)

def get_view_urns_not_in(
self, checkpoint: "BaseSQLAlchemyCheckpointState"
) -> Iterable[str]:
yield from self._get_dataset_urns_not_in(
"""Views are mapped to DataHub dataset concept."""
yield from CheckpointStateUtil.get_dataset_urns_not_in(
self.encoded_view_urns, checkpoint.encoded_view_urns
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def get_last_checkpoint(
):
return None

if JobId not in self.last_checkpoints:
if job_id not in self.last_checkpoints:
self.last_checkpoints[job_id] = self._get_last_checkpoint(
job_id, checkpoint_state_class
)
Expand Down
Loading

0 comments on commit f1abdc9

Please sign in to comment.