-
Notifications
You must be signed in to change notification settings - Fork 8
/
runner.py
67 lines (55 loc) · 2.14 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import logging
import os
from typing import Any, Dict
from kedro.io import AbstractDataSet, DataCatalog
from kedro.pipeline import Pipeline
from kedro.runner import SequentialRunner
from pluggy import PluginManager
from pydantic import BaseModel
from kedro_sagemaker.constants import (
KEDRO_SAGEMAKER_EXECUTION_ARN,
KEDRO_SAGEMAKER_RUNNER_CONFIG,
)
from kedro_sagemaker.datasets import (
CloudpickleDataset,
DistributedCloudpickleDataset,
)
from kedro_sagemaker.utils import is_distributed_environment
logger = logging.getLogger(__name__)
class KedroSageMakerRunnerConfig(BaseModel):
bucket: str
class SageMakerPipelinesRunner(SequentialRunner):
@classmethod
def runner_name(cls):
return f"{cls.__module__}.{cls.__qualname__}"
def __init__(self, is_async: bool = False):
super().__init__(is_async)
self.runner_config_raw = os.environ.get(KEDRO_SAGEMAKER_RUNNER_CONFIG)
self.runner_config = KedroSageMakerRunnerConfig.parse_raw(
self.runner_config_raw
)
self.run_id = os.getenv(KEDRO_SAGEMAKER_EXECUTION_ARN, "local").split(":")[-1]
def run(
self,
pipeline: Pipeline,
catalog: DataCatalog,
hook_manager: PluginManager = None,
session_id: str = None,
) -> Dict[str, Any]:
unsatisfied = pipeline.inputs() - set(catalog.list())
for ds_name in unsatisfied:
catalog = catalog.shallow_copy()
catalog.add(ds_name, self.create_default_data_set(ds_name))
return super().run(pipeline, catalog, hook_manager, session_id)
def create_default_data_set(self, ds_name: str) -> AbstractDataSet:
# TODO: handle credentials better (probably with built-in Kedro credentials
# via ConfigLoader (but it's not available here...)
dataset_cls = CloudpickleDataset
if is_distributed_environment():
logger.info("Using distributed dataset class as a default")
dataset_cls = DistributedCloudpickleDataset
return dataset_cls(
bucket=self.runner_config.bucket,
dataset_name=ds_name,
run_id=self.run_id,
)