Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into pr_teradata_release…
Browse files Browse the repository at this point in the history
…_1.0.0
  • Loading branch information
satish-chinthanippu committed Feb 9, 2024
2 parents e859a1d + 00ed467 commit f56bede
Show file tree
Hide file tree
Showing 33 changed files with 436 additions and 139 deletions.
5 changes: 0 additions & 5 deletions airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,6 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SA
select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)
).scalar_one()
dag_run.state = state
if state == DagRunState.RUNNING:
dag_run.start_date = timezone.utcnow()
dag_run.end_date = None
else:
dag_run.end_date = timezone.utcnow()
session.merge(dag_run)


Expand Down
4 changes: 2 additions & 2 deletions airflow/auth/managers/utils/fab.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@


def get_fab_action_from_method_map():
"""Returns the map associating a method to a FAB action."""
"""Return the map associating a method to a FAB action."""
return _MAP_METHOD_NAME_TO_FAB_ACTION_NAME


def get_method_from_fab_action_map():
"""Returns the map associating a FAB action to a method."""
"""Return the map associating a FAB action to a method."""
return {
**{v: k for k, v in _MAP_METHOD_NAME_TO_FAB_ACTION_NAME.items()},
ACTION_CAN_ACCESS_MENU: "GET",
Expand Down
2 changes: 1 addition & 1 deletion airflow/decorators/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def bash_task(
python_callable: Callable | None = None,
**kwargs,
) -> TaskDecorator:
"""Wraps a function into a BashOperator.
"""Wrap a function into a BashOperator.
Accepts kwargs for operator kwargs. Can be reused in a single DAG. This function is only used only used
during type checking or auto-completion.
Expand Down
2 changes: 1 addition & 1 deletion airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self):
self.fail_fast = conf.getboolean("debug", "fail_fast")

def execute_async(self, *args, **kwargs) -> None:
"""The method is replaced by custom trigger_task implementation."""
"""Replace the method with a custom trigger_task implementation."""

def sync(self) -> None:
task_succeeded = True
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,7 @@ def defer(
raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)

def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context):
"""This method is called when a deferred task is resumed."""
"""Call this method when a deferred task is resumed."""
# __fail__ is a special signal value for next_method that indicates
# this task was scheduled specifically to fail.
if next_method == "__fail__":
Expand Down
70 changes: 67 additions & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,75 @@ def get_state(self):
return self._state

def set_state(self, state: DagRunState) -> None:
"""Change the state of the DagRan.
Changes to attributes are implemented in accordance with the following table
(rows represent old states, columns represent new states):
.. list-table:: State transition matrix
:header-rows: 1
:stub-columns: 1
* -
- QUEUED
- RUNNING
- SUCCESS
- FAILED
* - None
- queued_at = timezone.utcnow()
- if empty: start_date = timezone.utcnow()
end_date = None
- end_date = timezone.utcnow()
- end_date = timezone.utcnow()
* - QUEUED
- queued_at = timezone.utcnow()
- if empty: start_date = timezone.utcnow()
end_date = None
- end_date = timezone.utcnow()
- end_date = timezone.utcnow()
* - RUNNING
- queued_at = timezone.utcnow()
start_date = None
end_date = None
-
- end_date = timezone.utcnow()
- end_date = timezone.utcnow()
* - SUCCESS
- queued_at = timezone.utcnow()
start_date = None
end_date = None
- start_date = timezone.utcnow()
end_date = None
-
-
* - FAILED
- queued_at = timezone.utcnow()
start_date = None
end_date = None
- start_date = timezone.utcnow()
end_date = None
-
-
"""
if state not in State.dag_states:
raise ValueError(f"invalid DagRun state: {state}")
if self._state != state:
if state == DagRunState.QUEUED:
self.queued_at = timezone.utcnow()
self.start_date = None
self.end_date = None
if state == DagRunState.RUNNING:
if self._state in State.finished_dr_states:
self.start_date = timezone.utcnow()
else:
self.start_date = self.start_date or timezone.utcnow()
self.end_date = None
if self._state in State.unfinished_dr_states or self._state is None:
if state in State.finished_dr_states:
self.end_date = timezone.utcnow()
self._state = state
self.end_date = timezone.utcnow() if self._state in State.finished_dr_states else None
else:
if state == DagRunState.QUEUED:
self.queued_at = timezone.utcnow()

