Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ingest/stateful): remove get_last_state method #6794

Merged
merged 1 commit into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions metadata-ingestion/src/datahub/ingestion/api/committable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ def commit(self) -> None:
pass


StateKeyType = TypeVar("StateKeyType")
StateType = TypeVar("StateType")


class StatefulCommittable(
Committable,
Generic[StateKeyType, StateType],
Generic[StateType],
):
def __init__(
self, name: str, commit_policy: CommitPolicy, state_to_commit: StateType
Expand All @@ -37,7 +36,3 @@ def __init__(

def has_successfully_committed(self) -> bool:
return bool(not self.state_to_commit or self.committed)

@abstractmethod
def get_last_state(self, state_key: StateKeyType) -> StateType:
pass
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, NewType, Type, TypeVar
from typing import Any, Dict, NewType, Type, TypeVar

import datahub.emitter.mce_builder as builder
from datahub.configuration.common import ConfigModel
Expand All @@ -13,13 +13,6 @@
CheckpointJobStatesMap = Dict[JobId, CheckpointJobStateType]


@dataclass
class JobStateKey:
pipeline_name: str
platform_instance_id: str
job_names: List[JobId]


class IngestionCheckpointingProviderConfig(ConfigModel):
pass

Expand All @@ -28,9 +21,7 @@ class IngestionCheckpointingProviderConfig(ConfigModel):


@dataclass()
class IngestionCheckpointingProviderBase(
StatefulCommittable[JobStateKey, CheckpointJobStatesMap]
):
class IngestionCheckpointingProviderBase(StatefulCommittable[CheckpointJobStatesMap]):
"""
The base class for all checkpointing state provider implementations.
"""
Expand All @@ -42,21 +33,15 @@ def __init__(
super().__init__(name, commit_policy, {})

@classmethod
@abstractmethod
def create(
cls: Type[_Self], config_dict: Dict[str, Any], ctx: PipelineContext, name: str
) -> "_Self":
raise NotImplementedError("Sub-classes must override this method.")

@abstractmethod
def get_last_state(
self,
state_key: JobStateKey,
) -> CheckpointJobStatesMap:
...
pass

@abstractmethod
def commit(self) -> None:
...
pass

@staticmethod
def get_data_job_urn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def create_from_checkpoint_aspect(
raise ValueError(f"Unknown serde: {checkpoint_aspect.state.serde}")
except Exception as e:
logger.error(
"Failed to construct checkpoint class from checkpoint aspect.", e
f"Failed to construct checkpoint class from checkpoint aspect: {e}"
)
raise e
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
CheckpointJobStatesMap,
IngestionCheckpointingProviderBase,
IngestionCheckpointingProviderConfig,
JobId,
JobStateKey,
)
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from datahub.metadata.schema_classes import (
Expand Down Expand Up @@ -102,20 +100,6 @@ def get_latest_checkpoint(

return None

def get_last_state(
self,
state_key: JobStateKey,
) -> CheckpointJobStatesMap:
last_job_checkpoint_map: CheckpointJobStatesMap = {}
for job_name in state_key.job_names:
last_job_checkpoint = self.get_latest_checkpoint(
state_key.pipeline_name, state_key.platform_instance_id, job_name
)
if last_job_checkpoint is not None:
last_job_checkpoint_map[job_name] = last_job_checkpoint

return last_job_checkpoint_map

def commit(self) -> None:
if not self.state_to_commit:
logger.warning(f"No state available to commit for {self.name}")
Expand Down
8 changes: 5 additions & 3 deletions metadata-ingestion/tests/test_helpers/state_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from avrogen.dict_wrapper import DictWrapper

from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.committable import StatefulCommittable
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
IngestionCheckpointingProviderBase,
)
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.run.pipeline import Pipeline

Expand All @@ -21,8 +23,8 @@ def validate_all_providers_have_committed_successfully(
provider_count: int = 0
for _, provider in pipeline.ctx.get_committables():
provider_count += 1
assert isinstance(provider, StatefulCommittable)
stateful_committable = cast(StatefulCommittable, provider)
assert isinstance(provider, IngestionCheckpointingProviderBase)
stateful_committable = cast(IngestionCheckpointingProviderBase, provider)
assert stateful_committable.has_successfully_committed()
assert stateful_committable.state_to_commit
assert provider_count == expected_providers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import (
CheckpointJobStatesMap,
CheckpointJobStateType,
IngestionCheckpointingProviderBase,
JobId,
JobStateKey,
)
from datahub.ingestion.source.state.checkpoint import Checkpoint
from datahub.ingestion.source.state.sql_common_state import (
Expand All @@ -31,11 +28,6 @@ class TestDatahubIngestionCheckpointProvider(unittest.TestCase):
platform_instance_id: str = "test_platform_instance_1"
job_names: List[JobId] = [JobId("job1"), JobId("job2")]
run_id: str = "test_run"
job_state_key: JobStateKey = JobStateKey(
pipeline_name=pipeline_name,
platform_instance_id=platform_instance_id,
job_names=job_names,
)

def setUp(self) -> None:
self._setup_mock_graph()
Expand Down Expand Up @@ -64,7 +56,7 @@ def _setup_mock_graph(self) -> None:
# Tracking for emitted mcps.
self.mcps_emitted: Dict[str, MetadataChangeProposalWrapper] = {}

def _create_provider(self) -> IngestionCheckpointingProviderBase:
def _create_provider(self) -> DatahubIngestionCheckpointingProvider:
ctx: PipelineContext = PipelineContext(
run_id=self.run_id, pipeline_name=self.pipeline_name
)
Expand Down Expand Up @@ -153,26 +145,27 @@ def test_provider(self):

# 4. Get last committed state. This must match what has been committed earlier.
# NOTE: This will retrieve from in-memory self.mcps_emitted because of the monkey-patching.
last_state: Optional[CheckpointJobStatesMap] = self.provider.get_last_state(
self.job_state_key
job1_last_state = self.provider.get_latest_checkpoint(
self.pipeline_name, self.platform_instance_id, self.job_names[0]
)
job2_last_state = self.provider.get_latest_checkpoint(
self.pipeline_name, self.platform_instance_id, self.job_names[1]
)
assert last_state is not None
self.assertEqual(len(last_state), 2)

# 5. Validate individual job checkpoint state values that have been committed and retrieved
# against the original values.
self.assertIsNotNone(last_state[self.job_names[0]])
self.assertIsNotNone(job1_last_state)
job1_last_checkpoint = Checkpoint.create_from_checkpoint_aspect(
job_name=self.job_names[0],
checkpoint_aspect=last_state[self.job_names[0]],
checkpoint_aspect=job1_last_state,
state_class=type(job1_state_obj),
)
self.assertEqual(job1_last_checkpoint, job1_checkpoint)

self.assertIsNotNone(last_state[self.job_names[1]])
self.assertIsNotNone(job2_last_state)
job2_last_checkpoint = Checkpoint.create_from_checkpoint_aspect(
job_name=self.job_names[1],
checkpoint_aspect=last_state[self.job_names[1]],
checkpoint_aspect=job2_last_state,
state_class=type(job2_state_obj),
)
self.assertEqual(job2_last_checkpoint, job2_checkpoint)