Skip to content

Commit

Permalink
providers/yandex: fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
got686-yandex committed Jul 24, 2024
1 parent 9ec9eb7 commit 8a5a19c
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 30 deletions.
18 changes: 17 additions & 1 deletion airflow/providers/yandex/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,16 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from deprecated import deprecated

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.yandex.hooks.yandex import YandexCloudBaseHook

if TYPE_CHECKING:
from yandexcloud._wrappers.dataproc import Dataproc


class DataprocHook(YandexCloudBaseHook):
"""
Expand All @@ -29,7 +37,15 @@ class DataprocHook(YandexCloudBaseHook):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.cluster_id = None
self.client = self.sdk.wrappers.Dataproc(
self.dataproc_client: Dataproc = self.sdk.wrappers.Dataproc(
default_folder_id=self.default_folder_id,
default_public_ssh_key=self.default_public_ssh_key,
)

@property
@deprecated(
reason="`client` deprecated and will be removed in the future. Use `dataproc_client` instead",
category=AirflowProviderDeprecationWarning,
)
def client(self):
return self.dataproc_client
10 changes: 8 additions & 2 deletions airflow/providers/yandex/hooks/yandex.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.hooks.base import BaseHook
from airflow.providers.yandex.utils.credentials import (
CredentialsType,
get_credentials,
get_service_account_id,
)
Expand Down Expand Up @@ -132,13 +133,18 @@ def __init__(
self.connection_id = yandex_conn_id or connection_id or default_conn_name
self.connection = self.get_connection(self.connection_id)
self.extras = self.connection.extra_dejson
self.credentials = get_credentials(
self.credentials: CredentialsType = get_credentials(
oauth_token=self._get_field("oauth"),
service_account_json=self._get_field("service_account_json"),
service_account_json_path=self._get_field("service_account_json_path"),
)
sdk_config = self._get_endpoint()
self.sdk = yandexcloud.SDK(user_agent=provider_user_agent(), **sdk_config, **self.credentials)
self.sdk = yandexcloud.SDK(
user_agent=provider_user_agent(),
token=self.credentials.get("token"),
service_account_key=self.credentials.get("service_account_key"),
endpoint=sdk_config["endpoint"],
)
self.default_folder_id = default_folder_id or self._get_field("folder_id")
self.default_public_ssh_key = default_public_ssh_key or self._get_field("public_ssh_key")
self.default_service_account_id = default_service_account_id or get_service_account_id(
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/yandex/hooks/yq.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def compose_query_web_link(self, query_id: str):
return self.client.compose_query_web_link(query_id)

def _get_iam_token(self) -> str:
iam_token = self.credentials.get("token")
if iam_token is not None:
return iam_token
if "token" in self.credentials:
return self.credentials["token"]

return yc_auth.get_auth_token(service_account_key=self.credentials.get("service_account_key"))
33 changes: 23 additions & 10 deletions airflow/providers/yandex/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class InitializationAction:
"""Data for initialization action to be run at start of DataProc cluster."""

uri: str # Uri of the executable file
args: Sequence[str] # Arguments to the initialization action
args: Iterable[str] # Arguments to the initialization action
timeout: int # Execution timeout


Expand Down Expand Up @@ -143,6 +143,18 @@ def __init__(
**kwargs,
) -> None:
super().__init__(**kwargs)
if ssh_public_keys is None:
ssh_public_keys = []

if services is None:
services = []

if host_group_ids is None:
host_group_ids = []

if security_group_ids is None:
security_group_ids = []

self.folder_id = folder_id
self.yandex_conn_id = connection_id
self.cluster_name = cluster_name
Expand Down Expand Up @@ -186,7 +198,7 @@ def execute(self, context: Context) -> dict:
self.hook = DataprocHook(
yandex_conn_id=self.yandex_conn_id,
)
operation_result = self.hook.client.create_cluster(
operation_result = self.hook.dataproc_client.create_cluster(
folder_id=self.folder_id,
cluster_name=self.cluster_name,
cluster_description=self.cluster_description,
Expand Down Expand Up @@ -221,15 +233,16 @@ def execute(self, context: Context) -> dict:
security_group_ids=self.security_group_ids,
log_group_id=self.log_group_id,
labels=self.labels,
initialization_actions=self.initialization_actions
and [
initialization_actions=[
self.hook.sdk.wrappers.InitializationAction(
uri=init_action.uri,
args=init_action.args,
timeout=init_action.timeout,
)
for init_action in self.initialization_actions
],
]
if self.initialization_actions
else [],
)
cluster_id = operation_result.response.id

Expand Down Expand Up @@ -290,7 +303,7 @@ def __init__(self, *, connection_id: str | None = None, cluster_id: str | None =

def execute(self, context: Context) -> None:
hook = self._setup(context)
hook.client.delete_cluster(self.cluster_id)
hook.dataproc_client.delete_cluster(self.cluster_id)


class DataprocCreateHiveJobOperator(DataprocBaseOperator):
Expand Down Expand Up @@ -331,7 +344,7 @@ def __init__(

def execute(self, context: Context) -> None:
hook = self._setup(context)
hook.client.create_hive_job(
hook.dataproc_client.create_hive_job(
query=self.query,
query_file_uri=self.query_file_uri,
script_variables=self.script_variables,
Expand Down Expand Up @@ -387,7 +400,7 @@ def __init__(

def execute(self, context: Context) -> None:
hook = self._setup(context)
hook.client.create_mapreduce_job(
hook.dataproc_client.create_mapreduce_job(
main_class=self.main_class,
main_jar_file_uri=self.main_jar_file_uri,
jar_file_uris=self.jar_file_uris,
Expand Down Expand Up @@ -455,7 +468,7 @@ def __init__(

def execute(self, context: Context) -> None:
hook = self._setup(context)
hook.client.create_spark_job(
hook.dataproc_client.create_spark_job(
main_class=self.main_class,
main_jar_file_uri=self.main_jar_file_uri,
jar_file_uris=self.jar_file_uris,
Expand Down Expand Up @@ -526,7 +539,7 @@ def __init__(

def execute(self, context: Context) -> None:
hook = self._setup(context)
hook.client.create_pyspark_job(
hook.dataproc_client.create_pyspark_job(
main_python_file_uri=self.main_python_file_uri,
python_file_uris=self.python_file_uris,
jar_file_uris=self.jar_file_uris,
Expand Down
7 changes: 1 addition & 6 deletions airflow/providers/yandex/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,7 @@ versions:

dependencies:
- apache-airflow>=2.7.0
# The 0.289 and 0.290 versions have broken dataproc support
# See https://github.com/yandex-cloud/python-sdk/issues/103
# the 0.291.0 version of yandex provider introduced mypy typing
# that conflicts with the way yandex provider uses it and should be fixed
# See https://github.com/yandex-cloud/python-sdk/issues/106
- yandexcloud>=0.278.0,!=0.289.0,!=0.290.0,<0.292.0
- yandexcloud>=0.305.0
- yandex-query-client>=0.1.4

integrations:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/yandex/secrets/lockbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _build_secret_name(self, prefix: str, key: str):
return f"{prefix}{self.sep}{key}"

def _get_secret_value(self, prefix: str, key: str) -> str | None:
secret: secret_pb.Secret = None
secret: secret_pb.Secret | None = None
for s in self._get_secrets():
if s.name == self._build_secret_name(prefix=prefix, key=key):
secret = s
Expand Down
11 changes: 9 additions & 2 deletions airflow/providers/yandex/utils/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,23 @@

import json
import logging
from typing import Any
from typing import TypedDict

log = logging.getLogger(__name__)


class CredentialsType(TypedDict, total=False):
"""Credentials dict description."""

token: str
service_account_key: dict[str, str]


def get_credentials(
oauth_token: str | None = None,
service_account_json: dict | str | None = None,
service_account_json_path: str | None = None,
) -> dict[str, Any]:
) -> CredentialsType:
"""
Return credentials JSON for Yandex Cloud SDK based on credentials.
Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -1340,7 +1340,7 @@
"deps": [
"apache-airflow>=2.7.0",
"yandex-query-client>=0.1.4",
"yandexcloud>=0.278.0,!=0.289.0,!=0.290.0,<0.292.0"
"yandexcloud>=0.305.0"
],
"devel-deps": [],
"plugins": [],
Expand Down
19 changes: 15 additions & 4 deletions tests/system/providers/yandex/example_yandexcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import yandex.cloud.dataproc.v1.job_service_pb2_grpc as job_service_grpc_pb
import yandex.cloud.dataproc.v1.subcluster_pb2 as subcluster_pb
from google.protobuf.json_format import MessageToDict
from yandexcloud.operations import OperationError