Expand Down Expand Up @@ -504,7 +568,7 @@ def get_task_instances(
session: Session = NEW_SESSION,
) -> list[TI]:
"""
Returns the task instances for this dag run.
Return the task instances for this dag run.
Redirect to DagRun.fetch_task_instances method.
Keep this method because it is widely used across the code.
Expand Down Expand Up @@ -547,7 +611,7 @@ def fetch_task_instance(
map_index: int = -1,
) -> TI | TaskInstancePydantic | None:
"""
Returns the task instance specified by task_id for this dag run.
Return the task instance specified by task_id for this dag run.
:param dag_id: the DAG id
:param dag_run_id: the DAG run id
Expand Down
12 changes: 6 additions & 6 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def _refresh_from_db(
*, task_instance: TaskInstance | TaskInstancePydantic, session: Session, lock_for_update: bool = False
) -> None:
"""
Refreshes the task instance from the database based on the primary key.
Refresh the task instance from the database based on the primary key.
:param task_instance: the task instance
:param session: SQLAlchemy ORM Session
Expand Down Expand Up @@ -531,7 +531,7 @@ def _set_duration(*, task_instance: TaskInstance | TaskInstancePydantic) -> None

def _stats_tags(*, task_instance: TaskInstance | TaskInstancePydantic) -> dict[str, str]:
"""
Returns task instance tags.
Return task instance tags.
:param task_instance: the task instance
Expand Down Expand Up @@ -943,7 +943,7 @@ def _get_previous_dagrun(
session: Session | None = None,
) -> DagRun | None:
"""
The DagRun that ran before this task instance's DagRun.
Return the DagRun that ran prior to this task instance's DagRun.
:param task_instance: the task instance
:param state: If passed, it only take into account instances of a specific state.
Expand Down Expand Up @@ -983,7 +983,7 @@ def _get_previous_execution_date(
session: Session,
) -> pendulum.DateTime | None:
"""
The execution date from property previous_ti_success.
Get execution date from property previous_ti_success.
:param task_instance: the task instance
:param session: SQLAlchemy ORM Session
Expand Down Expand Up @@ -1178,7 +1178,7 @@ def _get_previous_ti(
state: DagRunState | None = None,
) -> TaskInstance | TaskInstancePydantic | None:
"""
The task instance for the task that ran before this task instance.
Get task instance for the task that ran before this task instance.
:param task_instance: the task instance
:param state: If passed, it only take into account instances of a specific state.
Expand Down Expand Up @@ -1436,7 +1436,7 @@ def try_number(self):
@try_number.expression
def try_number(cls):
"""
This is what will be used by SQLAlchemy when filtering on try_number.
Return the expression to be used by SQLAlchemy when filtering on try_number.
This is required because the override in the get_try_number function causes
try_number values to be off by one when listing tasks in the UI.
Expand Down
4 changes: 2 additions & 2 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def _prepare_venv(self, venv_path: Path) -> None:
)

def _calculate_cache_hash(self) -> tuple[str, str]:
"""Helper to generate the hash of the cache folder to use.
"""Generate the hash of the cache folder to use.
The following factors are used as input for the hash:
- (sorted) list of requirements
Expand All @@ -666,7 +666,7 @@ def _calculate_cache_hash(self) -> tuple[str, str]:
return requirements_hash[:8], hash_text

