diff --git a/airflow/providers/yandex/hooks/dataproc.py b/airflow/providers/yandex/hooks/dataproc.py index 9b1862205e4662..ea10e3e39f3077 100644 --- a/airflow/providers/yandex/hooks/dataproc.py +++ b/airflow/providers/yandex/hooks/dataproc.py @@ -16,8 +16,16 @@ # under the License. from __future__ import annotations +from typing import TYPE_CHECKING + +from deprecated import deprecated + +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook +if TYPE_CHECKING: + from yandexcloud._wrappers.dataproc import Dataproc + class DataprocHook(YandexCloudBaseHook): """ @@ -29,7 +37,15 @@ class DataprocHook(YandexCloudBaseHook): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.cluster_id = None - self.client = self.sdk.wrappers.Dataproc( + self.dataproc_client: Dataproc = self.sdk.wrappers.Dataproc( default_folder_id=self.default_folder_id, default_public_ssh_key=self.default_public_ssh_key, ) + + @property + @deprecated( + reason="`client` deprecated and will be removed in the future. Use `dataproc_client` instead", + category=AirflowProviderDeprecationWarning, + ) + def client(self): + return self.dataproc_client diff --git a/airflow/providers/yandex/hooks/yandex.py b/airflow/providers/yandex/hooks/yandex.py index 5ad7ce28957630..6681df9387fef5 100644 --- a/airflow/providers/yandex/hooks/yandex.py +++ b/airflow/providers/yandex/hooks/yandex.py @@ -24,6 +24,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.hooks.base import BaseHook from airflow.providers.yandex.utils.credentials import ( + CredentialsType, get_credentials, get_service_account_id, ) @@ -132,13 +133,18 @@ def __init__( self.connection_id = yandex_conn_id or connection_id or default_conn_name self.connection = self.get_connection(self.connection_id) self.extras = self.connection.extra_dejson - self.credentials = get_credentials( + self.credentials: CredentialsType = get_credentials( oauth_token=self._get_field("oauth"), service_account_json=self._get_field("service_account_json"), service_account_json_path=self._get_field("service_account_json_path"), ) sdk_config = self._get_endpoint() - self.sdk = yandexcloud.SDK(user_agent=provider_user_agent(), **sdk_config, **self.credentials) + self.sdk = yandexcloud.SDK( + user_agent=provider_user_agent(), + token=self.credentials.get("token"), + service_account_key=self.credentials.get("service_account_key"), + endpoint=sdk_config["endpoint"], + ) self.default_folder_id = default_folder_id or self._get_field("folder_id") self.default_public_ssh_key = default_public_ssh_key or self._get_field("public_ssh_key") self.default_service_account_id = default_service_account_id or get_service_account_id( diff --git a/airflow/providers/yandex/hooks/yq.py b/airflow/providers/yandex/hooks/yq.py index 5950c0d6b05074..011476b04386bf 100644 --- a/airflow/providers/yandex/hooks/yq.py +++ b/airflow/providers/yandex/hooks/yq.py @@ -100,8 +100,7 @@ def compose_query_web_link(self, query_id: str): return self.client.compose_query_web_link(query_id) def _get_iam_token(self) -> str: - iam_token = self.credentials.get("token") - if iam_token is not None: - return iam_token + if "token" in self.credentials: + return self.credentials["token"] return yc_auth.get_auth_token(service_account_key=self.credentials.get("service_account_key")) diff --git a/airflow/providers/yandex/operators/dataproc.py b/airflow/providers/yandex/operators/dataproc.py index 2539314fcd58e1..f0a53f954ef118 100644 --- a/airflow/providers/yandex/operators/dataproc.py +++ b/airflow/providers/yandex/operators/dataproc.py @@ -33,7 +33,7 @@ class InitializationAction: """Data for initialization action to be run at start of DataProc cluster.""" uri: str # Uri of the executable file - args: Sequence[str] # Arguments to the initialization action + args: Iterable[str] # Arguments to the initialization action timeout: int # Execution timeout @@ -143,6 +143,18 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) + if ssh_public_keys is None: + ssh_public_keys = [] + + if services is None: + services = [] + + if host_group_ids is None: + host_group_ids = [] + + if security_group_ids is None: + security_group_ids = [] + self.folder_id = folder_id self.yandex_conn_id = connection_id self.cluster_name = cluster_name @@ -186,7 +198,7 @@ def execute(self, context: Context) -> dict: self.hook = DataprocHook( yandex_conn_id=self.yandex_conn_id, ) - operation_result = self.hook.client.create_cluster( + operation_result = self.hook.dataproc_client.create_cluster( folder_id=self.folder_id, cluster_name=self.cluster_name, cluster_description=self.cluster_description, @@ -221,15 +233,16 @@ def execute(self, context: Context) -> dict: security_group_ids=self.security_group_ids, log_group_id=self.log_group_id, labels=self.labels, - initialization_actions=self.initialization_actions - and [ + initialization_actions=[ self.hook.sdk.wrappers.InitializationAction( uri=init_action.uri, args=init_action.args, timeout=init_action.timeout, ) for init_action in self.initialization_actions - ], + ] + if self.initialization_actions + else [], ) cluster_id = operation_result.response.id @@ -290,7 +303,7 @@ def __init__(self, *, connection_id: str | None = None, cluster_id: str | None = def execute(self, context: Context) -> None: hook = self._setup(context) - hook.client.delete_cluster(self.cluster_id) + hook.dataproc_client.delete_cluster(self.cluster_id) class DataprocCreateHiveJobOperator(DataprocBaseOperator): @@ -331,7 +344,7 @@ def __init__( def execute(self, context: Context) -> None: hook = self._setup(context) - hook.client.create_hive_job( + hook.dataproc_client.create_hive_job( query=self.query, query_file_uri=self.query_file_uri, script_variables=self.script_variables, @@ -387,7 +400,7 @@ def __init__( def execute(self, context: Context) -> None: hook = self._setup(context) - hook.client.create_mapreduce_job( + hook.dataproc_client.create_mapreduce_job( main_class=self.main_class, main_jar_file_uri=self.main_jar_file_uri, jar_file_uris=self.jar_file_uris, @@ -455,7 +468,7 @@ def __init__( def execute(self, context: Context) -> None: hook = self._setup(context) - hook.client.create_spark_job( + hook.dataproc_client.create_spark_job( main_class=self.main_class, main_jar_file_uri=self.main_jar_file_uri, jar_file_uris=self.jar_file_uris, @@ -526,7 +539,7 @@ def __init__( def execute(self, context: Context) -> None: hook = self._setup(context) - hook.client.create_pyspark_job( + hook.dataproc_client.create_pyspark_job( main_python_file_uri=self.main_python_file_uri, python_file_uris=self.python_file_uris, jar_file_uris=self.jar_file_uris, diff --git a/airflow/providers/yandex/provider.yaml b/airflow/providers/yandex/provider.yaml index 5ec7bd39393933..794a6b25b691d2 100644 --- a/airflow/providers/yandex/provider.yaml +++ b/airflow/providers/yandex/provider.yaml @@ -52,12 +52,7 @@ versions: dependencies: - apache-airflow>=2.7.0 - # The 0.289 and 0.290 versions have broken dataproc support - # See https://github.com/yandex-cloud/python-sdk/issues/103 - # the 0.291.0 version of yandex provider introduced mypy typing - # that conflicts with the way yandex provider uses it and should be fixed - # See https://github.com/yandex-cloud/python-sdk/issues/106 - - yandexcloud>=0.278.0,!=0.289.0,!=0.290.0,<0.292.0 + - yandexcloud>=0.305.0 - yandex-query-client>=0.1.4 integrations: diff --git a/airflow/providers/yandex/secrets/lockbox.py b/airflow/providers/yandex/secrets/lockbox.py index 68879b446eef0a..d65131ab2cb1a7 100644 --- a/airflow/providers/yandex/secrets/lockbox.py +++ b/airflow/providers/yandex/secrets/lockbox.py @@ -242,7 +242,7 @@ def _build_secret_name(self, prefix: str, key: str): return f"{prefix}{self.sep}{key}" def _get_secret_value(self, prefix: str, key: str) -> str | None: - secret: secret_pb.Secret = None + secret: secret_pb.Secret | None = None for s in self._get_secrets(): if s.name == self._build_secret_name(prefix=prefix, key=key): secret = s diff --git a/airflow/providers/yandex/utils/credentials.py b/airflow/providers/yandex/utils/credentials.py index f54e8bdfbedfce..015e30ceaae1d2 100644 --- a/airflow/providers/yandex/utils/credentials.py +++ b/airflow/providers/yandex/utils/credentials.py @@ -18,16 +18,23 @@ import json import logging -from typing import Any +from typing import TypedDict log = logging.getLogger(__name__) +class CredentialsType(TypedDict, total=False): + """Credentials dict description.""" + + token: str + service_account_key: dict[str, str] + + def get_credentials( oauth_token: str | None = None, service_account_json: dict | str | None = None, service_account_json_path: str | None = None, -) -> dict[str, Any]: +) -> CredentialsType: """ Return credentials JSON for Yandex Cloud SDK based on credentials. diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 86b9b2e15b8189..be8b8637f2aa6c 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1340,7 +1340,7 @@ "deps": [ "apache-airflow>=2.7.0", "yandex-query-client>=0.1.4", - "yandexcloud>=0.278.0,!=0.289.0,!=0.290.0,<0.292.0" + "yandexcloud>=0.305.0" ], "devel-deps": [], "plugins": [], diff --git a/tests/system/providers/yandex/example_yandexcloud.py b/tests/system/providers/yandex/example_yandexcloud.py index 2751458ae984a0..ddebc46a3b50cc 100644 --- a/tests/system/providers/yandex/example_yandexcloud.py +++ b/tests/system/providers/yandex/example_yandexcloud.py @@ -27,6 +27,7 @@ import yandex.cloud.dataproc.v1.job_service_pb2_grpc as job_service_grpc_pb import yandex.cloud.dataproc.v1.subcluster_pb2 as subcluster_pb from google.protobuf.json_format import MessageToDict +from yandexcloud.operations import OperationError from airflow import DAG from airflow.decorators import task @@ -61,7 +62,7 @@ def create_cluster_request( bucket=YC_S3_BUCKET_NAME, config_spec=cluster_service_pb.CreateClusterConfigSpec( hadoop=cluster_pb.HadoopConfig( - services=("SPARK", "YARN"), + services=(cluster_pb.HadoopConfig.Service.SPARK, cluster_pb.HadoopConfig.Service.YARN), ssh_public_keys=[ssh_public_key], ), subclusters_spec=[ @@ -98,13 +99,13 @@ def create_cluster( *, dag: DAG | None = None, ts_nodash: str | None = None, -) -> str: +) -> str | None: hook = YandexCloudBaseHook(yandex_conn_id=yandex_conn_id) folder_id = folder_id or hook.default_folder_id if subnet_id is None: network_id = network_id or hook.sdk.helpers.find_network_id(folder_id) subnet_id = hook.sdk.helpers.find_subnet_id(folder_id=folder_id, zone_id=zone, network_id=network_id) - service_account_id = service_account_id or hook.sdk.helpers.find_service_account_id() + service_account_id = service_account_id or hook.sdk.helpers.find_service_account_id(folder_id=folder_id) ssh_public_key = ssh_public_key or hook.default_public_ssh_key dag_id = dag and dag.dag_id or "dag" @@ -126,6 +127,12 @@ def create_cluster( operation_result = hook.sdk.wait_operation_and_get_result( operation, response_type=cluster_pb.Cluster, meta_type=cluster_service_pb.CreateClusterMetadata ) + if isinstance(operation_result, OperationError): + raise ValueError("Cluster creation error") + + if operation_result.response is None: + return None + return operation_result.response.id @@ -149,7 +156,11 @@ def run_spark_job( operation_result = hook.sdk.wait_operation_and_get_result( operation, response_type=job_pb.Job, meta_type=job_service_pb.CreateJobMetadata ) - return MessageToDict(operation_result.response) + + if isinstance(operation_result, OperationError): + raise ValueError("Run spark task error") + + return MessageToDict(operation_result.response) if operation_result.response is not None else None @task(trigger_rule="all_done")