Skip to content

Commit

Permalink
Merge branch 'main' into Andy/Unittest_Coverage_Issue
Browse files Browse the repository at this point in the history
  • Loading branch information
andyjianzhou authored Jul 10, 2024
2 parents 4eb1546 + aca140a commit 0a3c7a1
Show file tree
Hide file tree
Showing 23 changed files with 270 additions and 92 deletions.
20 changes: 18 additions & 2 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple

import pendulum
Expand All @@ -32,6 +33,7 @@
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.task_context_logger import TaskContextLogger
from airflow.utils.state import TaskInstanceState

PARALLELISM: int = conf.getint("core", "PARALLELISM")
Expand Down Expand Up @@ -284,8 +286,12 @@ def trigger_tasks(self, open_slots: int) -> None:
self.log.info("queued but still running; attempt=%s task=%s", attempt.total_tries, key)
continue
# Otherwise, we give up and remove the task from the queue.
self.log.error(
"could not queue task %s (still running after %d attempts)", key, attempt.total_tries
self.send_message_to_task_logs(
logging.ERROR,
"could not queue task %s (still running after %d attempts).",
key,
attempt.total_tries,
ti=ti,
)
del self.attempts[key]
del self.queued_tasks[key]
Expand Down Expand Up @@ -512,6 +518,16 @@ def send_callback(self, request: CallbackRequest) -> None:
raise ValueError("Callback sink is not ready.")
self.callback_sink.send(request)

@cached_property
def _task_context_logger(self) -> TaskContextLogger:
return TaskContextLogger(
component_name="Executor",
call_site_logger=self.log,
)

def send_message_to_task_logs(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
self._task_context_logger._log(level, msg, *args, ti=ti)

@staticmethod
def get_cli_commands() -> list[GroupCommand]:
"""
Expand Down
10 changes: 10 additions & 0 deletions airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ class AwsAuthManager(BaseAuthManager):
"""

def __init__(self, appbuilder: AirflowAppBuilder) -> None:
from packaging.version import Version

from airflow.version import version

# TODO: remove this if block when min_airflow_version is set to higher than 2.9.0
if Version(version) < Version("2.9"):
raise AirflowOptionalProviderFeatureException(
"``AwsAuthManager`` is compatible with Airflow versions >= 2.9."
)

super().__init__(appbuilder)
self._check_avp_schema_version()

Expand Down
37 changes: 22 additions & 15 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from __future__ import annotations

import logging
import time
from collections import defaultdict, deque
from copy import deepcopy
Expand Down Expand Up @@ -347,7 +348,7 @@ def attempt_task_runs(self):
queue = ecs_task.queue
exec_config = ecs_task.executor_config
attempt_number = ecs_task.attempt_number
_failure_reasons = []
failure_reasons = []
if timezone.utcnow() < ecs_task.next_attempt_time:
self.pending_tasks.append(ecs_task)
continue
Expand All @@ -361,23 +362,21 @@ def attempt_task_runs(self):
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
self.pending_tasks.append(ecs_task)
raise
_failure_reasons.append(str(e))
failure_reasons.append(str(e))
except Exception as e:
# Failed to even get a response back from the Boto3 API or something else went
# wrong. For any possible failure we want to add the exception reasons to the
# failure list so that it is logged to the user and most importantly the task is
# added back to the pending list to be retried later.
_failure_reasons.append(str(e))
failure_reasons.append(str(e))
else:
# We got a response back, check if there were failures. If so, add them to the
# failures list so that it is logged to the user and most importantly the task
# is added back to the pending list to be retried later.
if run_task_response["failures"]:
_failure_reasons.extend([f["reason"] for f in run_task_response["failures"]])
failure_reasons.extend([f["reason"] for f in run_task_response["failures"]])

if _failure_reasons:
for reason in _failure_reasons:
failure_reasons[reason] += 1
if failure_reasons:
# Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
ecs_task.attempt_number += 1
Expand All @@ -386,14 +385,19 @@ def attempt_task_runs(self):
)
self.pending_tasks.append(ecs_task)
else:
self.log.error(
"ECS task %s has failed a maximum of %s times. Marking as failed",
self.send_message_to_task_logs(
logging.ERROR,
"ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
task_key,
attempt_number,
", ".join(failure_reasons),
ti=task_key,
)
self.fail(task_key)
elif not run_task_response["tasks"]:
self.log.error("ECS RunTask Response: %s", run_task_response)
self.send_message_to_task_logs(
logging.ERROR, "ECS RunTask Response: %s", run_task_response, ti=task_key
)
raise EcsExecutorException(
"No failures and no ECS tasks provided in response. This should never happen."
)
Expand All @@ -407,11 +411,6 @@ def attempt_task_runs(self):
# executor feature).
# TODO: remove when min airflow version >= 2.9.2
pass
if failure_reasons:
self.log.error(
"Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
dict(failure_reasons),
)

def _run_task(
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
Expand Down Expand Up @@ -543,3 +542,11 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task

not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis

def send_message_to_task_logs(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
# TODO: remove this method when min_airflow_version is set to higher than 2.10.0
try:
super().send_message_to_task_logs(level, msg, *args, ti=ti)
except AttributeError:
# ``send_message_to_task_logs`` is added in 2.10.0
self.log.error(msg, *args)
16 changes: 11 additions & 5 deletions airflow/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import fnmatch
import inspect
import os
import re
from datetime import datetime, timedelta
Expand Down Expand Up @@ -57,13 +58,13 @@ class S3KeySensor(BaseSensorOperator):
refers to this bucket
:param wildcard_match: whether the bucket_key should be interpreted as a
Unix wildcard pattern
:param check_fn: Function that receives the list of the S3 objects,
:param check_fn: Function that receives the list of the S3 objects with the context values,
and returns a boolean:
- ``True``: the criteria is met
- ``False``: the criteria isn't met
**Example**: Wait for any S3 object size more than 1 megabyte ::
def check_fn(files: List) -> bool:
def check_fn(files: List, **kwargs) -> bool:
return any(f.get('Size', 0) > 1048576 for f in files)
:param aws_conn_id: a reference to the s3 connection
:param verify: Whether to verify SSL certificates for S3 connection.
Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(
self.use_regex = use_regex
self.metadata_keys = metadata_keys if metadata_keys else ["Size"]

def _check_key(self, key):
def _check_key(self, key, context: Context):
bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
self.log.info("Poking for key : s3://%s/%s", bucket_name, key)

Expand Down Expand Up @@ -167,15 +168,20 @@ def _check_key(self, key):
files = [metadata]

if self.check_fn is not None:
# For backwards compatibility, check if the function takes a context argument
signature = inspect.signature(self.check_fn)
if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
return self.check_fn(files, **context)
# Otherwise, just pass the files
return self.check_fn(files)

return True

def poke(self, context: Context):
if isinstance(self.bucket_key, str):
return self._check_key(self.bucket_key)
return self._check_key(self.bucket_key, context=context)
else:
return all(self._check_key(key) for key in self.bucket_key)
return all(self._check_key(key, context=context) for key in self.bucket_key)

def execute(self, context: Context) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
Expand Down
19 changes: 13 additions & 6 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2955,6 +2955,7 @@ def execute(self, context: Any):

try:
self.log.info("Executing: %s'", self.configuration)
# Create a job
job: BigQueryJob | UnknownJob = self._submit_job(hook, self.job_id)
except Conflict:
# If the job already exists retrieve it
Expand All @@ -2963,18 +2964,24 @@ def execute(self, context: Any):
location=self.location,
job_id=self.job_id,
)
if job.state in self.reattach_states:
# We are reattaching to a job
job._begin()
self._handle_job_error(job)
else:
# Same job configuration so we need force_rerun

if job.state not in self.reattach_states:
# Same job configuration, so we need force_rerun
raise AirflowException(
f"Job with id: {self.job_id} already exists and is in {job.state} state. If you "
f"want to force rerun it consider setting `force_rerun=True`."
f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
)

else:
# Job already reached state DONE
if job.state == "DONE":
raise AirflowException("Job is already in state DONE. Can not reattach to this job.")

# We are reattaching to a job
self.log.info("Reattaching to existing Job in state %s", job.state)
self._handle_job_error(job)

job_types = {
LoadJob._JOB_TYPE: ["sourceTable", "destinationTable"],
CopyJob._JOB_TYPE: ["sourceTable", "destinationTable"],
Expand Down
7 changes: 1 addition & 6 deletions airflow/providers/openlineage/plugins/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,7 @@ def _build_run(
namespace=conf.namespace(),
name=parent_job_name or job_name,
)
facets.update(
{
"parent": parent_run_facet,
"parentRun": parent_run_facet, # Keep sending this for the backward compatibility
}
)
facets.update({"parent": parent_run_facet})

if run_facets:
facets.update(run_facets)
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ class TaskInfo(InfoJsonEncodable):
"_is_teardown": "is_teardown",
}
includes = [
"deferrable",
"depends_on_past",
"downstream_task_ids",
"execution_timeout",
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/pgvector/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ integrations:
dependencies:
- apache-airflow>=2.7.0
- apache-airflow-providers-postgres>=5.7.1
- pgvector>=0.2.3
# setting !=0.3.0 version due to https://github.com/pgvector/pgvector-python/issues/79
# observed in 0.3.0.
- pgvector!=0.3.0

hooks:
- integration-name: pgvector
Expand Down
6 changes: 6 additions & 0 deletions airflow/providers/slack/notifications/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
proxy: str | None = None,
timeout: int | None = None,
retry_handlers: list[RetryHandler] | None = None,
unfurl_links: bool = True,
unfurl_media: bool = True,
):
super().__init__()
self.slack_conn_id = slack_conn_id
Expand All @@ -77,6 +79,8 @@ def __init__(
self.timeout = timeout
self.proxy = proxy
self.retry_handlers = retry_handlers
self.unfurl_links = unfurl_links
self.unfurl_media = unfurl_media

@cached_property
def hook(self) -> SlackHook:
Expand All @@ -98,6 +102,8 @@ def notify(self, context):
"icon_url": self.icon_url,
"attachments": json.dumps(self.attachments),
"blocks": json.dumps(self.blocks),
"unfurl_links": self.unfurl_links,
"unfurl_media": self.unfurl_media,
}
self.hook.call("chat.postMessage", json=api_call_params)

Expand Down
9 changes: 4 additions & 5 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
from pendulum import DateTime

from airflow.models import DagRun
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

Expand Down Expand Up @@ -265,7 +266,7 @@ def close(self):
@internal_api_call
@provide_session
def _render_filename_db_access(
*, ti, try_number: int, session=None
*, ti: TaskInstance | TaskInstancePydantic, try_number: int, session=None
) -> tuple[DagRun | DagRunPydantic, TaskInstance | TaskInstancePydantic, str | None, str | None]:
ti = _ensure_ti(ti, session)
dag_run = ti.get_dagrun(session=session)
Expand All @@ -281,9 +282,7 @@ def _render_filename_db_access(
filename = render_template_to_string(jinja_tpl, context)
return dag_run, ti, str_tpl, filename

def _render_filename(
self, ti: TaskInstance | TaskInstanceKey | TaskInstancePydantic, try_number: int
) -> str:
def _render_filename(self, ti: TaskInstance | TaskInstancePydantic, try_number: int) -> str:
"""Return the worker log filename."""
dag_run, ti, str_tpl, filename = self._render_filename_db_access(ti=ti, try_number=try_number)
if filename:
Expand Down
Loading

0 comments on commit 0a3c7a1

Please sign in to comment.