diff --git a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml
index fd594fa8408c..102cfec1e17a 100644
--- a/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/airflow_providers_bug_report.yml
@@ -51,6 +51,7 @@ body:
- elasticsearch
- exasol
- facebook
+ - flyte
- ftp
- github
- google
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 34b220405e6b..92f933b02f8c 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -620,7 +620,7 @@ apache.druid, apache.hdfs, apache.hive, apache.kylin, apache.livy, apache.pig, a
apache.spark, apache.sqoop, apache.webhdfs, asana, async, atlas, aws, azure, cassandra, celery,
cgroups, cloudant, cncf.kubernetes, crypto, dask, databricks, datadog, dbt.cloud, deprecated_api,
devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker, druid, elasticsearch,
-exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc,
+exasol, facebook, flyte, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc,
hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, jira, kerberos, kubernetes, ldap,
leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql,
neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, pandas, papermill, password, pinot, plexus,
diff --git a/INSTALL b/INSTALL
index 4354d5e33f55..5bcb8c462c6c 100644
--- a/INSTALL
+++ b/INSTALL
@@ -100,7 +100,7 @@ apache.druid, apache.hdfs, apache.hive, apache.kylin, apache.livy, apache.pig, a
apache.spark, apache.sqoop, apache.webhdfs, asana, async, atlas, aws, azure, cassandra, celery,
cgroups, cloudant, cncf.kubernetes, crypto, dask, databricks, datadog, dbt.cloud, deprecated_api,
devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker, druid, elasticsearch,
-exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc,
+exasol, facebook, flyte, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth, grpc,
hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, jira, kerberos, kubernetes, ldap,
leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql,
neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, pandas, papermill, password, pinot, plexus,
diff --git a/airflow/providers/flyte/CHANGELOG.rst b/airflow/providers/flyte/CHANGELOG.rst
new file mode 100644
index 000000000000..cef7dda80708
--- /dev/null
+++ b/airflow/providers/flyte/CHANGELOG.rst
@@ -0,0 +1,25 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+Changelog
+---------
+
+1.0.0
+.....
+
+Initial version of the provider.
diff --git a/airflow/providers/flyte/__init__.py b/airflow/providers/flyte/__init__.py
new file mode 100644
index 000000000000..13a83393a912
--- /dev/null
+++ b/airflow/providers/flyte/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/flyte/example_dags/__init__.py b/airflow/providers/flyte/example_dags/__init__.py
new file mode 100644
index 000000000000..13a83393a912
--- /dev/null
+++ b/airflow/providers/flyte/example_dags/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/flyte/example_dags/example_flyte.py b/airflow/providers/flyte/example_dags/example_flyte.py
new file mode 100644
index 000000000000..ac0db5a0f827
--- /dev/null
+++ b/airflow/providers/flyte/example_dags/example_flyte.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Example DAG demonstrating the usage of the AirflowFlyteOperator."""
+
+from datetime import datetime, timedelta
+
+from airflow import DAG
+from airflow.providers.flyte.operators.flyte import AirflowFlyteOperator
+from airflow.providers.flyte.sensors.flyte import AirflowFlyteSensor
+
+with DAG(
+ dag_id="example_flyte_operator",
+ schedule_interval=None,
+ start_date=datetime(2021, 1, 1),
+ dagrun_timeout=timedelta(minutes=60),
+ tags=["example"],
+ catchup=False,
+) as dag:
+
+ # [START howto_operator_flyte_synchronous]
+ sync_predictions = AirflowFlyteOperator(
+ task_id="flyte_example_sync",
+ flyte_conn_id="flyte_conn_example",
+ project="flytesnacks",
+ domain="development",
+ launchplan_name="core.basic.lp.my_wf",
+ assumable_iam_role="default",
+ kubernetes_service_account="demo",
+ version="v1",
+ inputs={"val": 19},
+ timeout=timedelta(seconds=3600),
+ )
+ # [END howto_operator_flyte_synchronous]
+
+ # [START howto_operator_flyte_asynchronous]
+ async_predictions = AirflowFlyteOperator(
+ task_id="flyte_example_async",
+ flyte_conn_id="flyte_conn_example",
+ project="flytesnacks",
+ domain="development",
+ launchplan_name="core.basic.lp.my_wf",
+ max_parallelism=2,
+ raw_data_prefix="s3://flyte-demo/raw_data",
+ assumable_iam_role="default",
+ kubernetes_service_account="demo",
+ version="v1",
+ inputs={"val": 19},
+ timeout=timedelta(seconds=3600),
+ asynchronous=True,
+ )
+
+ predictions_sensor = AirflowFlyteSensor(
+ task_id="predictions_sensor",
+ execution_name=async_predictions.output,
+ project="flytesnacks",
+ domain="development",
+ flyte_conn_id="flyte_conn_example",
+ )
+ # [END howto_operator_flyte_asynchronous]
+
+ # Task dependency created via `XComArgs`:
+ async_predictions >> predictions_sensor
diff --git a/airflow/providers/flyte/hooks/__init__.py b/airflow/providers/flyte/hooks/__init__.py
new file mode 100644
index 000000000000..13a83393a912
--- /dev/null
+++ b/airflow/providers/flyte/hooks/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/flyte/hooks/flyte.py b/airflow/providers/flyte/hooks/flyte.py
new file mode 100644
index 000000000000..ec8c965a602f
--- /dev/null
+++ b/airflow/providers/flyte/hooks/flyte.py
@@ -0,0 +1,205 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import time
+from datetime import datetime, timedelta
+from typing import Any, Dict, Optional
+
+from flytekit.configuration import AuthType, Config, PlatformConfig
+from flytekit.exceptions.user import FlyteEntityNotExistException
+from flytekit.models.common import Annotations, AuthRole, Labels
+from flytekit.models.core import execution as core_execution_models
+from flytekit.models.core.identifier import WorkflowExecutionIdentifier
+from flytekit.remote.remote import FlyteRemote, Options
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+
+
+class AirflowFlyteHook(BaseHook):
+ """
+ Interact with the FlyteRemote API.
+
+ :param flyte_conn_id: Required. The name of the Flyte connection to get
+ the connection information for Flyte.
+ :param project: Optional. The project under consideration.
+ :param domain: Optional. The domain under consideration.
+ """
+
+ SUCCEEDED = core_execution_models.WorkflowExecutionPhase.SUCCEEDED
+ FAILED = core_execution_models.WorkflowExecutionPhase.FAILED
+ TIMED_OUT = core_execution_models.WorkflowExecutionPhase.TIMED_OUT
+ ABORTED = core_execution_models.WorkflowExecutionPhase.ABORTED
+
+ flyte_conn_id = "flyte_default"
+ conn_type = "flyte"
+
+ def __init__(
+ self, flyte_conn_id: str = flyte_conn_id, project: Optional[str] = None, domain: Optional[str] = None
+ ) -> None:
+ super().__init__()
+ self.flyte_conn_id = flyte_conn_id
+ self.flyte_conn = self.get_connection(self.flyte_conn_id)
+ self.project = project or self.flyte_conn.extra_dejson.get("project")
+ self.domain = domain or self.flyte_conn.extra_dejson.get("domain")
+
+ if not (self.project and self.domain):
+ raise AirflowException("Please provide a project and domain.")
+
+ def execution_id(self, execution_name: str) -> WorkflowExecutionIdentifier:
+ """Get the execution id."""
+ return WorkflowExecutionIdentifier(self.project, self.domain, execution_name)
+
+ def create_flyte_remote(self) -> FlyteRemote:
+ """Create a FlyteRemote object."""
+ remote = FlyteRemote(
+ config=Config(
+ platform=PlatformConfig(
+ endpoint=":".join([self.flyte_conn.host, self.flyte_conn.port])
+ if (self.flyte_conn.host and self.flyte_conn.port)
+ else (self.flyte_conn.host or "localhost:30081"),
+ insecure=self.flyte_conn.extra_dejson.get("insecure", False),
+ client_id=self.flyte_conn.login or None,
+ client_credentials_secret=self.flyte_conn.password or None,
+ command=self.flyte_conn.extra_dejson.get("command", None),
+ scopes=self.flyte_conn.extra_dejson.get("scopes", None),
+ auth_mode=AuthType(self.flyte_conn.extra_dejson.get("auth_mode", "standard")),
+ )
+ ),
+ )
+ return remote
+
+ def trigger_execution(
+ self,
+ execution_name: str,
+ launchplan_name: Optional[str] = None,
+ task_name: Optional[str] = None,
+ max_parallelism: Optional[int] = None,
+ raw_data_prefix: Optional[str] = None,
+ assumable_iam_role: Optional[str] = None,
+ kubernetes_service_account: Optional[str] = None,
+ labels: Optional[Dict[str, str]] = None,
+ annotations: Optional[Dict[str, str]] = None,
+ version: Optional[str] = None,
+ inputs: Dict[str, Any] = {},
+ ) -> None:
+ """
+ Trigger an execution.
+
+ :param execution_name: Required. The name of the execution to trigger.
+ :param launchplan_name: Optional. The name of the launchplan to trigger.
+ :param task_name: Optional. The name of the task to trigger.
+ :param max_parallelism: Optional. The maximum number of parallel executions to allow.
+ :param raw_data_prefix: Optional. The prefix to use for raw data.
+ :param assumable_iam_role: Optional. The assumable IAM role to use.
+ :param kubernetes_service_account: Optional. The kubernetes service account to use.
+ :param labels: Optional. The labels to use.
+ :param annotations: Optional. The annotations to use.
+ :param version: Optional. The version of the launchplan to trigger.
+ :param inputs: Optional. The inputs to the launchplan.
+ """
+ if (not (task_name or launchplan_name)) or (task_name and launchplan_name):
+ raise AirflowException("Either task_name or launchplan_name is required.")
+
+ remote = self.create_flyte_remote()
+ try:
+ if launchplan_name:
+ flyte_entity = remote.fetch_launch_plan(
+ name=launchplan_name, project=self.project, domain=self.domain, version=version
+ )
+ elif task_name:
+ flyte_entity = remote.fetch_task(
+ name=task_name, project=self.project, domain=self.domain, version=version
+ )
+ except FlyteEntityNotExistException as e:
+ raise AirflowException(f"Failed to fetch entity: {e}")
+
+ try:
+ remote.execute(
+ flyte_entity,
+ inputs=inputs,
+ project=self.project,
+ domain=self.domain,
+ execution_name=execution_name,
+ options=Options(
+ raw_data_prefix=raw_data_prefix,
+ max_parallelism=max_parallelism,
+ auth_role=AuthRole(
+ assumable_iam_role=assumable_iam_role,
+ kubernetes_service_account=kubernetes_service_account,
+ ),
+ labels=Labels(labels),
+ annotations=Annotations(annotations),
+ ),
+ )
+ except Exception as e:
+ raise AirflowException(f"Failed to trigger execution: {e}")
+
+ def execution_status(self, execution_name: str, remote: FlyteRemote):
+ phase = remote.client.get_execution(self.execution_id(execution_name)).closure.phase
+
+ if phase == self.SUCCEEDED:
+ return True
+ elif phase == self.FAILED:
+ raise AirflowException(f"Execution {execution_name} failed")
+ elif phase == self.TIMED_OUT:
+ raise AirflowException(f"Execution {execution_name} timedout")
+ elif phase == self.ABORTED:
+ raise AirflowException(f"Execution {execution_name} aborted")
+ else:
+ return False
+
+ def wait_for_execution(
+ self,
+ execution_name: str,
+ timeout: Optional[timedelta] = None,
+ poll_interval: timedelta = timedelta(seconds=30),
+ ) -> None:
+ """
+ Helper method which polls an execution to check the status.
+
+ :param execution: Required. The execution to check.
+ :param timeout: Optional. The timeout to wait for the execution to finish.
+ :param poll_interval: Optional. The interval between checks to poll the execution.
+ """
+ remote = self.create_flyte_remote()
+
+ time_to_give_up = datetime.max if timeout is None else datetime.utcnow() + timeout
+
+ while datetime.utcnow() < time_to_give_up:
+ time.sleep(poll_interval.total_seconds())
+
+ if self.execution_status(execution_name, remote):
+ return
+ continue
+
+ raise AirflowException(f"Execution {execution_name} timedout")
+
+ def terminate(
+ self,
+ execution_name: str,
+ cause: str,
+ ) -> None:
+ """
+ Terminate an execution.
+
+ :param execution: Required. The execution to terminate.
+ :param cause: Required. The cause of the termination.
+ """
+ remote = self.create_flyte_remote()
+ execution_id = self.execution_id(execution_name)
+ remote.client.terminate_execution(id=execution_id, cause=cause)
diff --git a/airflow/providers/flyte/operators/__init__.py b/airflow/providers/flyte/operators/__init__.py
new file mode 100644
index 000000000000..13a83393a912
--- /dev/null
+++ b/airflow/providers/flyte/operators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/flyte/operators/flyte.py b/airflow/providers/flyte/operators/flyte.py
new file mode 100644
index 000000000000..88978626c677
--- /dev/null
+++ b/airflow/providers/flyte/operators/flyte.py
@@ -0,0 +1,148 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import re
+from datetime import timedelta
+from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.providers.flyte.hooks.flyte import AirflowFlyteHook
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class AirflowFlyteOperator(BaseOperator):
+ """
+ Launch Flyte executions from within Airflow.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AirflowFlyteOperator`
+
+ :param flyte_conn_id: Required. The connection to Flyte setup, containing metadata.
+ :param project: Optional. The project to connect to.
+ :param domain: Optional. The domain to connect to.
+ :param launchplan_name: Optional. The name of the launchplan to trigger.
+ :param task_name: Optional. The name of the task to trigger.
+ :param max_parallelism: Optional. The maximum number of parallel executions to allow.
+ :param raw_data_prefix: Optional. The prefix to use for raw data.
+ :param assumable_iam_role: Optional. The IAM role to assume.
+ :param kubernetes_service_account: Optional. The Kubernetes service account to use.
+ :param labels: Optional. Custom labels to be applied to the execution resource.
+ :param annotations: Optional. Custom annotations to be applied to the execution resource.
+ :param version: Optional. The version of the launchplan/task to trigger.
+ :param inputs: Optional. The inputs to the launchplan/task.
+ :param timeout: Optional. The timeout to wait for the execution to finish.
+ :param poll_interval: Optional. The interval between checks to poll the execution.
+ :param asynchronous: Optional. Whether to wait for the execution to finish or not.
+ """
+
+ template_fields: Sequence[str] = ("flyte_conn_id",) # mypy fix
+
+ def __init__(
+ self,
+ flyte_conn_id: str,
+ project: Optional[str] = None,
+ domain: Optional[str] = None,
+ launchplan_name: Optional[str] = None,
+ task_name: Optional[str] = None,
+ max_parallelism: Optional[int] = None,
+ raw_data_prefix: Optional[str] = None,
+ assumable_iam_role: Optional[str] = None,
+ kubernetes_service_account: Optional[str] = None,
+ labels: Dict[str, str] = {},
+ annotations: Dict[str, str] = {},
+ version: Optional[str] = None,
+ inputs: Dict[str, Any] = {},
+ timeout: Optional[timedelta] = None,
+ poll_interval: timedelta = timedelta(seconds=30),
+ asynchronous: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.flyte_conn_id = flyte_conn_id
+ self.project = project
+ self.domain = domain
+ self.launchplan_name = launchplan_name
+ self.task_name = task_name
+ self.max_parallelism = max_parallelism
+ self.raw_data_prefix = raw_data_prefix
+ self.assumable_iam_role = assumable_iam_role
+ self.kubernetes_service_account = kubernetes_service_account
+ self.labels = labels
+ self.annotations = annotations
+ self.version = version
+ self.inputs = inputs
+ self.timeout = timeout
+ self.poll_interval = poll_interval
+ self.asynchronous = asynchronous
+ self.execution_name: str = ""
+
+ if (not (self.task_name or self.launchplan_name)) or (self.task_name and self.launchplan_name):
+ raise AirflowException("Either task_name or launchplan_name is required.")
+
+ def execute(self, context: "Context") -> str:
+ """Trigger an execution and wait for it to finish."""
+
+ # create a deterministic execution name
+ task_id = re.sub(r"[\W_]+", "", context["task"].task_id)[:5]
+ self.execution_name = task_id + re.sub(
+ r"[\W_]+",
+ "",
+ context["dag_run"].run_id.split("__")[-1].lower(),
+ )[: (20 - len(task_id))]
+
+ hook = AirflowFlyteHook(flyte_conn_id=self.flyte_conn_id, project=self.project, domain=self.domain)
+ hook.trigger_execution(
+ launchplan_name=self.launchplan_name,
+ task_name=self.task_name,
+ max_parallelism=self.max_parallelism,
+ raw_data_prefix=self.raw_data_prefix,
+ assumable_iam_role=self.assumable_iam_role,
+ kubernetes_service_account=self.kubernetes_service_account,
+ labels=self.labels,
+ annotations=self.annotations,
+ version=self.version,
+ inputs=self.inputs,
+ execution_name=self.execution_name,
+ )
+ self.log.info("Execution %s submitted", self.execution_name)
+
+ if not self.asynchronous:
+ self.log.info("Waiting for execution %s to complete", self.execution_name)
+ hook.wait_for_execution(
+ execution_name=self.execution_name,
+ timeout=self.timeout,
+ poll_interval=self.poll_interval,
+ )
+ self.log.info("Execution %s completed", self.execution_name)
+
+ return self.execution_name
+
+ def on_kill(self) -> None:
+ """Kill the execution."""
+ if self.execution_name:
+ print(f"Killing execution {self.execution_name}")
+ hook = AirflowFlyteHook(
+ flyte_conn_id=self.flyte_conn_id, project=self.project, domain=self.domain
+ )
+ hook.terminate(
+ execution_name=self.execution_name,
+ cause="Killed by Airflow",
+ )
diff --git a/airflow/providers/flyte/provider.yaml b/airflow/providers/flyte/provider.yaml
new file mode 100644
index 000000000000..4deff0810944
--- /dev/null
+++ b/airflow/providers/flyte/provider.yaml
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+---
+package-name: apache-airflow-providers-flyte
+name: Flyte
+description: |
+ `Flyte `__
+versions:
+ - 1.0.0
+
+additional-dependencies:
+ - apache-airflow>=2.1.0
+
+integrations:
+ - integration-name: Flyte
+ external-doc-url: https://docs.flyte.org/
+ logo: /integration-logos/flyte/Flyte.png
+ how-to-guide:
+ - /docs/apache-airflow-providers-flyte/operators/flyte.rst
+ tags: [service]
+
+operators:
+ - integration-name: Flyte
+ python-modules:
+ - airflow.providers.flyte.operators.flyte
+
+hooks:
+ - integration-name: Flyte
+ python-modules:
+ - airflow.providers.flyte.hooks.flyte
+
+sensors:
+ - integration-name: Flyte
+ python-modules:
+ - airflow.providers.flyte.sensors.flyte
+
+connection-types:
+ - hook-class-name: airflow.providers.flyte.hooks.flyte.AirflowFlyteHook
+ connection-type: flyte
diff --git a/airflow/providers/flyte/sensors/__init__.py b/airflow/providers/flyte/sensors/__init__.py
new file mode 100644
index 000000000000..13a83393a912
--- /dev/null
+++ b/airflow/providers/flyte/sensors/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/flyte/sensors/flyte.py b/airflow/providers/flyte/sensors/flyte.py
new file mode 100644
index 000000000000..0e5d068133ea
--- /dev/null
+++ b/airflow/providers/flyte/sensors/flyte.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import TYPE_CHECKING, Optional, Sequence
+
+from airflow.providers.flyte.hooks.flyte import AirflowFlyteHook
+from airflow.sensors.base import BaseSensorOperator
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class AirflowFlyteSensor(BaseSensorOperator):
+ """
+ Check for the status of a Flyte execution.
+
+ :param execution_name: Required. The name of the execution to check.
+ :param project: Optional. The project to connect to.
+ :param domain: Optional. The domain to connect to.
+ :param flyte_conn_id: Required. The name of the Flyte connection to
+ get the connection information for Flyte.
+ """
+
+ template_fields: Sequence[str] = ("execution_name",) # mypy fix
+
+ def __init__(
+ self,
+ execution_name: str,
+ project: Optional[str] = None,
+ domain: Optional[str] = None,
+ flyte_conn_id: str = "flyte_default",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.execution_name = execution_name
+ self.project = project
+ self.domain = domain
+ self.flyte_conn_id = flyte_conn_id
+
+ def poke(self, context: "Context") -> bool:
+ """Check for the status of a Flyte execution."""
+ hook = AirflowFlyteHook(flyte_conn_id=self.flyte_conn_id, project=self.project, domain=self.domain)
+ remote = hook.create_flyte_remote()
+
+ if hook.execution_status(self.execution_name, remote):
+ return True
+
+ self.log.info("Waiting for execution %s to complete", self.execution_name)
+ return False
diff --git a/docs/apache-airflow-providers-flyte/commits.rst b/docs/apache-airflow-providers-flyte/commits.rst
new file mode 100644
index 000000000000..518db1548855
--- /dev/null
+++ b/docs/apache-airflow-providers-flyte/commits.rst
@@ -0,0 +1,26 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+Package apache-airflow-providers-flyte
+--------------------------------------
+
+`Flyte `__
+
+
+This is detailed commit list of changes for versions provider package: ``flyte``.
+For high-level changelog, see :doc:`package information including changelog `.
diff --git a/docs/apache-airflow-providers-flyte/connections.rst b/docs/apache-airflow-providers-flyte/connections.rst
new file mode 100644
index 000000000000..f9eadca1a697
--- /dev/null
+++ b/docs/apache-airflow-providers-flyte/connections.rst
@@ -0,0 +1,45 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Flyte Connection
+================
+
+The Flyte connection enables connecting to Flyte through FlyteRemote.
+
+Configuring the Connection
+--------------------------
+
+Host(optional)
+ The FlyteAdmin host. Defaults to localhost.
+
+Port (optional)
+ The FlyteAdmin port. Defaults to 30081.
+
+Login (optional)
+ ``client_id``
+
+Password (optional)
+ ``client_credentials_secret``
+
+Extra (optional)
+ Specify the ``extra`` parameter as JSON dictionary to provide additional parameters.
+ * ``project``: The default project to connect to.
+ * ``domain``: The default domain to connect to.
+ * ``insecure``: Whether to use SSL or not.
+ * ``command``: The command to execute to return a token using an external process.
+ * ``scopes``: List of scopes to request.
+ * ``auth_mode``: The OAuth mode to use. Defaults to pkce flow.
diff --git a/docs/apache-airflow-providers-flyte/index.rst b/docs/apache-airflow-providers-flyte/index.rst
new file mode 100644
index 000000000000..072a85a8094d
--- /dev/null
+++ b/docs/apache-airflow-providers-flyte/index.rst
@@ -0,0 +1,83 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+``apache-airflow-providers-flyte``
+==================================
+
+Content
+-------
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Guides
+
+ Operators
+ Connection types
+
+.. toctree::
+ :maxdepth: 1
+ :caption: References
+
+ Python API <_api/airflow/providers/flyte/index>
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Resources
+
+ Example DAGs
+ PyPI Repository
+ Installing from sources
+
+.. THE REMAINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME!
+
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Commits
+
+ Detailed list of commits
+
+
+Package apache-airflow-providers-flyte
+--------------------------------------
+
+`Flyte `__
+
+
+Release: 1.0.0
+
+Provider package
+----------------
+
+This is a provider package for ``flyte`` provider. All classes for this provider package
+are in ``airflow.providers.flyte`` python package.
+
+Installation
+------------
+
+You can install this package on top of an existing Airflow 2.1+ installation via
+``pip install apache-airflow-providers-flyte``
+
+PIP requirements
+----------------
+
+================================= ==================
+PIP package Version required
+================================= ==================
+``apache-airflow`` ``>=2.1.0``
+``flytekit`` ``>=0.32.0b0``
+================================= ==================
diff --git a/docs/apache-airflow-providers-flyte/installing-providers-from-sources.rst b/docs/apache-airflow-providers-flyte/installing-providers-from-sources.rst
new file mode 100644
index 000000000000..1c90205d15b3
--- /dev/null
+++ b/docs/apache-airflow-providers-flyte/installing-providers-from-sources.rst
@@ -0,0 +1,18 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. include:: ../installing-providers-from-sources.rst
diff --git a/docs/apache-airflow-providers-flyte/operators/flyte.rst b/docs/apache-airflow-providers-flyte/operators/flyte.rst
new file mode 100644
index 000000000000..292996e0455f
--- /dev/null
+++ b/docs/apache-airflow-providers-flyte/operators/flyte.rst
@@ -0,0 +1,50 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/operator:AirflowFlyteOperator:
+
+AirflowFlyteOperator
+====================
+
+Use the :class:`~airflow.providers.flyte.operators.AirflowFlyteOperator` to
+trigger a task/workflow in Flyte.
+
+Using the Operator
+^^^^^^^^^^^^^^^^^^
+
+The AirflowFlyteOperator requires a ``flyte_conn_id`` to fetch all the connection-related
+parameters that may be useful to instantiate ``FlyteRemote``. Also, you must give a
+``launchplan_name`` — to trigger a workflow, or ``task_name`` — to trigger a task; you can give a
+handful of other values that are optional, such as ``project``, ``domain``, ``max_parallelism``,
+``raw_data_prefix``, ``assumable_iam_role``, ``kubernetes_service_account``, ``version``, ``inputs``, ``timeout``, and ``poll_interval``.
+
+The executions will be triggered synchronously by default on Flyte. You can set the ``asynchronous`` parameter to
+``True`` to trigger the executions asynchronously.
+
+An example where the execution is triggered synchronously:
+
+.. exampleinclude:: /../../airflow/providers/flyte/example_dags/example_flyte.py
+ :language: python
+ :start-after: [START howto_operator_flyte_synchronous]
+ :end-before: [END howto_operator_flyte_synchronous]
+
+An example where the execution is triggered asynchronously:
+
+.. exampleinclude:: /../../airflow/providers/flyte/example_dags/example_flyte.py
+ :language: python
+ :start-after: [START howto_operator_flyte_asynchronous]
+ :end-before: [END howto_operator_flyte_asynchronous]
diff --git a/docs/apache-airflow/extra-packages-ref.rst b/docs/apache-airflow/extra-packages-ref.rst
index ad4fa0152221..57efd4b5f2bd 100644
--- a/docs/apache-airflow/extra-packages-ref.rst
+++ b/docs/apache-airflow/extra-packages-ref.rst
@@ -174,6 +174,8 @@ Those are extras that add dependencies needed for integration with external serv
+---------------------+-----------------------------------------------------+-----------------------------------------------------+
| facebook | ``pip install 'apache-airflow[facebook]'`` | Facebook Social |
+---------------------+-----------------------------------------------------+-----------------------------------------------------+
+| flyte | ``pip install 'apache-airflow[flyte]'`` | Flyte hooks and operators |
++---------------------+-----------------------------------------------------+-----------------------------------------------------+
| google | ``pip install 'apache-airflow[google]'`` | Google Cloud |
+---------------------+-----------------------------------------------------+-----------------------------------------------------+
| hashicorp | ``pip install 'apache-airflow[hashicorp]'`` | Hashicorp Services (Vault) |
diff --git a/docs/integration-logos/flyte/Flyte.png b/docs/integration-logos/flyte/Flyte.png
new file mode 100644
index 000000000000..49cdbbbc3419
Binary files /dev/null and b/docs/integration-logos/flyte/Flyte.png differ
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index b6242d9e0369..0a5e55d7cb92 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -157,6 +157,7 @@ Firehose
Firestore
Flink
FluentD
+Flyte
Formaturas
Fundera
GCS
@@ -810,6 +811,7 @@ filetype
findall
firstname
fluentd
+flyte
fmt
fn
fo
diff --git a/setup.py b/setup.py
index 3ce9ae73114e..bbe22723c037 100644
--- a/setup.py
+++ b/setup.py
@@ -300,6 +300,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
flask_appbuilder_authlib = [
'authlib',
]
+flyte = ["flytekit>=0.32.0b0"]
github = [
'pygithub',
]
@@ -683,6 +684,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'elasticsearch': elasticsearch,
'exasol': exasol,
'facebook': facebook,
+ 'flyte': flyte,
'ftp': [],
'github': github,
'google': google,
diff --git a/tests/providers/flyte/__init__.py b/tests/providers/flyte/__init__.py
new file mode 100644
index 000000000000..13a83393a912
--- /dev/null
+++ b/tests/providers/flyte/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/flyte/hooks/__init__.py b/tests/providers/flyte/hooks/__init__.py
new file mode 100644
index 000000000000..13a83393a912
--- /dev/null
+++ b/tests/providers/flyte/hooks/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/flyte/hooks/test_flyte.py b/tests/providers/flyte/hooks/test_flyte.py
new file mode 100644
index 000000000000..8643ac1d6aa0
--- /dev/null
+++ b/tests/providers/flyte/hooks/test_flyte.py
@@ -0,0 +1,368 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from datetime import timedelta
+from unittest import mock
+
+import pytest
+from flytekit.configuration import Config, PlatformConfig
+from flytekit.exceptions.user import FlyteEntityNotExistException, FlyteValueException
+from flytekit.models.core import execution as core_execution_models
+from flytekit.remote import FlyteRemote
+
+from airflow.exceptions import AirflowException
+from airflow.models import Connection
+from airflow.providers.flyte.hooks.flyte import AirflowFlyteHook
+
+
+class TestAirflowFlyteHook(unittest.TestCase):
+
+ flyte_conn_id = "flyte_default"
+ execution_name = "flyte20220330t133856"
+ conn_type = "flyte"
+ host = "localhost"
+ port = "30081"
+ extra = {"project": "flytesnacks", "domain": "development"}
+ launchplan_name = "core.basic.hello_world.my_wf"
+ task_name = "core.basic.hello_world.say_hello"
+ raw_data_prefix = "s3://flyte-demo/raw_data"
+ assumable_iam_role = "arn:aws:iam::123456789012:role/example-role"
+ kubernetes_service_account = "default"
+ version = "v1"
+ inputs = {"name": "hello world"}
+ timeout = timedelta(seconds=3600)
+
+ @classmethod
+ def get_mock_connection(cls):
+ return Connection(
+ conn_id=cls.flyte_conn_id, conn_type=cls.conn_type, host=cls.host, port=cls.port, extra=cls.extra
+ )
+
+ @classmethod
+ def create_remote(cls):
+ return FlyteRemote(
+ config=Config(
+ platform=PlatformConfig(endpoint=":".join([cls.host, cls.port]), insecure=True),
+ )
+ )
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ def test_trigger_execution_success(self, mock_create_flyte_remote, mock_get_connection):
+ mock_connection = self.get_mock_connection()
+ mock_get_connection.return_value = mock_connection
+
+ test_hook = AirflowFlyteHook(self.flyte_conn_id)
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+ mock_remote.fetch_launch_plan = mock.MagicMock()
+
+ mock_remote.execute = mock.MagicMock()
+
+ test_hook.trigger_execution(
+ launchplan_name=self.launchplan_name,
+ raw_data_prefix=self.raw_data_prefix,
+ assumable_iam_role=self.assumable_iam_role,
+ kubernetes_service_account=self.kubernetes_service_account,
+ version=self.version,
+ inputs=self.inputs,
+ execution_name=self.execution_name,
+ )
+ mock_create_flyte_remote.assert_called_once()
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ def test_trigger_task_execution_success(self, mock_create_flyte_remote, mock_get_connection):
+ mock_connection = self.get_mock_connection()
+ mock_get_connection.return_value = mock_connection
+
+ test_hook = AirflowFlyteHook(self.flyte_conn_id)
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+ mock_remote.fetch_task = mock.MagicMock()
+
+ mock_remote.execute = mock.MagicMock()
+
+ test_hook.trigger_execution(
+ task_name=self.task_name,
+ raw_data_prefix=self.raw_data_prefix,
+ assumable_iam_role=self.assumable_iam_role,
+ kubernetes_service_account=self.kubernetes_service_account,
+ version=self.version,
+ inputs=self.inputs,
+ execution_name=self.execution_name,
+ )
+ mock_create_flyte_remote.assert_called_once()
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ def test_trigger_execution_failed_to_fetch(self, mock_create_flyte_remote, mock_get_connection):
+ mock_connection = self.get_mock_connection()
+ mock_get_connection.return_value = mock_connection
+
+ test_hook = AirflowFlyteHook(self.flyte_conn_id)
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+ mock_remote.fetch_launch_plan = mock.MagicMock(side_effect=FlyteEntityNotExistException)
+
+ with pytest.raises(AirflowException):
+ test_hook.trigger_execution(
+ launchplan_name=self.launchplan_name,
+ raw_data_prefix=self.raw_data_prefix,
+ assumable_iam_role=self.assumable_iam_role,
+ kubernetes_service_account=self.kubernetes_service_account,
+ version=self.version,
+ inputs=self.inputs,
+ execution_name=self.execution_name,
+ )
+ mock_create_flyte_remote.assert_called_once()
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ def test_trigger_execution_failed_to_trigger(self, mock_create_flyte_remote, mock_get_connection):
+ mock_connection = self.get_mock_connection()
+ mock_get_connection.return_value = mock_connection
+
+ test_hook = AirflowFlyteHook(self.flyte_conn_id)
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+ mock_remote.fetch_launch_plan = mock.MagicMock()
+ mock_remote.execute = mock.MagicMock(side_effect=FlyteValueException)
+
+ with pytest.raises(AirflowException):
+ test_hook.trigger_execution(
+ launchplan_name=self.launchplan_name,
+ raw_data_prefix=self.raw_data_prefix,
+ assumable_iam_role=self.assumable_iam_role,
+ kubernetes_service_account=self.kubernetes_service_account,
+ version=self.version,
+ inputs=self.inputs,
+ execution_name=self.execution_name,
+ )
+ mock_create_flyte_remote.assert_called_once()
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id")
+ def test_wait_for_execution_succeeded(
+ self, mock_execution_id, mock_create_flyte_remote, mock_get_connection
+ ):
+ mock_connection = self.get_mock_connection()
+ mock_get_connection.return_value = mock_connection
+
+ test_hook = AirflowFlyteHook(self.flyte_conn_id)
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+
+ execution_id = mock.MagicMock()
+ mock_execution_id.return_value = execution_id
+
+ mock_get_execution = mock.MagicMock()
+ mock_remote.client.get_execution = mock_get_execution
+ mock_phase = mock.PropertyMock(side_effect=[test_hook.SUCCEEDED])
+ type(mock_get_execution().closure).phase = mock_phase
+
+ test_hook.wait_for_execution(
+ execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=3)
+ )
+ mock_create_flyte_remote.assert_called_once()
+ mock_execution_id.assert_called_once_with(self.execution_name)
+ assert mock_phase.call_count == 1
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id")
+ def test_wait_for_execution_failed(
+ self, mock_execution_id, mock_create_flyte_remote, mock_get_connection
+ ):
+ mock_connection = self.get_mock_connection()
+ mock_get_connection.return_value = mock_connection
+
+ test_hook = AirflowFlyteHook(self.flyte_conn_id)
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+
+ execution_id = mock.MagicMock()
+ mock_execution_id.return_value = execution_id
+
+ mock_get_execution = mock.MagicMock()
+ mock_remote.client.get_execution = mock_get_execution
+ mock_phase = mock.PropertyMock(
+ side_effect=[core_execution_models.WorkflowExecutionPhase.RUNNING, test_hook.FAILED]
+ )
+ type(mock_get_execution().closure).phase = mock_phase
+
+ with pytest.raises(AirflowException):
+ test_hook.wait_for_execution(
+ execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=3)
+ )
+
+ mock_create_flyte_remote.assert_called_once()
+ mock_execution_id.assert_has_calls([mock.call(self.execution_name)] * 2)
+ assert mock_phase.call_count == 2
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id")
+ def test_wait_for_execution_queued_succeeded(
+ self, mock_execution_id, mock_create_flyte_remote, mock_get_connection
+ ):
+ mock_connection = self.get_mock_connection()
+ mock_get_connection.return_value = mock_connection
+
+ test_hook = AirflowFlyteHook(self.flyte_conn_id)
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+
+ execution_id = mock.MagicMock()
+ mock_execution_id.return_value = execution_id
+
+ mock_get_execution = mock.MagicMock()
+ mock_remote.client.get_execution = mock_get_execution
+ mock_phase = mock.PropertyMock(
+ side_effect=[core_execution_models.WorkflowExecutionPhase.QUEUED, test_hook.SUCCEEDED]
+ )
+ type(mock_get_execution().closure).phase = mock_phase
+
+ test_hook.wait_for_execution(
+ execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=3)
+ )
+
+ mock_create_flyte_remote.assert_called_once()
+ mock_execution_id.assert_has_calls([mock.call(self.execution_name)] * 2)
+ assert mock_phase.call_count == 2
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id")
+ def test_wait_for_execution_timedout(
+ self, mock_execution_id, mock_create_flyte_remote, mock_get_connection
+ ):
+ mock_connection = self.get_mock_connection()
+ mock_get_connection.return_value = mock_connection
+
+ test_hook = AirflowFlyteHook(self.flyte_conn_id)
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+
+ execution_id = mock.MagicMock()
+ mock_execution_id.return_value = execution_id
+
+ mock_get_execution = mock.MagicMock()
+ mock_remote.client.get_execution = mock_get_execution
+ mock_phase = mock.PropertyMock(
+ side_effect=[
+ core_execution_models.WorkflowExecutionPhase.QUEUED,
+ core_execution_models.WorkflowExecutionPhase.RUNNING,
+ test_hook.TIMED_OUT,
+ ]
+ )
+ type(mock_get_execution().closure).phase = mock_phase
+
+ with pytest.raises(AirflowException):
+ test_hook.wait_for_execution(
+ execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=3)
+ )
+
+ mock_create_flyte_remote.assert_called_once()
+ mock_execution_id.assert_has_calls([mock.call(self.execution_name)] * 3)
+ assert mock_phase.call_count == 3
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id")
+ def test_wait_for_execution_timeout(
+ self, mock_execution_id, mock_create_flyte_remote, mock_get_connection
+ ):
+ mock_connection = self.get_mock_connection()
+ mock_get_connection.return_value = mock_connection
+
+ test_hook = AirflowFlyteHook(self.flyte_conn_id)
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+
+ execution_id = mock.MagicMock()
+ mock_execution_id.return_value = execution_id
+
+ mock_get_execution = mock.MagicMock()
+ mock_remote.client.get_execution = mock_get_execution
+ mock_phase = mock.PropertyMock(
+ side_effect=[
+ core_execution_models.WorkflowExecutionPhase.QUEUED,
+ core_execution_models.WorkflowExecutionPhase.RUNNING,
+ core_execution_models.WorkflowExecutionPhase.RUNNING,
+ ]
+ )
+ type(mock_get_execution().closure).phase = mock_phase
+
+ with pytest.raises(AirflowException):
+ test_hook.wait_for_execution(
+ execution_name=self.execution_name,
+ timeout=timedelta(seconds=1),
+ poll_interval=timedelta(seconds=3),
+ )
+
+ mock_create_flyte_remote.assert_called_once()
+ mock_execution_id.assert_called()
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id")
+ def test_wait_for_execution_aborted(
+ self, mock_execution_id, mock_create_flyte_remote, mock_get_connection
+ ):
+ mock_connection = self.get_mock_connection()
+ mock_get_connection.return_value = mock_connection
+
+ test_hook = AirflowFlyteHook(self.flyte_conn_id)
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+
+ execution_id = mock.MagicMock()
+ mock_execution_id.return_value = execution_id
+
+ mock_get_execution = mock.MagicMock()
+ mock_remote.client.get_execution = mock_get_execution
+ mock_phase = mock.PropertyMock(
+ side_effect=[
+ core_execution_models.WorkflowExecutionPhase.RUNNING,
+ core_execution_models.WorkflowExecutionPhase.ABORTED,
+ ]
+ )
+ type(mock_get_execution().closure).phase = mock_phase
+
+ with pytest.raises(AirflowException):
+ test_hook.wait_for_execution(
+ execution_name=self.execution_name,
+ timeout=self.timeout,
+ poll_interval=timedelta(seconds=3),
+ )
+
+ mock_create_flyte_remote.assert_called_once()
+ mock_execution_id.assert_has_calls([mock.call(self.execution_name)] * 2)
+ assert mock_phase.call_count == 2
diff --git a/tests/providers/flyte/operators/__init__.py b/tests/providers/flyte/operators/__init__.py
new file mode 100644
index 000000000000..13a83393a912
--- /dev/null
+++ b/tests/providers/flyte/operators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/flyte/operators/test_flyte.py b/tests/providers/flyte/operators/test_flyte.py
new file mode 100644
index 000000000000..72f44b05cb13
--- /dev/null
+++ b/tests/providers/flyte/operators/test_flyte.py
@@ -0,0 +1,148 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from datetime import timedelta
+from unittest import mock
+
+from airflow.models import Connection
+from airflow.models.dagrun import DagRun
+from airflow.providers.flyte.operators.flyte import AirflowFlyteOperator
+
+
+class TestAirflowFlyteOperator(unittest.TestCase):
+
+ task_id = "test_flyte_operator"
+ flyte_conn_id = "flyte_default"
+ run_id = "manual__2022-03-30T13:55:08.715694+00:00"
+ conn_type = "flyte"
+ host = "localhost"
+ port = "30081"
+ project = "flytesnacks"
+ domain = "development"
+ launchplan_name = "core.basic.hello_world.my_wf"
+ raw_data_prefix = "s3://flyte-demo/raw_data"
+ assumable_iam_role = "arn:aws:iam::123456789012:role/example-role"
+ kubernetes_service_account = "default"
+ labels = {"key1": "value1"}
+ version = "v1"
+ inputs = {"name": "hello world"}
+ timeout = timedelta(seconds=3600)
+ execution_name = "testf20220330t135508"
+
+ @classmethod
+ def get_connection(cls):
+ return Connection(
+ conn_id=cls.flyte_conn_id,
+ conn_type=cls.conn_type,
+ host=cls.host,
+ port=cls.port,
+ extra={"project": cls.project, "domain": cls.domain},
+ )
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.trigger_execution")
+ @mock.patch(
+ "airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.wait_for_execution",
+ return_value=None,
+ )
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ def test_execute(self, mock_get_connection, mock_wait_for_execution, mock_trigger_execution):
+ mock_get_connection.return_value = self.get_connection()
+
+ operator = AirflowFlyteOperator(
+ task_id=self.task_id,
+ flyte_conn_id=self.flyte_conn_id,
+ project=self.project,
+ domain=self.domain,
+ launchplan_name=self.launchplan_name,
+ raw_data_prefix=self.raw_data_prefix,
+ assumable_iam_role=self.assumable_iam_role,
+ kubernetes_service_account=self.kubernetes_service_account,
+ labels=self.labels,
+ version=self.version,
+ inputs=self.inputs,
+ timeout=self.timeout,
+ )
+ result = operator.execute({"dag_run": DagRun(run_id=self.run_id), "task": operator})
+
+ assert result == self.execution_name
+ mock_get_connection.assert_called_once_with(self.flyte_conn_id)
+ mock_trigger_execution.assert_called_once_with(
+ launchplan_name=self.launchplan_name,
+ task_name=None,
+ max_parallelism=None,
+ raw_data_prefix=self.raw_data_prefix,
+ assumable_iam_role=self.assumable_iam_role,
+ kubernetes_service_account=self.kubernetes_service_account,
+ labels=self.labels,
+ annotations={},
+ version=self.version,
+ inputs=self.inputs,
+ execution_name=self.execution_name,
+ )
+ mock_wait_for_execution.assert_called_once_with(
+ execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=30)
+ )
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.trigger_execution", return_value=None)
+ @mock.patch(
+ "airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.wait_for_execution",
+ return_value=None,
+ )
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.terminate")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ def test_on_kill_success(
+ self, mock_get_connection, mock_terminate, mock_wait_for_execution, mock_trigger_execution
+ ):
+ mock_get_connection.return_value = self.get_connection()
+
+ operator = AirflowFlyteOperator(
+ task_id=self.task_id,
+ flyte_conn_id=self.flyte_conn_id,
+ project=self.project,
+ domain=self.domain,
+ launchplan_name=self.launchplan_name,
+ inputs=self.inputs,
+ timeout=self.timeout,
+ )
+ operator.execute({"dag_run": DagRun(run_id=self.run_id), "task": operator})
+ operator.on_kill()
+
+ mock_get_connection.has_calls([mock.call(self.flyte_conn_id)] * 2)
+ mock_trigger_execution.assert_called()
+ mock_wait_for_execution.assert_called_once_with(
+ execution_name=self.execution_name, timeout=self.timeout, poll_interval=timedelta(seconds=30)
+ )
+ mock_terminate.assert_called_once_with(execution_name=self.execution_name, cause="Killed by Airflow")
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.terminate")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ def test_on_kill_noop(self, mock_get_connection, mock_terminate):
+ mock_get_connection.return_value = self.get_connection()
+
+ operator = AirflowFlyteOperator(
+ task_id=self.task_id,
+ flyte_conn_id=self.flyte_conn_id,
+ project=self.project,
+ domain=self.domain,
+ launchplan_name=self.launchplan_name,
+ inputs=self.inputs,
+ )
+ operator.on_kill()
+
+ mock_get_connection.assert_not_called()
+ mock_terminate.assert_not_called()
diff --git a/tests/providers/flyte/sensors/__init__.py b/tests/providers/flyte/sensors/__init__.py
new file mode 100644
index 000000000000..13a83393a912
--- /dev/null
+++ b/tests/providers/flyte/sensors/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/flyte/sensors/test_flyte.py b/tests/providers/flyte/sensors/test_flyte.py
new file mode 100644
index 000000000000..ebe7746912bb
--- /dev/null
+++ b/tests/providers/flyte/sensors/test_flyte.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from unittest import mock
+
+import pytest
+from flytekit.configuration import Config, PlatformConfig
+from flytekit.models.core import execution as core_execution_models
+from flytekit.remote import FlyteRemote
+
+from airflow import AirflowException
+from airflow.models import Connection
+from airflow.providers.flyte.hooks.flyte import AirflowFlyteHook
+from airflow.providers.flyte.sensors.flyte import AirflowFlyteSensor
+
+
+class TestAirflowFlyteSensor(unittest.TestCase):
+
+ task_id = "test_flyte_sensor"
+ flyte_conn_id = "flyte_default"
+ conn_type = "flyte"
+ host = "localhost"
+ port = "30081"
+ project = "flytesnacks"
+ domain = "development"
+ execution_name = "testf20220330t135508"
+
+ @classmethod
+ def get_connection(cls):
+ return Connection(
+ conn_id=cls.flyte_conn_id,
+ conn_type=cls.conn_type,
+ host=cls.host,
+ port=cls.port,
+ extra={"project": cls.project, "domain": cls.domain},
+ )
+
+ @classmethod
+ def create_remote(cls):
+ return FlyteRemote(
+ config=Config(
+ platform=PlatformConfig(endpoint=":".join([cls.host, cls.port]), insecure=True),
+ )
+ )
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id")
+ def test_poke_done(self, mock_execution_id, mock_create_flyte_remote, mock_get_connection):
+ mock_get_connection.return_value = self.get_connection()
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+
+ execution_id = mock.MagicMock()
+ mock_execution_id.return_value = execution_id
+
+ mock_get_execution = mock.MagicMock()
+ mock_remote.client.get_execution = mock_get_execution
+ mock_phase = mock.PropertyMock(return_value=AirflowFlyteHook.SUCCEEDED)
+ type(mock_get_execution().closure).phase = mock_phase
+
+ sensor = AirflowFlyteSensor(
+ task_id=self.task_id,
+ execution_name=self.execution_name,
+ project=self.project,
+ domain=self.domain,
+ flyte_conn_id=self.flyte_conn_id,
+ )
+
+ return_value = sensor.poke({})
+
+ assert return_value
+ mock_create_flyte_remote.assert_called_once()
+ mock_execution_id.assert_called_with(self.execution_name)
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id")
+ def test_poke_failed(self, mock_execution_id, mock_create_flyte_remote, mock_get_connection):
+ mock_get_connection.return_value = self.get_connection()
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+
+ sensor = AirflowFlyteSensor(
+ task_id=self.task_id,
+ execution_name=self.execution_name,
+ project=self.project,
+ domain=self.domain,
+ flyte_conn_id=self.flyte_conn_id,
+ )
+
+ execution_id = mock.MagicMock()
+ mock_execution_id.return_value = execution_id
+
+ for phase in [AirflowFlyteHook.ABORTED, AirflowFlyteHook.FAILED, AirflowFlyteHook.TIMED_OUT]:
+ mock_get_execution = mock.MagicMock()
+ mock_remote.client.get_execution = mock_get_execution
+ mock_phase = mock.PropertyMock(return_value=phase)
+ type(mock_get_execution().closure).phase = mock_phase
+
+ with pytest.raises(AirflowException):
+ sensor.poke({})
+
+ mock_create_flyte_remote.has_calls([mock.call()] * 3)
+ mock_execution_id.has_calls([mock.call(self.execution_name)] * 3)
+
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.get_connection")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.create_flyte_remote")
+ @mock.patch("airflow.providers.flyte.hooks.flyte.AirflowFlyteHook.execution_id")
+ def test_poke_running(self, mock_execution_id, mock_create_flyte_remote, mock_get_connection):
+ mock_get_connection.return_value = self.get_connection()
+
+ mock_remote = self.create_remote()
+ mock_create_flyte_remote.return_value = mock_remote
+
+ execution_id = mock.MagicMock()
+ mock_execution_id.return_value = execution_id
+
+ mock_get_execution = mock.MagicMock()
+ mock_remote.client.get_execution = mock_get_execution
+ mock_phase = mock.PropertyMock(return_value=core_execution_models.WorkflowExecutionPhase.RUNNING)
+ type(mock_get_execution().closure).phase = mock_phase
+
+ sensor = AirflowFlyteSensor(
+ task_id=self.task_id,
+ execution_name=self.execution_name,
+ project=self.project,
+ domain=self.domain,
+ flyte_conn_id=self.flyte_conn_id,
+ )
+
+ return_value = sensor.poke({})
+ assert not return_value
+
+ mock_create_flyte_remote.assert_called_once()
+ mock_execution_id.assert_called_with(self.execution_name)