def _ensure_venv_cache_exists(self, venv_cache_path: Path) -> Path:
"""Helper to ensure a valid virtual environment is set up and will create inplace."""
"""Ensure a valid virtual environment is set up and will create inplace."""
cache_hash, hash_data = self._calculate_cache_hash()
venv_path = venv_cache_path / f"venv-{cache_hash}"
self.log.info("Python virtual environment will be cached in %s", venv_path)
Expand Down
6 changes: 3 additions & 3 deletions airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,14 @@ class AirflowPlugin:

@classmethod
def validate(cls):
"""Validates that plugin has a name."""
"""Validate if plugin has a name."""
if not cls.name:
raise AirflowPluginException("Your plugin needs a name.")

@classmethod
def on_load(cls, *args, **kwargs):
"""
Executed when the plugin is loaded; This method is only called once during runtime.
Execute when the plugin is loaded; This method is only called once during runtime.
:param args: If future arguments are passed in on call.
:param kwargs: If future arguments are passed in on call.
Expand Down Expand Up @@ -296,7 +296,7 @@ def load_providers_plugins():


def make_module(name: str, objects: list[Any]):
"""Creates new module."""
"""Create new module."""
if not objects:
return None
log.debug("Creating module %s", name)
Expand Down
22 changes: 14 additions & 8 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.utils import timezone
from airflow.utils.helpers import merge_dicts
from airflow.utils.state import State

if TYPE_CHECKING:
Expand Down Expand Up @@ -408,8 +409,8 @@ def _run_task(
The command and executor config will be placed in the container-override
section of the JSON request before calling Boto3's "run_task" function.
"""
run_task_api = self._run_task_kwargs(task_id, cmd, queue, exec_config)
boto_run_task = self.ecs.run_task(**run_task_api)
run_task_kwargs = self._run_task_kwargs(task_id, cmd, queue, exec_config)
boto_run_task = self.ecs.run_task(**run_task_kwargs)
run_task_response = BotoRunTaskSchema().load(boto_run_task)
return run_task_response

Expand All @@ -421,17 +422,17 @@ def _run_task_kwargs(
One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client.
"""
run_task_api = deepcopy(self.run_task_kwargs)
container_override = self.get_container(run_task_api["overrides"]["containerOverrides"])
run_task_kwargs = deepcopy(self.run_task_kwargs)
run_task_kwargs = merge_dicts(run_task_kwargs, exec_config)
container_override = self.get_container(run_task_kwargs["overrides"]["containerOverrides"])
container_override["command"] = cmd
container_override.update(exec_config)

# Inject the env variable to configure logging for containerized execution environment
if "environment" not in container_override:
container_override["environment"] = []
container_override["environment"].append({"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"})

return run_task_api
return run_task_kwargs

def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
"""Save the task to be executed in the next sync by inserting the commands into a queue."""
Expand Down Expand Up @@ -484,6 +485,11 @@ def _load_run_kwargs(self) -> dict:
def get_container(self, container_list):
"""Searches task list for core Airflow container."""
for container in container_list:
if container["name"] == self.container_name:
return container
try:
if container["name"] == self.container_name:
return container
except KeyError:
raise EcsExecutorException(
'container "name" must be provided in "containerOverrides" configuration'
)
raise KeyError(f"No such container found by container name: {self.container_name}")
2 changes: 1 addition & 1 deletion airflow/providers/fab/auth_manager/decorators/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def decorated(*args, **kwargs):

def _has_access_fab(permissions: Sequence[tuple[str, str]] | None = None) -> Callable[[T], T]:
"""
Factory for decorator that checks current user's permissions against required permissions.
Check current user's permissions against required permissions.
This decorator is only kept for backward compatible reasons. The decorator
``airflow.www.auth.has_access``, which redirects to this decorator, is widely used in user plugins.
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/fab/auth_manager/fab_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def _get_fab_resource_types(dag_access_entity: DagAccessEntity) -> tuple[str, ..

def _resource_name_for_dag(self, dag_id: str) -> str:
"""
Returns the FAB resource name for a DAG id.
Return the FAB resource name for a DAG id.
:param dag_id: the DAG id
Expand Down
Loading

0 comments on commit f56bede

Please sign in to comment.