diff --git a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml index fd594fa8408c..102cfec1e17a 100644 --- a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml +++ b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml @@ -51,6 +51,7 @@ body: - elasticsearch - exasol - facebook + - flyte - ftp - github - google diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 34b220405e6b..92f933b02f8c 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -620,7 +620,7 @@ apache.druid, apache.hdfs, apache.hive, apache.kylin, apache.livy, apache.pig, a apache.spark, apache.sqoop, apache.webhdfs, asana, async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes, crypto, dask, databricks, datadog, dbt.cloud, deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker, druid, elasticsearch, -exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc, +exasol, facebook, flyte, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, jira, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, pandas, papermill, password, pinot, plexus, diff --git a/INSTALL b/INSTALL index 4354d5e33f55..5bcb8c462c6c 100644 --- a/INSTALL +++ b/INSTALL @@ -100,7 +100,7 @@ apache.druid, apache.hdfs, apache.hive, apache.kylin, apache.livy, apache.pig, a apache.spark, apache.sqoop, apache.webhdfs, asana, async, atlas, aws, azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes, crypto, dask, databricks, datadog, dbt.cloud, deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker, druid, elasticsearch, -exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc, +exasol, facebook, flyte, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, jira, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, pandas, papermill, password, pinot, plexus, diff --git a/airflow/providers/flyte/CHANGELOG.rst b/airflow/providers/flyte/CHANGELOG.rst new file mode 100644 index 000000000000..cef7dda80708 --- /dev/null +++ b/airflow/providers/flyte/CHANGELOG.rst @@ -0,0 +1,25 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.0 +..... + +Initial version of the provider. diff --git a/airflow/providers/flyte/__init__.py b/airflow/providers/flyte/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/airflow/providers/flyte/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/flyte/example_dags/__init__.py b/airflow/providers/flyte/example_dags/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/airflow/providers/flyte/example_dags/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/flyte/example_dags/example_flyte.py b/airflow/providers/flyte/example_dags/example_flyte.py new file mode 100644 index 000000000000..ac0db5a0f827 --- /dev/null +++ b/airflow/providers/flyte/example_dags/example_flyte.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example DAG demonstrating the usage of the AirflowFlyteOperator.""" + +from datetime import datetime, timedelta + +from airflow import DAG +from airflow.providers.flyte.operators.flyte import AirflowFlyteOperator +from airflow.providers.flyte.sensors.flyte import AirflowFlyteSensor + +with DAG( + dag_id="example_flyte_operator", + schedule_interval=None, + start_date=datetime(2021, 1, 1), + dagrun_timeout=timedelta(minutes=60), + tags=["example"], + catchup=False, +) as dag: + + # [START howto_operator_flyte_synchronous] + sync_predictions = AirflowFlyteOperator( + task_id="flyte_example_sync", + flyte_conn_id="flyte_conn_example", + project="flytesnacks", + domain="development", + launchplan_name="core.basic.lp.my_wf", + assumable_iam_role="default", + kubernetes_service_account="demo", + version="v1", + inputs={"val": 19}, + timeout=timedelta(seconds=3600), + ) + # [END howto_operator_flyte_synchronous] + + # [START howto_operator_flyte_asynchronous] + async_predictions = AirflowFlyteOperator( + task_id="flyte_example_async", + flyte_conn_id="flyte_conn_example", + project="flytesnacks", + domain="development", + launchplan_name="core.basic.lp.my_wf", + max_parallelism=2, + raw_data_prefix="s3://flyte-demo/raw_data", + assumable_iam_role="default", + kubernetes_service_account="demo", + version="v1", + inputs={"val": 19}, + timeout=timedelta(seconds=3600), + asynchronous=True, + ) + + predictions_sensor = AirflowFlyteSensor( + task_id="predictions_sensor", + execution_name=async_predictions.output, + project="flytesnacks", + domain="development", + flyte_conn_id="flyte_conn_example", + ) + # [END howto_operator_flyte_asynchronous] + + # Task dependency created via `XComArgs`: + async_predictions >> predictions_sensor diff --git a/airflow/providers/flyte/hooks/__init__.py b/airflow/providers/flyte/hooks/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/airflow/providers/flyte/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/flyte/hooks/flyte.py b/airflow/providers/flyte/hooks/flyte.py new file mode 100644 index 000000000000..ec8c965a602f --- /dev/null +++ b/airflow/providers/flyte/hooks/flyte.py @@ -0,0 +1,205 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +from flytekit.configuration import AuthType, Config, PlatformConfig +from flytekit.exceptions.user import FlyteEntityNotExistException +from flytekit.models.common import Annotations, AuthRole, Labels +from flytekit.models.core import execution as core_execution_models +from flytekit.models.core.identifier import WorkflowExecutionIdentifier +from flytekit.remote.remote import FlyteRemote, Options + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class AirflowFlyteHook(BaseHook): + """ + Interact with the FlyteRemote API. + + :param flyte_conn_id: Required. The name of the Flyte connection to get + the connection information for Flyte. + :param project: Optional. The project under consideration. + :param domain: Optional. The domain under consideration. + """ + + SUCCEEDED = core_execution_models.WorkflowExecutionPhase.SUCCEEDED + FAILED = core_execution_models.WorkflowExecutionPhase.FAILED + TIMED_OUT = core_execution_models.WorkflowExecutionPhase.TIMED_OUT + ABORTED = core_execution_models.WorkflowExecutionPhase.ABORTED + + flyte_conn_id = "flyte_default" + conn_type = "flyte" + + def __init__( + self, flyte_conn_id: str = flyte_conn_id, project: Optional[str] = None, domain: Optional[str] = None + ) -> None: + super().__init__() + self.flyte_conn_id = flyte_conn_id + self.flyte_conn = self.get_connection(self.flyte_conn_id) + self.project = project or self.flyte_conn.extra_dejson.get("project") + self.domain = domain or self.flyte_conn.extra_dejson.get("domain") + + if not (self.project and self.domain): + raise AirflowException("Please provide a project and domain.") + + def execution_id(self, execution_name: str) -> WorkflowExecutionIdentifier: + """Get the execution id.""" + return WorkflowExecutionIdentifier(self.project, self.domain, execution_name) + + def create_flyte_remote(self) -> FlyteRemote: + """Create a FlyteRemote object.""" + remote = FlyteRemote( + config=Config( + platform=PlatformConfig( + endpoint=":".join([self.flyte_conn.host, self.flyte_conn.port]) + if (self.flyte_conn.host and self.flyte_conn.port) + else (self.flyte_conn.host or "localhost:30081"), + insecure=self.flyte_conn.extra_dejson.get("insecure", False), + client_id=self.flyte_conn.login or None, + client_credentials_secret=self.flyte_conn.password or None, + command=self.flyte_conn.extra_dejson.get("command", None), + scopes=self.flyte_conn.extra_dejson.get("scopes", None), + auth_mode=AuthType(self.flyte_conn.extra_dejson.get("auth_mode", "standard")), + ) + ), + ) + return remote + + def trigger_execution( + self, + execution_name: str, + launchplan_name: Optional[str] = None, + task_name: Optional[str] = None, + max_parallelism: Optional[int] = None, + raw_data_prefix: Optional[str] = None, + assumable_iam_role: Optional[str] = None, + kubernetes_service_account: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + annotations: Optional[Dict[str, str]] = None, + version: Optional[str] = None, + inputs: Dict[str, Any] = {}, + ) -> None: + """ + Trigger an execution. + + :param execution_name: Required. The name of the execution to trigger. + :param launchplan_name: Optional. The name of the launchplan to trigger. + :param task_name: Optional. The name of the task to trigger. + :param max_parallelism: Optional. The maximum number of parallel executions to allow. + :param raw_data_prefix: Optional. The prefix to use for raw data. + :param assumable_iam_role: Optional. The assumable IAM role to use. + :param kubernetes_service_account: Optional. The kubernetes service account to use. + :param labels: Optional. The labels to use. + :param annotations: Optional. The annotations to use. + :param version: Optional. The version of the launchplan to trigger. + :param inputs: Optional. The inputs to the launchplan. + """ + if (not (task_name or launchplan_name)) or (task_name and launchplan_name): + raise AirflowException("Either task_name or launchplan_name is required.") + + remote = self.create_flyte_remote() + try: + if launchplan_name: + flyte_entity = remote.fetch_launch_plan( + name=launchplan_name, project=self.project, domain=self.domain, version=version + ) + elif task_name: + flyte_entity = remote.fetch_task( + name=task_name, project=self.project, domain=self.domain, version=version + ) + except FlyteEntityNotExistException as e: + raise AirflowException(f"Failed to fetch entity: {e}") + + try: + remote.execute( + flyte_entity, + inputs=inputs, + project=self.project, + domain=self.domain, + execution_name=execution_name, + options=Options( + raw_data_prefix=raw_data_prefix, + max_parallelism=max_parallelism, + auth_role=AuthRole( + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, + ), + labels=Labels(labels), + annotations=Annotations(annotations), + ), + ) + except Exception as e: + raise AirflowException(f"Failed to trigger execution: {e}") + + def execution_status(self, execution_name: str, remote: FlyteRemote): + phase = remote.client.get_execution(self.execution_id(execution_name)).closure.phase + + if phase == self.SUCCEEDED: + return True + elif phase == self.FAILED: + raise AirflowException(f"Execution {execution_name} failed") + elif phase == self.TIMED_OUT: + raise AirflowException(f"Execution {execution_name} timedout") + elif phase == self.ABORTED: + raise AirflowException(f"Execution {execution_name} aborted") + else: + return False + + def wait_for_execution( + self, + execution_name: str, + timeout: Optional[timedelta] = None, + poll_interval: timedelta = timedelta(seconds=30), + ) -> None: + """ + Helper method which polls an execution to check the status. + + :param execution: Required. The execution to check. + :param timeout: Optional. The timeout to wait for the execution to finish. + :param poll_interval: Optional. The interval between checks to poll the execution. + """ + remote = self.create_flyte_remote() + + time_to_give_up = datetime.max if timeout is None else datetime.utcnow() + timeout + + while datetime.utcnow() < time_to_give_up: + time.sleep(poll_interval.total_seconds()) + + if self.execution_status(execution_name, remote): + return + continue + + raise AirflowException(f"Execution {execution_name} timedout") + + def terminate( + self, + execution_name: str, + cause: str, + ) -> None: + """ + Terminate an execution. + + :param execution: Required. The execution to terminate. + :param cause: Required. The cause of the termination. + """ + remote = self.create_flyte_remote() + execution_id = self.execution_id(execution_name) + remote.client.terminate_execution(id=execution_id, cause=cause) diff --git a/airflow/providers/flyte/operators/__init__.py b/airflow/providers/flyte/operators/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/airflow/providers/flyte/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/flyte/operators/flyte.py b/airflow/providers/flyte/operators/flyte.py new file mode 100644 index 000000000000..88978626c677 --- /dev/null +++ b/airflow/providers/flyte/operators/flyte.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.flyte.hooks.flyte import AirflowFlyteHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AirflowFlyteOperator(BaseOperator): + """ + Launch Flyte executions from within Airflow. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AirflowFlyteOperator` + + :param flyte_conn_id: Required. The connection to Flyte setup, containing metadata. + :param project: Optional. The project to connect to. + :param domain: Optional. The domain to connect to. + :param launchplan_name: Optional. The name of the launchplan to trigger. + :param task_name: Optional. The name of the task to trigger. + :param max_parallelism: Optional. The maximum number of parallel executions to allow. + :param raw_data_prefix: Optional. The prefix to use for raw data. + :param assumable_iam_role: Optional. The IAM role to assume. + :param kubernetes_service_account: Optional. The Kubernetes service account to use. + :param labels: Optional. Custom labels to be applied to the execution resource. + :param annotations: Optional. Custom annotations to be applied to the execution resource. + :param version: Optional. The version of the launchplan/task to trigger. + :param inputs: Optional. The inputs to the launchplan/task. + :param timeout: Optional. The timeout to wait for the execution to finish. + :param poll_interval: Optional. The interval between checks to poll the execution. + :param asynchronous: Optional. Whether to wait for the execution to finish or not. + """ + + template_fields: Sequence[str] = ("flyte_conn_id",) # mypy fix + + def __init__( + self, + flyte_conn_id: str, + project: Optional[str] = None, + domain: Optional[str] = None, + launchplan_name: Optional[str] = None, + task_name: Optional[str] = None, + max_parallelism: Optional[int] = None, + raw_data_prefix: Optional[str] = None, + assumable_iam_role: Optional[str] = None, + kubernetes_service_account: Optional[str] = None, + labels: Dict[str, str] = {}, + annotations: Dict[str, str] = {}, + version: Optional[str] = None, + inputs: Dict[str, Any] = {}, + timeout: Optional[timedelta] = None, + poll_interval: timedelta = timedelta(seconds=30), + asynchronous: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.flyte_conn_id = flyte_conn_id + self.project = project + self.domain = domain + self.launchplan_name = launchplan_name + self.task_name = task_name + self.max_parallelism = max_parallelism + self.raw_data_prefix = raw_data_prefix + self.assumable_iam_role = assumable_iam_role + self.kubernetes_service_account = kubernetes_service_account + self.labels = labels + self.annotations = annotations + self.version = version + self.inputs = inputs + self.timeout = timeout + self.poll_interval = poll_interval + self.asynchronous = asynchronous + self.execution_name: str = "" + + if (not (self.task_name or self.launchplan_name)) or (self.task_name and self.launchplan_name): + raise AirflowException("Either task_name or launchplan_name is required.") + + def execute(self, context: "Context") -> str: + """Trigger an execution and wait for it to finish.""" + + # create a deterministic execution name + task_id = re.sub(r"[\W_]+", "", context["task"].task_id)[:5] + self.execution_name = task_id + re.sub( + r"[\W_]+", + "", + context["dag_run"].run_id.split("__")[-1].lower(), + )[: (20 - len(task_id))] + + hook = AirflowFlyteHook(flyte_conn_id=self.flyte_conn_id, project=self.project, domain=self.domain) + hook.trigger_execution( + launchplan_name=self.launchplan_name, + task_name=self.task_name, + max_parallelism=self.max_parallelism, + raw_data_prefix=self.raw_data_prefix, + assumable_iam_role=self.assumable_iam_role, + kubernetes_service_account=self.kubernetes_service_account, + labels=self.labels, + annotations=self.annotations, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + ) + self.log.info("Execution %s submitted", self.execution_name) + + if not self.asynchronous: + self.log.info("Waiting for execution %s to complete", self.execution_name) + hook.wait_for_execution( + execution_name=self.execution_name, + timeout=self.timeout, + poll_interval=self.poll_interval, + ) + self.log.info("Execution %s completed", self.execution_name) + + return self.execution_name + + def on_kill(self) -> None: + """Kill the execution.""" + if self.execution_name: + print(f"Killing execution {self.execution_name}") + hook = AirflowFlyteHook( + flyte_conn_id=self.flyte_conn_id, project=self.project, domain=self.domain + ) + hook.terminate( + execution_name=self.execution_name, + cause="Killed by Airflow", + ) diff --git a/airflow/providers/flyte/provider.yaml b/airflow/providers/flyte/provider.yaml new file mode 100644 index 000000000000..4deff0810944 --- /dev/null +++ b/airflow/providers/flyte/provider.yaml @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-flyte +name: Flyte +description: | + `Flyte `__ +versions: + - 1.0.0 + +additional-dependencies: + - apache-airflow>=2.1.0 + +integrations: + - integration-name: Flyte + external-doc-url: https://docs.flyte.org/ + logo: /integration-logos/flyte/Flyte.png + how-to-guide: + - /docs/apache-airflow-providers-flyte/operators/flyte.rst + tags: [service] + +operators: + - integration-name: Flyte + python-modules: + - airflow.providers.flyte.operators.flyte + +hooks: + - integration-name: Flyte + python-modules: + - airflow.providers.flyte.hooks.flyte + +sensors: + - integration-name: Flyte + python-modules: + - airflow.providers.flyte.sensors.flyte + +connection-types: + - hook-class-name: airflow.providers.flyte.hooks.flyte.AirflowFlyteHook + connection-type: flyte diff --git a/airflow/providers/flyte/sensors/__init__.py b/airflow/providers/flyte/sensors/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/airflow/providers/flyte/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/flyte/sensors/flyte.py b/airflow/providers/flyte/sensors/flyte.py new file mode 100644 index 000000000000..0e5d068133ea --- /dev/null +++ b/airflow/providers/flyte/sensors/flyte.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import TYPE_CHECKING, Optional, Sequence + +from airflow.providers.flyte.hooks.flyte import AirflowFlyteHook +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AirflowFlyteSensor(BaseSensorOperator): + """ + Check for the status of a Flyte execution. + + :param execution_name: Required. The name of the execution to check. + :param project: Optional. The project to connect to. + :param domain: Optional. The domain to connect to. + :param flyte_conn_id: Required. The name of the Flyte connection to + get the connection information for Flyte. + """ + + template_fields: Sequence[str] = ("execution_name",) # mypy fix + + def __init__( + self, + execution_name: str, + project: Optional[str] = None, + domain: Optional[str] = None, + flyte_conn_id: str = "flyte_default", + **kwargs, + ): + super().__init__(**kwargs) + self.execution_name = execution_name + self.project = project + self.domain = domain + self.flyte_conn_id = flyte_conn_id + + def poke(self, context: "Context") -> bool: + """Check for the status of a Flyte execution.""" + hook = AirflowFlyteHook(flyte_conn_id=self.flyte_conn_id, project=self.project, domain=self.domain) + remote = hook.create_flyte_remote() + + if hook.execution_status(self.execution_name, remote): + return True + + self.log.info("Waiting for execution %s to complete", self.execution_name) + return False diff --git a/docs/apache-airflow-providers-flyte/commits.rst b/docs/apache-airflow-providers-flyte/commits.rst new file mode 100644 index 000000000000..518db1548855 --- /dev/null +++ b/docs/apache-airflow-providers-flyte/commits.rst @@ -0,0 +1,26 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Package apache-airflow-providers-flyte +-------------------------------------- + +`Flyte `__ + + +This is detailed commit list of changes for versions provider package: ``flyte``. +For high-level changelog, see :doc:`package information including changelog `. diff --git a/docs/apache-airflow-providers-flyte/connections.rst b/docs/apache-airflow-providers-flyte/connections.rst new file mode 100644 index 000000000000..f9eadca1a697 --- /dev/null +++ b/docs/apache-airflow-providers-flyte/connections.rst @@ -0,0 +1,45 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Flyte Connection +================ + +The Flyte connection enables connecting to Flyte through FlyteRemote. + +Configuring the Connection +-------------------------- + +Host(optional) + The FlyteAdmin host. Defaults to localhost. + +Port (optional) + The FlyteAdmin port. Defaults to 30081. + +Login (optional) + ``client_id`` + +Password (optional) + ``client_credentials_secret`` + +Extra (optional) + Specify the ``extra`` parameter as JSON dictionary to provide additional parameters. + * ``project``: The default project to connect to. + * ``domain``: The default domain to connect to. + * ``insecure``: Whether to use SSL or not. + * ``command``: The command to execute to return a token using an external process. + * ``scopes``: List of scopes to request. + * ``auth_mode``: The OAuth mode to use. Defaults to pkce flow. diff --git a/docs/apache-airflow-providers-flyte/index.rst b/docs/apache-airflow-providers-flyte/index.rst new file mode 100644 index 000000000000..072a85a8094d --- /dev/null +++ b/docs/apache-airflow-providers-flyte/index.rst @@ -0,0 +1,83 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +``apache-airflow-providers-flyte`` +================================== + +Content +------- + +.. toctree:: + :maxdepth: 1 + :caption: Guides + + Operators + Connection types + +.. toctree:: + :maxdepth: 1 + :caption: References + + Python API <_api/airflow/providers/flyte/index> + +.. toctree:: + :maxdepth: 1 + :caption: Resources + + Example DAGs + PyPI Repository + Installing from sources + +.. THE REMAINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME! + + +.. toctree:: + :maxdepth: 1 + :caption: Commits + + Detailed list of commits + + +Package apache-airflow-providers-flyte +-------------------------------------- + +`Flyte `__ + + +Release: 1.0.0 + +Provider package +---------------- + +This is a provider package for ``flyte`` provider. All classes for this provider package +are in ``airflow.providers.flyte`` python package. + +Installation +------------ + +You can install this package on top of an existing Airflow 2.1+ installation via +``pip install apache-airflow-providers-flyte`` + +PIP requirements +---------------- + +================================= ================== +PIP package Version required +================================= ================== +``apache-airflow`` ``>=2.1.0`` +``flytekit`` ``>=0.32.0b0`` +================================= ================== diff --git a/docs/apache-airflow-providers-flyte/installing-providers-from-sources.rst b/docs/apache-airflow-providers-flyte/installing-providers-from-sources.rst new file mode 100644 index 000000000000..1c90205d15b3 --- /dev/null +++ b/docs/apache-airflow-providers-flyte/installing-providers-from-sources.rst @@ -0,0 +1,18 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. include:: ../installing-providers-from-sources.rst diff --git a/docs/apache-airflow-providers-flyte/operators/flyte.rst b/docs/apache-airflow-providers-flyte/operators/flyte.rst new file mode 100644 index 000000000000..292996e0455f --- /dev/null +++ b/docs/apache-airflow-providers-flyte/operators/flyte.rst @@ -0,0 +1,50 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:AirflowFlyteOperator: + +AirflowFlyteOperator +==================== + +Use the :class:`~airflow.providers.flyte.operators.AirflowFlyteOperator` to +trigger a task/workflow in Flyte. + +Using the Operator +^^^^^^^^^^^^^^^^^^ + +The AirflowFlyteOperator requires a ``flyte_conn_id`` to fetch all the connection-related +parameters that may be useful to instantiate ``FlyteRemote``. Also, you must give a +``launchplan_name`` — to trigger a workflow, or ``task_name`` — to trigger a task; you can give a +handful of other values that are optional, such as ``project``, ``domain``, ``max_parallelism``, +``raw_data_prefix``, ``assumable_iam_role``, ``kubernetes_service_account``, ``version``, ``inputs``, ``timeout``, and ``poll_interval``. + +The executions will be triggered synchronously by default on Flyte. You can set the ``asynchronous`` parameter to +``True`` to trigger the executions asynchronously. + +An example where the execution is triggered synchronously: + +.. exampleinclude:: /../../airflow/providers/flyte/example_dags/example_flyte.py + :language: python + :start-after: [START howto_operator_flyte_synchronous] + :end-before: [END howto_operator_flyte_synchronous] + +An example where the execution is triggered asynchronously: + +.. exampleinclude:: /../../airflow/providers/flyte/example_dags/example_flyte.py + :language: python + :start-after: [START howto_operator_flyte_asynchronous] + :end-before: [END howto_operator_flyte_asynchronous] diff --git a/docs/apache-airflow/extra-packages-ref.rst b/docs/apache-airflow/extra-packages-ref.rst index ad4fa0152221..57efd4b5f2bd 100644 --- a/docs/apache-airflow/extra-packages-ref.rst +++ b/docs/apache-airflow/extra-packages-ref.rst @@ -174,6 +174,8 @@ Those are extras that add dependencies needed for integration with external serv +---------------------+-----------------------------------------------------+-----------------------------------------------------+ | facebook | ``pip install 'apache-airflow[facebook]'`` | Facebook Social | +---------------------+-----------------------------------------------------+-----------------------------------------------------+ +| flyte | ``pip install 'apache-airflow[flyte]'`` | Flyte hooks and operators | ++---------------------+-----------------------------------------------------+-----------------------------------------------------+ | google | ``pip install 'apache-airflow[google]'`` | Google Cloud | +---------------------+-----------------------------------------------------+-----------------------------------------------------+ | hashicorp | ``pip install 'apache-airflow[hashicorp]'`` | Hashicorp Services (Vault) | diff --git a/docs/integration-logos/flyte/Flyte.png b/docs/integration-logos/flyte/Flyte.png new file mode 100644 index 000000000000..49cdbbbc3419 Binary files /dev/null and b/docs/integration-logos/flyte/Flyte.png differ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b6242d9e0369..0a5e55d7cb92 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -157,6 +157,7 @@ Firehose Firestore Flink FluentD +Flyte Formaturas Fundera GCS @@ -810,6 +811,7 @@ filetype findall firstname fluentd +flyte fmt fn fo diff --git a/setup.py b/setup.py index 3ce9ae73114e..bbe22723c037 100644 --- a/setup.py +++ b/setup.py @@ -300,6 +300,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version flask_appbuilder_authlib = [ 'authlib', ] +flyte = ["flytekit>=0.32.0b0"] github = [ 'pygithub', ] @@ -683,6 +684,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'elasticsearch': elasticsearch, 'exasol': exasol, 'facebook': facebook, + 'flyte': flyte, 'ftp': [], 'github': github, 'google': google, diff --git a/tests/providers/flyte/__init__.py b/tests/providers/flyte/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/providers/flyte/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/flyte/hooks/__init__.py b/tests/providers/flyte/hooks/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/providers/flyte/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/flyte/hooks/test_flyte.py b/tests/providers/flyte/hooks/test_flyte.py new file mode 100644 index 000000000000..8643ac1d6aa0 --- /dev/null +++ b/tests/providers/flyte/hooks/test_flyte.py @@ -0,0 +1,368 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from datetime import timedelta +from unittest import mock + +import pytest +from flytekit.configuration import Config, PlatformConfig +from flytekit.exceptions.user import FlyteEntityNotExistException, FlyteValueException +from flytekit.models.core import execution as core_execution_models +from flytekit.remote import FlyteRemote + +from airflow.exceptions import AirflowException +from airflow.models import Connection +from airflow.providers.flyte.hooks.flyte import AirflowFlyteHook + + +class TestAirflowFlyteHook(unittest.TestCase): + + flyte_conn_id = "flyte_default" + execution_name = "flyte20220330t133856" + conn_type = "flyte" + host = "localhost" + port = "30081" + extra = {"project": "flytesnacks", "domain": "development"} + launchplan_name = "core.basic.hello_world.my_wf" + task_name = "core.basic.hello_world.say_hello" + raw_data_prefix = "s3://flyte-demo/raw_data" + assumable_iam_role = "arn:aws:iam::123456789012:role/example-role" + kubernetes_service_account = "default" + version = "v1" + inputs = {"name": "hello world"} + timeout = timedelta(seconds=3600) + + @classmethod + def get_mock_connection(cls): + return Connection( + conn_id=cls.flyte_conn_id, conn_type=cls.conn_type, host=cls.host, port=cls.port, extra=cls.extra + ) + + @classmethod + def create_remote(cls): + return FlyteRemote( + config=Config( + platform=PlatformConfig(endpoint=":".join([cls.host, cls.port]), insecure=True), + ) + ) + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + def test_trigger_execution_success(self, mock_create_flyte_remote, mock_get_connection): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = AirflowFlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + mock_remote.fetch_launch_plan = mock.MagicMock() + + mock_remote.execute = mock.MagicMock() + + test_hook.trigger_execution( + launchplan_name=self.launchplan_name, + raw_data_prefix=self.raw_data_prefix, + assumable_iam_role=self.assumable_iam_role, + kubernetes_service_account=self.kubernetes_service_account, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + ) + mock_create_flyte_remote.assert_called_once() + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + def test_trigger_task_execution_success(self, mock_create_flyte_remote, mock_get_connection): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = AirflowFlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + mock_remote.fetch_task = mock.MagicMock() + + mock_remote.execute = mock.MagicMock() + + test_hook.trigger_execution( + task_name=self.task_name, + raw_data_prefix=self.raw_data_prefix, + assumable_iam_role=self.assumable_iam_role, + kubernetes_service_account=self.kubernetes_service_account, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + ) + mock_create_flyte_remote.assert_called_once() + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + def test_trigger_execution_failed_to_fetch(self, mock_create_flyte_remote, mock_get_connection): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = AirflowFlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + mock_remote.fetch_launch_plan = mock.MagicMock(side_effect=FlyteEntityNotExistException) + + with pytest.raises(AirflowException): + test_hook.trigger_execution( + launchplan_name=self.launchplan_name, + raw_data_prefix=self.raw_data_prefix, + assumable_iam_role=self.assumable_iam_role, + kubernetes_service_account=self.kubernetes_service_account, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + ) + mock_create_flyte_remote.assert_called_once() + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + def test_trigger_execution_failed_to_trigger(self, mock_create_flyte_remote, mock_get_connection): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = AirflowFlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + mock_remote.fetch_launch_plan = mock.MagicMock() + mock_remote.execute = mock.MagicMock(side_effect=FlyteValueException) + + with pytest.raises(AirflowException): + test_hook.trigger_execution( + launchplan_name=self.launchplan_name, + raw_data_prefix=self.raw_data_prefix, + assumable_iam_role=self.assumable_iam_role, + kubernetes_service_account=self.kubernetes_service_account, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + ) + mock_create_flyte_remote.assert_called_once() + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id") + def test_wait_for_execution_succeeded( + self, mock_execution_id, mock_create_flyte_remote, mock_get_connection + ): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = AirflowFlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock(side_effect=[test_hook.SUCCEEDED]) + type(mock_get_execution().closure).phase = mock_phase + + test_hook.wait_for_execution( + execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=3) + ) + mock_create_flyte_remote.assert_called_once() + mock_execution_id.assert_called_once_with(self.execution_name) + assert mock_phase.call_count == 1 + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id") + def test_wait_for_execution_failed( + self, mock_execution_id, mock_create_flyte_remote, mock_get_connection + ): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = AirflowFlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock( + side_effect=[core_execution_models.WorkflowExecutionPhase.RUNNING, test_hook.FAILED] + ) + type(mock_get_execution().closure).phase = mock_phase + + with pytest.raises(AirflowException): + test_hook.wait_for_execution( + execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=3) + ) + + mock_create_flyte_remote.assert_called_once() + mock_execution_id.assert_has_calls([mock.call(self.execution_name)] * 2) + assert mock_phase.call_count == 2 + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id") + def test_wait_for_execution_queued_succeeded( + self, mock_execution_id, mock_create_flyte_remote, mock_get_connection + ): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = AirflowFlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock( + side_effect=[core_execution_models.WorkflowExecutionPhase.QUEUED, test_hook.SUCCEEDED] + ) + type(mock_get_execution().closure).phase = mock_phase + + test_hook.wait_for_execution( + execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=3) + ) + + mock_create_flyte_remote.assert_called_once() + mock_execution_id.assert_has_calls([mock.call(self.execution_name)] * 2) + assert mock_phase.call_count == 2 + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id") + def test_wait_for_execution_timedout( + self, mock_execution_id, mock_create_flyte_remote, mock_get_connection + ): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = AirflowFlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock( + side_effect=[ + core_execution_models.WorkflowExecutionPhase.QUEUED, + core_execution_models.WorkflowExecutionPhase.RUNNING, + test_hook.TIMED_OUT, + ] + ) + type(mock_get_execution().closure).phase = mock_phase + + with pytest.raises(AirflowException): + test_hook.wait_for_execution( + execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=3) + ) + + mock_create_flyte_remote.assert_called_once() + mock_execution_id.assert_has_calls([mock.call(self.execution_name)] * 3) + assert mock_phase.call_count == 3 + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id") + def test_wait_for_execution_timeout( + self, mock_execution_id, mock_create_flyte_remote, mock_get_connection + ): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = AirflowFlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock( + side_effect=[ + core_execution_models.WorkflowExecutionPhase.QUEUED, + core_execution_models.WorkflowExecutionPhase.RUNNING, + core_execution_models.WorkflowExecutionPhase.RUNNING, + ] + ) + type(mock_get_execution().closure).phase = mock_phase + + with pytest.raises(AirflowException): + test_hook.wait_for_execution( + execution_name=self.execution_name, + timeout=timedelta(seconds=1), + poll_interval=timedelta(seconds=3), + ) + + mock_create_flyte_remote.assert_called_once() + mock_execution_id.assert_called() + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id") + def test_wait_for_execution_aborted( + self, mock_execution_id, mock_create_flyte_remote, mock_get_connection + ): + mock_connection = self.get_mock_connection() + mock_get_connection.return_value = mock_connection + + test_hook = AirflowFlyteHook(self.flyte_conn_id) + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock( + side_effect=[ + core_execution_models.WorkflowExecutionPhase.RUNNING, + core_execution_models.WorkflowExecutionPhase.ABORTED, + ] + ) + type(mock_get_execution().closure).phase = mock_phase + + with pytest.raises(AirflowException): + test_hook.wait_for_execution( + execution_name=self.execution_name, + timeout=self.timeout, + poll_interval=timedelta(seconds=3), + ) + + mock_create_flyte_remote.assert_called_once() + mock_execution_id.assert_has_calls([mock.call(self.execution_name)] * 2) + assert mock_phase.call_count == 2 diff --git a/tests/providers/flyte/operators/__init__.py b/tests/providers/flyte/operators/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/providers/flyte/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/flyte/operators/test_flyte.py b/tests/providers/flyte/operators/test_flyte.py new file mode 100644 index 000000000000..72f44b05cb13 --- /dev/null +++ b/tests/providers/flyte/operators/test_flyte.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from datetime import timedelta +from unittest import mock + +from airflow.models import Connection +from airflow.models.dagrun import DagRun +from airflow.providers.flyte.operators.flyte import AirflowFlyteOperator + + +class TestAirflowFlyteOperator(unittest.TestCase): + + task_id = "test_flyte_operator" + flyte_conn_id = "flyte_default" + run_id = "manual__2022-03-30T13:55:08.715694+00:00" + conn_type = "flyte" + host = "localhost" + port = "30081" + project = "flytesnacks" + domain = "development" + launchplan_name = "core.basic.hello_world.my_wf" + raw_data_prefix = "s3://flyte-demo/raw_data" + assumable_iam_role = "arn:aws:iam::123456789012:role/example-role" + kubernetes_service_account = "default" + labels = {"key1": "value1"} + version = "v1" + inputs = {"name": "hello world"} + timeout = timedelta(seconds=3600) + execution_name = "testf20220330t135508" + + @classmethod + def get_connection(cls): + return Connection( + conn_id=cls.flyte_conn_id, + conn_type=cls.conn_type, + host=cls.host, + port=cls.port, + extra={"project": cls.project, "domain": cls.domain}, + ) + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.trigger_execution") + @mock.patch( + "airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.wait_for_execution", + return_value=None, + ) + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + def test_execute(self, mock_get_connection, mock_wait_for_execution, mock_trigger_execution): + mock_get_connection.return_value = self.get_connection() + + operator = AirflowFlyteOperator( + task_id=self.task_id, + flyte_conn_id=self.flyte_conn_id, + project=self.project, + domain=self.domain, + launchplan_name=self.launchplan_name, + raw_data_prefix=self.raw_data_prefix, + assumable_iam_role=self.assumable_iam_role, + kubernetes_service_account=self.kubernetes_service_account, + labels=self.labels, + version=self.version, + inputs=self.inputs, + timeout=self.timeout, + ) + result = operator.execute({"dag_run": DagRun(run_id=self.run_id), "task": operator}) + + assert result == self.execution_name + mock_get_connection.assert_called_once_with(self.flyte_conn_id) + mock_trigger_execution.assert_called_once_with( + launchplan_name=self.launchplan_name, + task_name=None, + max_parallelism=None, + raw_data_prefix=self.raw_data_prefix, + assumable_iam_role=self.assumable_iam_role, + kubernetes_service_account=self.kubernetes_service_account, + labels=self.labels, + annotations={}, + version=self.version, + inputs=self.inputs, + execution_name=self.execution_name, + ) + mock_wait_for_execution.assert_called_once_with( + execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=30) + ) + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.trigger_execution", return_value=None) + @mock.patch( + "airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.wait_for_execution", + return_value=None, + ) + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.terminate") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + def test_on_kill_success( + self, mock_get_connection, mock_terminate, mock_wait_for_execution, mock_trigger_execution + ): + mock_get_connection.return_value = self.get_connection() + + operator = AirflowFlyteOperator( + task_id=self.task_id, + flyte_conn_id=self.flyte_conn_id, + project=self.project, + domain=self.domain, + launchplan_name=self.launchplan_name, + inputs=self.inputs, + timeout=self.timeout, + ) + operator.execute({"dag_run": DagRun(run_id=self.run_id), "task": operator}) + operator.on_kill() + + mock_get_connection.has_calls([mock.call(self.flyte_conn_id)] * 2) + mock_trigger_execution.assert_called() + mock_wait_for_execution.assert_called_once_with( + execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=30) + ) + mock_terminate.assert_called_once_with(execution_name=self.execution_name, cause="Killed by Airflow") + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.terminate") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + def test_on_kill_noop(self, mock_get_connection, mock_terminate): + mock_get_connection.return_value = self.get_connection() + + operator = AirflowFlyteOperator( + task_id=self.task_id, + flyte_conn_id=self.flyte_conn_id, + project=self.project, + domain=self.domain, + launchplan_name=self.launchplan_name, + inputs=self.inputs, + ) + operator.on_kill() + + mock_get_connection.assert_not_called() + mock_terminate.assert_not_called() diff --git a/tests/providers/flyte/sensors/__init__.py b/tests/providers/flyte/sensors/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/providers/flyte/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/flyte/sensors/test_flyte.py b/tests/providers/flyte/sensors/test_flyte.py new file mode 100644 index 000000000000..ebe7746912bb --- /dev/null +++ b/tests/providers/flyte/sensors/test_flyte.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from unittest import mock + +import pytest +from flytekit.configuration import Config, PlatformConfig +from flytekit.models.core import execution as core_execution_models +from flytekit.remote import FlyteRemote + +from airflow import AirflowException +from airflow.models import Connection +from airflow.providers.flyte.hooks.flyte import AirflowFlyteHook +from airflow.providers.flyte.sensors.flyte import AirflowFlyteSensor + + +class TestAirflowFlyteSensor(unittest.TestCase): + + task_id = "test_flyte_sensor" + flyte_conn_id = "flyte_default" + conn_type = "flyte" + host = "localhost" + port = "30081" + project = "flytesnacks" + domain = "development" + execution_name = "testf20220330t135508" + + @classmethod + def get_connection(cls): + return Connection( + conn_id=cls.flyte_conn_id, + conn_type=cls.conn_type, + host=cls.host, + port=cls.port, + extra={"project": cls.project, "domain": cls.domain}, + ) + + @classmethod + def create_remote(cls): + return FlyteRemote( + config=Config( + platform=PlatformConfig(endpoint=":".join([cls.host, cls.port]), insecure=True), + ) + ) + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id") + def test_poke_done(self, mock_execution_id, mock_create_flyte_remote, mock_get_connection): + mock_get_connection.return_value = self.get_connection() + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock(return_value=AirflowFlyteHook.SUCCEEDED) + type(mock_get_execution().closure).phase = mock_phase + + sensor = AirflowFlyteSensor( + task_id=self.task_id, + execution_name=self.execution_name, + project=self.project, + domain=self.domain, + flyte_conn_id=self.flyte_conn_id, + ) + + return_value = sensor.poke({}) + + assert return_value + mock_create_flyte_remote.assert_called_once() + mock_execution_id.assert_called_with(self.execution_name) + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id") + def test_poke_failed(self, mock_execution_id, mock_create_flyte_remote, mock_get_connection): + mock_get_connection.return_value = self.get_connection() + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + sensor = AirflowFlyteSensor( + task_id=self.task_id, + execution_name=self.execution_name, + project=self.project, + domain=self.domain, + flyte_conn_id=self.flyte_conn_id, + ) + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + for phase in [AirflowFlyteHook.ABORTED, AirflowFlyteHook.FAILED, AirflowFlyteHook.TIMED_OUT]: + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock(return_value=phase) + type(mock_get_execution().closure).phase = mock_phase + + with pytest.raises(AirflowException): + sensor.poke({}) + + mock_create_flyte_remote.has_calls([mock.call()] * 3) + mock_execution_id.has_calls([mock.call(self.execution_name)] * 3) + + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote") + @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id") + def test_poke_running(self, mock_execution_id, mock_create_flyte_remote, mock_get_connection): + mock_get_connection.return_value = self.get_connection() + + mock_remote = self.create_remote() + mock_create_flyte_remote.return_value = mock_remote + + execution_id = mock.MagicMock() + mock_execution_id.return_value = execution_id + + mock_get_execution = mock.MagicMock() + mock_remote.client.get_execution = mock_get_execution + mock_phase = mock.PropertyMock(return_value=core_execution_models.WorkflowExecutionPhase.RUNNING) + type(mock_get_execution().closure).phase = mock_phase + + sensor = AirflowFlyteSensor( + task_id=self.task_id, + execution_name=self.execution_name, + project=self.project, + domain=self.domain, + flyte_conn_id=self.flyte_conn_id, + ) + + return_value = sensor.poke({}) + assert not return_value + + mock_create_flyte_remote.assert_called_once() + mock_execution_id.assert_called_with(self.execution_name)