From aea5f6019a4ab1f7c2025eee01b1ff2669e5659a Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Fri, 16 Dec 2022 17:31:15 -0500 Subject: [PATCH] feat(ingest): remove `get_last_state` method from stateful ingestion --- .../src/datahub/ingestion/api/committable.py | 7 +---- ...gestion_job_checkpointing_provider_base.py | 25 ++++------------- .../ingestion/source/state/checkpoint.py | 2 +- ...atahub_ingestion_checkpointing_provider.py | 16 ----------- .../tests/test_helpers/state_helpers.py | 8 +++--- ...atahub_ingestion_checkpointing_provider.py | 27 +++++++------------ 6 files changed, 22 insertions(+), 63 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/api/committable.py b/metadata-ingestion/src/datahub/ingestion/api/committable.py index b9fa3d92e48c6..cc7d74469f2b3 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/committable.py +++ b/metadata-ingestion/src/datahub/ingestion/api/committable.py @@ -20,13 +20,12 @@ def commit(self) -> None: pass -StateKeyType = TypeVar("StateKeyType") StateType = TypeVar("StateType") class StatefulCommittable( Committable, - Generic[StateKeyType, StateType], + Generic[StateType], ): def __init__( self, name: str, commit_policy: CommitPolicy, state_to_commit: StateType @@ -37,7 +36,3 @@ def __init__( def has_successfully_committed(self) -> bool: return bool(not self.state_to_commit or self.committed) - - @abstractmethod - def get_last_state(self, state_key: StateKeyType) -> StateType: - pass diff --git a/metadata-ingestion/src/datahub/ingestion/api/ingestion_job_checkpointing_provider_base.py b/metadata-ingestion/src/datahub/ingestion/api/ingestion_job_checkpointing_provider_base.py index aaf90bfcdd5f0..ceb776089d24e 100644 --- a/metadata-ingestion/src/datahub/ingestion/api/ingestion_job_checkpointing_provider_base.py +++ b/metadata-ingestion/src/datahub/ingestion/api/ingestion_job_checkpointing_provider_base.py @@ -1,6 +1,6 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, NewType, Type, TypeVar +from typing import Any, Dict, NewType, Type, TypeVar import datahub.emitter.mce_builder as builder from datahub.configuration.common import ConfigModel @@ -13,13 +13,6 @@ CheckpointJobStatesMap = Dict[JobId, CheckpointJobStateType] -@dataclass -class JobStateKey: - pipeline_name: str - platform_instance_id: str - job_names: List[JobId] - - class IngestionCheckpointingProviderConfig(ConfigModel): pass @@ -28,9 +21,7 @@ class IngestionCheckpointingProviderConfig(ConfigModel): @dataclass() -class IngestionCheckpointingProviderBase( - StatefulCommittable[JobStateKey, CheckpointJobStatesMap] -): +class IngestionCheckpointingProviderBase(StatefulCommittable[CheckpointJobStatesMap]): """ The base class for all checkpointing state provider implementations. """ @@ -42,21 +33,15 @@ def __init__( super().__init__(name, commit_policy, {}) @classmethod + @abstractmethod def create( cls: Type[_Self], config_dict: Dict[str, Any], ctx: PipelineContext, name: str ) -> "_Self": - raise NotImplementedError("Sub-classes must override this method.") - - @abstractmethod - def get_last_state( - self, - state_key: JobStateKey, - ) -> CheckpointJobStatesMap: - ... + pass @abstractmethod def commit(self) -> None: - ... + pass @staticmethod def get_data_job_urn( diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py b/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py index 0c91f431cc387..837e3e68fb2cd 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/checkpoint.py @@ -132,7 +132,7 @@ def create_from_checkpoint_aspect( raise ValueError(f"Unknown serde: {checkpoint_aspect.state.serde}") except Exception as e: logger.error( - "Failed to construct checkpoint class from checkpoint aspect.", e + f"Failed to construct checkpoint class from checkpoint aspect: {e}" ) raise e else: diff --git a/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py b/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py index 7ef67888a21b5..63bd1a958a1ce 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state_provider/datahub_ingestion_checkpointing_provider.py @@ -6,11 +6,9 @@ from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import ( - CheckpointJobStatesMap, IngestionCheckpointingProviderBase, IngestionCheckpointingProviderConfig, JobId, - JobStateKey, ) from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph from datahub.metadata.schema_classes import ( @@ -102,20 +100,6 @@ def get_latest_checkpoint( return None - def get_last_state( - self, - state_key: JobStateKey, - ) -> CheckpointJobStatesMap: - last_job_checkpoint_map: CheckpointJobStatesMap = {} - for job_name in state_key.job_names: - last_job_checkpoint = self.get_latest_checkpoint( - state_key.pipeline_name, state_key.platform_instance_id, job_name - ) - if last_job_checkpoint is not None: - last_job_checkpoint_map[job_name] = last_job_checkpoint - - return last_job_checkpoint_map - def commit(self) -> None: if not self.state_to_commit: logger.warning(f"No state available to commit for {self.name}") diff --git a/metadata-ingestion/tests/test_helpers/state_helpers.py b/metadata-ingestion/tests/test_helpers/state_helpers.py index 00abfcd3dc541..e4a5ffc00749a 100644 --- a/metadata-ingestion/tests/test_helpers/state_helpers.py +++ b/metadata-ingestion/tests/test_helpers/state_helpers.py @@ -6,7 +6,9 @@ from avrogen.dict_wrapper import DictWrapper from datahub.emitter.mcp import MetadataChangeProposalWrapper -from datahub.ingestion.api.committable import StatefulCommittable +from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import ( + IngestionCheckpointingProviderBase, +) from datahub.ingestion.graph.client import DataHubGraph from datahub.ingestion.run.pipeline import Pipeline @@ -21,8 +23,8 @@ def validate_all_providers_have_committed_successfully( provider_count: int = 0 for _, provider in pipeline.ctx.get_committables(): provider_count += 1 - assert isinstance(provider, StatefulCommittable) - stateful_committable = cast(StatefulCommittable, provider) + assert isinstance(provider, IngestionCheckpointingProviderBase) + stateful_committable = cast(IngestionCheckpointingProviderBase, provider) assert stateful_committable.has_successfully_committed() assert stateful_committable.state_to_commit assert provider_count == expected_providers diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_datahub_ingestion_checkpointing_provider.py b/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_datahub_ingestion_checkpointing_provider.py index 3b190ae3d0721..b45ac932db66f 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_datahub_ingestion_checkpointing_provider.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_datahub_ingestion_checkpointing_provider.py @@ -8,11 +8,8 @@ from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import ( - CheckpointJobStatesMap, CheckpointJobStateType, - IngestionCheckpointingProviderBase, JobId, - JobStateKey, ) from datahub.ingestion.source.state.checkpoint import Checkpoint from datahub.ingestion.source.state.sql_common_state import ( @@ -31,11 +28,6 @@ class TestDatahubIngestionCheckpointProvider(unittest.TestCase): platform_instance_id: str = "test_platform_instance_1" job_names: List[JobId] = [JobId("job1"), JobId("job2")] run_id: str = "test_run" - job_state_key: JobStateKey = JobStateKey( - pipeline_name=pipeline_name, - platform_instance_id=platform_instance_id, - job_names=job_names, - ) def setUp(self) -> None: self._setup_mock_graph() @@ -64,7 +56,7 @@ def _setup_mock_graph(self) -> None: # Tracking for emitted mcps. self.mcps_emitted: Dict[str, MetadataChangeProposalWrapper] = {} - def _create_provider(self) -> IngestionCheckpointingProviderBase: + def _create_provider(self) -> DatahubIngestionCheckpointingProvider: ctx: PipelineContext = PipelineContext( run_id=self.run_id, pipeline_name=self.pipeline_name ) @@ -153,26 +145,27 @@ def test_provider(self): # 4. Get last committed state. This must match what has been committed earlier. # NOTE: This will retrieve from in-memory self.mcps_emitted because of the monkey-patching. - last_state: Optional[CheckpointJobStatesMap] = self.provider.get_last_state( - self.job_state_key + job1_last_state = self.provider.get_latest_checkpoint( + self.pipeline_name, self.platform_instance_id, self.job_names[0] + ) + job2_last_state = self.provider.get_latest_checkpoint( + self.pipeline_name, self.platform_instance_id, self.job_names[1] ) - assert last_state is not None - self.assertEqual(len(last_state), 2) # 5. Validate individual job checkpoint state values that have been committed and retrieved # against the original values. - self.assertIsNotNone(last_state[self.job_names[0]]) + self.assertIsNotNone(job1_last_state) job1_last_checkpoint = Checkpoint.create_from_checkpoint_aspect( job_name=self.job_names[0], - checkpoint_aspect=last_state[self.job_names[0]], + checkpoint_aspect=job1_last_state, state_class=type(job1_state_obj), ) self.assertEqual(job1_last_checkpoint, job1_checkpoint) - self.assertIsNotNone(last_state[self.job_names[1]]) + self.assertIsNotNone(job2_last_state) job2_last_checkpoint = Checkpoint.create_from_checkpoint_aspect( job_name=self.job_names[1], - checkpoint_aspect=last_state[self.job_names[1]], + checkpoint_aspect=job2_last_state, state_class=type(job2_state_obj), ) self.assertEqual(job2_last_checkpoint, job2_checkpoint)