from airflow import DAG
from airflow.decorators import task
Expand Down Expand Up @@ -61,7 +62,7 @@ def create_cluster_request(
bucket=YC_S3_BUCKET_NAME,
config_spec=cluster_service_pb.CreateClusterConfigSpec(
hadoop=cluster_pb.HadoopConfig(
services=("SPARK", "YARN"),
services=(cluster_pb.HadoopConfig.Service.SPARK, cluster_pb.HadoopConfig.Service.YARN),
ssh_public_keys=[ssh_public_key],
),
subclusters_spec=[
Expand Down Expand Up @@ -98,13 +99,13 @@ def create_cluster(
*,
dag: DAG | None = None,
ts_nodash: str | None = None,
) -> str:
) -> str | None:
hook = YandexCloudBaseHook(yandex_conn_id=yandex_conn_id)
folder_id = folder_id or hook.default_folder_id
if subnet_id is None:
network_id = network_id or hook.sdk.helpers.find_network_id(folder_id)
subnet_id = hook.sdk.helpers.find_subnet_id(folder_id=folder_id, zone_id=zone, network_id=network_id)
service_account_id = service_account_id or hook.sdk.helpers.find_service_account_id()
service_account_id = service_account_id or hook.sdk.helpers.find_service_account_id(folder_id=folder_id)
ssh_public_key = ssh_public_key or hook.default_public_ssh_key

dag_id = dag and dag.dag_id or "dag"
Expand All @@ -126,6 +127,12 @@ def create_cluster(
operation_result = hook.sdk.wait_operation_and_get_result(
operation, response_type=cluster_pb.Cluster, meta_type=cluster_service_pb.CreateClusterMetadata
)
if isinstance(operation_result, OperationError):
raise ValueError("Cluster creation error")

if operation_result.response is None:
return None

return operation_result.response.id


Expand All @@ -149,7 +156,11 @@ def run_spark_job(
operation_result = hook.sdk.wait_operation_and_get_result(
operation, response_type=job_pb.Job, meta_type=job_service_pb.CreateJobMetadata
)
return MessageToDict(operation_result.response)

if isinstance(operation_result, OperationError):
raise ValueError("Run spark task error")

return MessageToDict(operation_result.response) if operation_result.response is not None else None


@task(trigger_rule="all_done")
Expand Down

0 comments on commit 8a5a19c

Please sign in to comment.