From 0f8dfebdd6cd30e604d2180998f976d0c7b62277 Mon Sep 17 00:00:00 2001 From: Sudipto Baral Date: Fri, 9 Feb 2024 06:54:23 -0500 Subject: [PATCH 1/8] fix: update hyperlink to the new documentation section for local virtualenv setup. (#37272) Signed-off-by: sudipto baral --- contributing-docs/06_development_environments.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/contributing-docs/06_development_environments.rst b/contributing-docs/06_development_environments.rst index e442ed735a1f12..c467d6932855c1 100644 --- a/contributing-docs/06_development_environments.rst +++ b/contributing-docs/06_development_environments.rst @@ -32,7 +32,7 @@ in `07_local_virtualenv.rst <07_local_virtualenv.rst>`__. Benefits: - Packages are installed locally. No container environment is required. -- You can benefit from local debugging within your IDE. You can follow the `Contributors quick start `__ +- You can benefit from local debugging within your IDE. You can follow the `Local and remote debugging in IDE <07_local_virtualenv.rst#local-and-remote-debugging-in-ide>`__ to set up your local virtualenv and connect your IDE with the environment. - With the virtualenv in your IDE, you can benefit from auto completion and running tests directly from the IDE. From 9f4f208b5da38bc2e82db682c636ec4fcf7ad617 Mon Sep 17 00:00:00 2001 From: Aleksey Kirilishin <54231417+avkirilishin@users.noreply.github.com> Date: Fri, 9 Feb 2024 15:53:04 +0300 Subject: [PATCH 2/8] Fix the bug that affected the DAG end date. (#36144) --- airflow/api/common/mark_tasks.py | 5 -- airflow/models/dagrun.py | 66 ++++++++++++++++++- .../endpoints/test_dag_run_endpoint.py | 4 +- .../client/test_local_client.py | 4 +- .../common/test_mark_tasks.py | 49 ++++++++++---- tests/models/test_cleartasks.py | 5 +- 6 files changed, 107 insertions(+), 26 deletions(-) diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index 3cc6dfdfd715b6..a175a61e207ea6 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -366,11 +366,6 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SA select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id) ).scalar_one() dag_run.state = state - if state == DagRunState.RUNNING: - dag_run.start_date = timezone.utcnow() - dag_run.end_date = None - else: - dag_run.end_date = timezone.utcnow() session.merge(dag_run) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 6b4fdbd0cf39c2..aba2ce3fbb861a 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -269,11 +269,75 @@ def get_state(self): return self._state def set_state(self, state: DagRunState) -> None: + """Change the state of the DagRan. + + Changes to attributes are implemented in accordance with the following table + (rows represent old states, columns represent new states): + + .. list-table:: State transition matrix + :header-rows: 1 + :stub-columns: 1 + + * - + - QUEUED + - RUNNING + - SUCCESS + - FAILED + * - None + - queued_at = timezone.utcnow() + - if empty: start_date = timezone.utcnow() + end_date = None + - end_date = timezone.utcnow() + - end_date = timezone.utcnow() + * - QUEUED + - queued_at = timezone.utcnow() + - if empty: start_date = timezone.utcnow() + end_date = None + - end_date = timezone.utcnow() + - end_date = timezone.utcnow() + * - RUNNING + - queued_at = timezone.utcnow() + start_date = None + end_date = None + - + - end_date = timezone.utcnow() + - end_date = timezone.utcnow() + * - SUCCESS + - queued_at = timezone.utcnow() + start_date = None + end_date = None + - start_date = timezone.utcnow() + end_date = None + - + - + * - FAILED + - queued_at = timezone.utcnow() + start_date = None + end_date = None + - start_date = timezone.utcnow() + end_date = None + - + - + + """ if state not in State.dag_states: raise ValueError(f"invalid DagRun state: {state}") if self._state != state: + if state == DagRunState.QUEUED: + self.queued_at = timezone.utcnow() + self.start_date = None + self.end_date = None + if state == DagRunState.RUNNING: + if self._state in State.finished_dr_states: + self.start_date = timezone.utcnow() + else: + self.start_date = self.start_date or timezone.utcnow() + self.end_date = None + if self._state in State.unfinished_dr_states or self._state is None: + if state in State.finished_dr_states: + self.end_date = timezone.utcnow() self._state = state - self.end_date = timezone.utcnow() if self._state in State.finished_dr_states else None + else: if state == DagRunState.QUEUED: self.queued_at = timezone.utcnow() diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 045b5392f5f0ba..9f3d0666bd0626 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -1600,11 +1600,11 @@ def test_should_respond_200(self, state, run_type, dag_maker, session): "conf": {}, "dag_id": dag_id, "dag_run_id": dag_run_id, - "end_date": dr.end_date.isoformat(), + "end_date": dr.end_date.isoformat() if state != State.QUEUED else None, "execution_date": dr.execution_date.isoformat(), "external_trigger": False, "logical_date": dr.execution_date.isoformat(), - "start_date": dr.start_date.isoformat(), + "start_date": dr.start_date.isoformat() if state != State.QUEUED else None, "state": state, "data_interval_start": dr.data_interval_start.isoformat(), "data_interval_end": dr.data_interval_end.isoformat(), diff --git a/tests/api_experimental/client/test_local_client.py b/tests/api_experimental/client/test_local_client.py index b02a5a5c422972..91a81a0cafaee3 100644 --- a/tests/api_experimental/client/test_local_client.py +++ b/tests/api_experimental/client/test_local_client.py @@ -135,13 +135,11 @@ def test_trigger_dag(self, mock): # test output queued_at = pendulum.now() - started_at = pendulum.now() mock.return_value = DagRun( dag_id=test_dag_id, run_id=run_id, queued_at=queued_at, execution_date=EXECDATE, - start_date=started_at, external_trigger=True, state=DagRunState.QUEUED, conf={}, @@ -159,7 +157,7 @@ def test_trigger_dag(self, mock): "last_scheduling_decision": None, "logical_date": EXECDATE, "run_type": DagRunType.MANUAL, - "start_date": started_at, + "start_date": None, "state": DagRunState.QUEUED, } dag_run = self.client.trigger_dag(dag_id=test_dag_id) diff --git a/tests/api_experimental/common/test_mark_tasks.py b/tests/api_experimental/common/test_mark_tasks.py index 47c10fa1853956..9b28136bba2797 100644 --- a/tests/api_experimental/common/test_mark_tasks.py +++ b/tests/api_experimental/common/test_mark_tasks.py @@ -555,20 +555,28 @@ def _verify_dag_run_state(self, dag, date, state): assert dr.get_state() == state @provide_session - def _verify_dag_run_dates(self, dag, date, state, middle_time, session=None): + def _verify_dag_run_dates(self, dag, date, state, middle_time=None, old_end_date=None, session=None): # When target state is RUNNING, we should set start_date, # otherwise we should set end_date. DR = DagRun dr = session.query(DR).filter(DR.dag_id == dag.dag_id, DR.execution_date == date).one() if state == State.RUNNING: # Since the DAG is running, the start_date must be updated after creation - assert dr.start_date > middle_time + if middle_time: + assert dr.start_date > middle_time # If the dag is still running, we don't have an end date assert dr.end_date is None else: - # If the dag is not running, there must be an end time - assert dr.start_date < middle_time - assert dr.end_date > middle_time + # If the dag is not running, there must be an end time, + # and the end time must not be changed if it has already been set. + if dr.start_date and middle_time: + assert dr.start_date < middle_time + if dr.end_date: + if old_end_date: + assert dr.end_date == old_end_date + else: + if middle_time: + assert dr.end_date > middle_time def test_set_running_dag_run_to_success(self): date = self.execution_dates[0] @@ -599,30 +607,42 @@ def test_set_running_dag_run_to_failed(self): assert dr.get_task_instance("run_after_loop").state == State.FAILED self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time) - @pytest.mark.parametrize( - "dag_run_alter_function, new_state", - [(set_dag_run_state_to_running, State.RUNNING), (set_dag_run_state_to_queued, State.QUEUED)], - ) - def test_set_running_dag_run_to_activate_state(self, dag_run_alter_function: Callable, new_state: State): + def test_set_running_dag_run_to_running_state(self): + date = self.execution_dates[0] # type: ignore + dr = self._create_test_dag_run(State.RUNNING, date) + self._set_default_task_instance_states(dr) + + altered = set_dag_run_state_to_running(dag=self.dag1, run_id=dr.run_id, commit=True) # type: ignore + + # None of the tasks should be altered, only the dag itself + assert len(altered) == 0 + new_state = State.RUNNING + self._verify_dag_run_state(self.dag1, date, new_state) # type: ignore + self._verify_task_instance_states_remain_default(dr) + self._verify_dag_run_dates(self.dag1, date, new_state) # type: ignore + + def test_set_running_dag_run_to_queued_state(self): date = self.execution_dates[0] # type: ignore dr = self._create_test_dag_run(State.RUNNING, date) middle_time = timezone.utcnow() self._set_default_task_instance_states(dr) - altered = dag_run_alter_function(dag=self.dag1, run_id=dr.run_id, commit=True) # type: ignore + altered = set_dag_run_state_to_queued(dag=self.dag1, run_id=dr.run_id, commit=True) # type: ignore # None of the tasks should be altered, only the dag itself assert len(altered) == 0 + new_state = State.QUEUED self._verify_dag_run_state(self.dag1, date, new_state) # type: ignore self._verify_task_instance_states_remain_default(dr) self._verify_dag_run_dates(self.dag1, date, new_state, middle_time) # type: ignore @pytest.mark.parametrize("completed_state", [State.SUCCESS, State.FAILED]) - def test_set_success_dag_run_to_success(self, completed_state): + def test_set_completed_dag_run_to_success(self, completed_state): date = self.execution_dates[0] dr = self._create_test_dag_run(completed_state, date) middle_time = timezone.utcnow() self._set_default_task_instance_states(dr) + old_end_date = dr.end_date altered = set_dag_run_state_to_success(dag=self.dag1, run_id=dr.run_id, commit=True) @@ -631,13 +651,14 @@ def test_set_success_dag_run_to_success(self, completed_state): assert len(altered) == expected self._verify_dag_run_state(self.dag1, date, State.SUCCESS) self._verify_task_instance_states(self.dag1, date, State.SUCCESS) - self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, middle_time) + self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, middle_time, old_end_date) @pytest.mark.parametrize("completed_state", [State.SUCCESS, State.FAILED]) def test_set_completed_dag_run_to_failed(self, completed_state): date = self.execution_dates[0] dr = self._create_test_dag_run(completed_state, date) middle_time = timezone.utcnow() + old_end_date = dr.end_date self._set_default_task_instance_states(dr) altered = set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id, commit=True) @@ -646,7 +667,7 @@ def test_set_completed_dag_run_to_failed(self, completed_state): assert len(altered) == expected self._verify_dag_run_state(self.dag1, date, State.FAILED) assert dr.get_task_instance("run_after_loop").state == State.FAILED - self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time) + self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time, old_end_date) @pytest.mark.parametrize( "dag_run_alter_function,new_state", diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py index ed0232926aede3..bce9dc4668a1d4 100644 --- a/tests/models/test_cleartasks.py +++ b/tests/models/test_cleartasks.py @@ -210,7 +210,10 @@ def test_clear_task_instances_on_running_dr(self, state, dag_maker): session.refresh(dr) assert dr.state == state - assert dr.start_date + if state == DagRunState.QUEUED: + assert dr.start_date is None + if state == DagRunState.RUNNING: + assert dr.start_date assert dr.last_scheduling_decision == DEFAULT_DATE @pytest.mark.parametrize( From 7835fd2659335d3acf830ce2e70dc19bfc5b2a84 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Fri, 9 Feb 2024 14:59:33 +0100 Subject: [PATCH 3/8] The fix-ownership command missed --rm flag and left dangling containers (#37277) Fixes: #37269 --- dev/breeze/src/airflow_breeze/utils/docker_command_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py b/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py index 5a4e569e42a7bb..35d8e56d809ad4 100644 --- a/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py +++ b/dev/breeze/src/airflow_breeze/utils/docker_command_utils.py @@ -567,6 +567,7 @@ def fix_ownership_using_docker(quiet: bool = False): f"DOCKER_IS_ROOTLESS={is_docker_rootless()}", "-e", f"VERBOSE_COMMANDS={str(not quiet).lower()}", + "--rm", "-t", OWNERSHIP_CLEANUP_DOCKER_TAG, "/opt/airflow/scripts/in_container/run_fix_ownership.py", From ab9e2e166eb363d8316b24b8548d401faa7d517b Mon Sep 17 00:00:00 2001 From: Kalyan Date: Fri, 9 Feb 2024 20:15:31 +0530 Subject: [PATCH 4/8] fix: D401 lint issues in airflow core (#37274) --- airflow/auth/managers/utils/fab.py | 4 ++-- airflow/decorators/bash.py | 2 +- airflow/executors/debug_executor.py | 2 +- airflow/models/baseoperator.py | 2 +- airflow/models/dagrun.py | 4 ++-- airflow/models/taskinstance.py | 12 +++++------ airflow/operators/python.py | 4 ++-- airflow/plugins_manager.py | 6 +++--- airflow/providers_manager.py | 26 ++++++++++++------------ airflow/serialization/serde.py | 6 +++--- airflow/utils/file.py | 2 +- airflow/utils/log/task_context_logger.py | 2 +- airflow/utils/sqlalchemy.py | 2 ++ airflow/www/auth.py | 4 ++-- airflow/www/blueprints.py | 2 +- airflow/www/views.py | 2 +- pyproject.toml | 16 --------------- 17 files changed, 42 insertions(+), 56 deletions(-) diff --git a/airflow/auth/managers/utils/fab.py b/airflow/auth/managers/utils/fab.py index 316e5ecff1658d..22b572e07f5053 100644 --- a/airflow/auth/managers/utils/fab.py +++ b/airflow/auth/managers/utils/fab.py @@ -40,12 +40,12 @@ def get_fab_action_from_method_map(): - """Returns the map associating a method to a FAB action.""" + """Return the map associating a method to a FAB action.""" return _MAP_METHOD_NAME_TO_FAB_ACTION_NAME def get_method_from_fab_action_map(): - """Returns the map associating a FAB action to a method.""" + """Return the map associating a FAB action to a method.""" return { **{v: k for k, v in _MAP_METHOD_NAME_TO_FAB_ACTION_NAME.items()}, ACTION_CAN_ACCESS_MENU: "GET", diff --git a/airflow/decorators/bash.py b/airflow/decorators/bash.py index 70011c30790346..36fc646370ee73 100644 --- a/airflow/decorators/bash.py +++ b/airflow/decorators/bash.py @@ -84,7 +84,7 @@ def bash_task( python_callable: Callable | None = None, **kwargs, ) -> TaskDecorator: - """Wraps a function into a BashOperator. + """Wrap a function into a BashOperator. Accepts kwargs for operator kwargs. Can be reused in a single DAG. This function is only used only used during type checking or auto-completion. diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py index 750b0ba20b033e..9b376cdb010228 100644 --- a/airflow/executors/debug_executor.py +++ b/airflow/executors/debug_executor.py @@ -61,7 +61,7 @@ def __init__(self): self.fail_fast = conf.getboolean("debug", "fail_fast") def execute_async(self, *args, **kwargs) -> None: - """The method is replaced by custom trigger_task implementation.""" + """Replace the method with a custom trigger_task implementation.""" def sync(self) -> None: task_succeeded = True diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index ce55b24350b494..e2406776d86c0d 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1602,7 +1602,7 @@ def defer( raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): - """This method is called when a deferred task is resumed.""" + """Call this method when a deferred task is resumed.""" # __fail__ is a special signal value for next_method that indicates # this task was scheduled specifically to fail. if next_method == "__fail__": diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index aba2ce3fbb861a..f9126dd6313bf5 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -568,7 +568,7 @@ def get_task_instances( session: Session = NEW_SESSION, ) -> list[TI]: """ - Returns the task instances for this dag run. + Return the task instances for this dag run. Redirect to DagRun.fetch_task_instances method. Keep this method because it is widely used across the code. @@ -611,7 +611,7 @@ def fetch_task_instance( map_index: int = -1, ) -> TI | TaskInstancePydantic | None: """ - Returns the task instance specified by task_id for this dag run. + Return the task instance specified by task_id for this dag run. :param dag_id: the DAG id :param dag_run_id: the DAG run id diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 38b787cbe8fadb..01a84fc8834d94 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -461,7 +461,7 @@ def _refresh_from_db( *, task_instance: TaskInstance | TaskInstancePydantic, session: Session, lock_for_update: bool = False ) -> None: """ - Refreshes the task instance from the database based on the primary key. + Refresh the task instance from the database based on the primary key. :param task_instance: the task instance :param session: SQLAlchemy ORM Session @@ -531,7 +531,7 @@ def _set_duration(*, task_instance: TaskInstance | TaskInstancePydantic) -> None def _stats_tags(*, task_instance: TaskInstance | TaskInstancePydantic) -> dict[str, str]: """ - Returns task instance tags. + Return task instance tags. :param task_instance: the task instance @@ -943,7 +943,7 @@ def _get_previous_dagrun( session: Session | None = None, ) -> DagRun | None: """ - The DagRun that ran before this task instance's DagRun. + Return the DagRun that ran prior to this task instance's DagRun. :param task_instance: the task instance :param state: If passed, it only take into account instances of a specific state. @@ -983,7 +983,7 @@ def _get_previous_execution_date( session: Session, ) -> pendulum.DateTime | None: """ - The execution date from property previous_ti_success. + Get execution date from property previous_ti_success. :param task_instance: the task instance :param session: SQLAlchemy ORM Session @@ -1178,7 +1178,7 @@ def _get_previous_ti( state: DagRunState | None = None, ) -> TaskInstance | TaskInstancePydantic | None: """ - The task instance for the task that ran before this task instance. + Get task instance for the task that ran before this task instance. :param task_instance: the task instance :param state: If passed, it only take into account instances of a specific state. @@ -1436,7 +1436,7 @@ def try_number(self): @try_number.expression def try_number(cls): """ - This is what will be used by SQLAlchemy when filtering on try_number. + Return the expression to be used by SQLAlchemy when filtering on try_number. This is required because the override in the get_try_number function causes try_number values to be off by one when listing tasks in the UI. diff --git a/airflow/operators/python.py b/airflow/operators/python.py index 1b1453cc5ed50f..0f005f43b266aa 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -640,7 +640,7 @@ def _prepare_venv(self, venv_path: Path) -> None: ) def _calculate_cache_hash(self) -> tuple[str, str]: - """Helper to generate the hash of the cache folder to use. + """Generate the hash of the cache folder to use. The following factors are used as input for the hash: - (sorted) list of requirements @@ -666,7 +666,7 @@ def _calculate_cache_hash(self) -> tuple[str, str]: return requirements_hash[:8], hash_text def _ensure_venv_cache_exists(self, venv_cache_path: Path) -> Path: - """Helper to ensure a valid virtual environment is set up and will create inplace.""" + """Ensure a valid virtual environment is set up and will create inplace.""" cache_hash, hash_data = self._calculate_cache_hash() venv_path = venv_cache_path / f"venv-{cache_hash}" self.log.info("Python virtual environment will be cached in %s", venv_path) diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py index 143e3af5707bc5..6514409ef493de 100644 --- a/airflow/plugins_manager.py +++ b/airflow/plugins_manager.py @@ -171,14 +171,14 @@ class AirflowPlugin: @classmethod def validate(cls): - """Validates that plugin has a name.""" + """Validate if plugin has a name.""" if not cls.name: raise AirflowPluginException("Your plugin needs a name.") @classmethod def on_load(cls, *args, **kwargs): """ - Executed when the plugin is loaded; This method is only called once during runtime. + Execute when the plugin is loaded; This method is only called once during runtime. :param args: If future arguments are passed in on call. :param kwargs: If future arguments are passed in on call. @@ -296,7 +296,7 @@ def load_providers_plugins(): def make_module(name: str, objects: list[Any]): - """Creates new module.""" + """Create new module.""" if not objects: return None log.debug("Creating module %s", name) diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 075473796bc8fa..1f1fe397b9785e 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -146,7 +146,7 @@ def _read_schema_from_resources_or_local_file(filename: str) -> dict: def _create_provider_info_schema_validator(): - """Creates JSON schema validator from the provider_info.schema.json.""" + """Create JSON schema validator from the provider_info.schema.json.""" import jsonschema schema = _read_schema_from_resources_or_local_file("provider_info.schema.json") @@ -156,7 +156,7 @@ def _create_provider_info_schema_validator(): def _create_customized_form_field_behaviours_schema_validator(): - """Creates JSON schema validator from the customized_form_field_behaviours.schema.json.""" + """Create JSON schema validator from the customized_form_field_behaviours.schema.json.""" import jsonschema schema = _read_schema_from_resources_or_local_file("customized_form_field_behaviours.schema.json") @@ -305,7 +305,7 @@ def _correctness_check( provider_package: str, class_name: str, provider_info: ProviderInfo ) -> type[BaseHook] | None: """ - Performs coherence check on provider classes. + Perform coherence check on provider classes. For apache-airflow providers - it checks if it starts with appropriate package. For all providers it tries to import the provider - checking that there are no exceptions during importing. @@ -408,7 +408,7 @@ def initialization_stack_trace() -> str | None: return ProvidersManager._initialization_stack_trace def __init__(self): - """Initializes the manager.""" + """Initialize the manager.""" super().__init__() ProvidersManager._initialized = True ProvidersManager._initialization_stack_trace = "".join(traceback.format_stack(inspect.currentframe())) @@ -445,7 +445,7 @@ def __init__(self): self._init_airflow_core_hooks() def _init_airflow_core_hooks(self): - """Initializes the hooks dict with default hooks from Airflow core.""" + """Initialize the hooks dict with default hooks from Airflow core.""" core_dummy_hooks = { "generic": "Generic", "email": "Email", @@ -563,7 +563,7 @@ def initialize_providers_configuration(self): def _initialize_providers_configuration(self): """ - Internal method to initialize providers configuration information. + Initialize providers configuration information. Should be used if we do not want to trigger caching for ``initialize_providers_configuration`` method. In some cases we might want to make sure that the configuration is initialized, but we do not want @@ -626,7 +626,7 @@ def _discover_all_providers_from_packages(self) -> None: def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None: """ - Finds all built-in airflow providers if airflow is run from the local sources. + Find all built-in airflow providers if airflow is run from the local sources. It finds `provider.yaml` files for all such providers and registers the providers using those. @@ -654,7 +654,7 @@ def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None: def _add_provider_info_from_local_source_files_on_path(self, path) -> None: """ - Finds all the provider.yaml files in the directory specified. + Find all the provider.yaml files in the directory specified. :param path: path where to look for provider.yaml files """ @@ -672,7 +672,7 @@ def _add_provider_info_from_local_source_files_on_path(self, path) -> None: def _add_provider_info_from_local_source_file(self, path, package_name) -> None: """ - Parses found provider.yaml file and adds found provider to the dictionary. + Parse found provider.yaml file and adds found provider to the dictionary. :param path: full file path of the provider.yaml file :param package_name: name of the package @@ -1069,7 +1069,7 @@ def _add_customized_fields(self, package_name: str, hook_class: type, customized ) def _discover_auth_managers(self) -> None: - """Retrieves all auth managers defined in the providers.""" + """Retrieve all auth managers defined in the providers.""" for provider_package, provider in self._provider_dict.items(): if provider.data.get("auth-managers"): for auth_manager_class_name in provider.data["auth-managers"]: @@ -1077,7 +1077,7 @@ def _discover_auth_managers(self) -> None: self._auth_manager_class_name_set.add(auth_manager_class_name) def _discover_notifications(self) -> None: - """Retrieves all notifications defined in the providers.""" + """Retrieve all notifications defined in the providers.""" for provider_package, provider in self._provider_dict.items(): if provider.data.get("notifications"): for notification_class_name in provider.data["notifications"]: @@ -1085,7 +1085,7 @@ def _discover_notifications(self) -> None: self._notification_info_set.add(notification_class_name) def _discover_extra_links(self) -> None: - """Retrieves all extra links defined in the providers.""" + """Retrieve all extra links defined in the providers.""" for provider_package, provider in self._provider_dict.items(): if provider.data.get("extra-links"): for extra_link_class_name in provider.data["extra-links"]: @@ -1149,7 +1149,7 @@ def _discover_plugins(self) -> None: @provider_info_cache("triggers") def initialize_providers_triggers(self): - """Initialization of providers triggers.""" + """Initialize providers triggers.""" self.initialize_providers_list() for provider_package, provider in self._provider_dict.items(): for trigger in provider.data.get("triggers", []): diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py index a214acc9a6677d..fd7eb33af72846 100644 --- a/airflow/serialization/serde.py +++ b/airflow/serialization/serde.py @@ -288,20 +288,20 @@ def _convert(old: dict) -> dict: def _match(classname: str) -> bool: - """Checks if the given classname matches a path pattern either using glob format or regexp format.""" + """Check if the given classname matches a path pattern either using glob format or regexp format.""" return _match_glob(classname) or _match_regexp(classname) @functools.lru_cache(maxsize=None) def _match_glob(classname: str): - """Checks if the given classname matches a pattern from allowed_deserialization_classes using glob syntax.""" + """Check if the given classname matches a pattern from allowed_deserialization_classes using glob syntax.""" patterns = _get_patterns() return any(fnmatch(classname, p.pattern) for p in patterns) @functools.lru_cache(maxsize=None) def _match_regexp(classname: str): - """Checks if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp.""" + """Check if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp.""" patterns = _get_regexp_patterns() return any(p.match(classname) is not None for p in patterns) diff --git a/airflow/utils/file.py b/airflow/utils/file.py index 7e15eeb2f8d72c..c66358a10aebf8 100644 --- a/airflow/utils/file.py +++ b/airflow/utils/file.py @@ -385,7 +385,7 @@ def iter_airflow_imports(file_path: str) -> Generator[str, None, None]: def get_unique_dag_module_name(file_path: str) -> str: - """Returns a unique module name in the format unusual_prefix_{sha1 of module's file path}_{original module name}.""" + """Return a unique module name in the format unusual_prefix_{sha1 of module's file path}_{original module name}.""" if isinstance(file_path, str): path_hash = hashlib.sha1(file_path.encode("utf-8")).hexdigest() org_mod_name = Path(file_path).stem diff --git a/airflow/utils/log/task_context_logger.py b/airflow/utils/log/task_context_logger.py index 84ed207e3ae9cf..46e8cf8cee3823 100644 --- a/airflow/utils/log/task_context_logger.py +++ b/airflow/utils/log/task_context_logger.py @@ -65,7 +65,7 @@ def _should_enable(self) -> bool: @staticmethod def _get_task_handler() -> FileTaskHandler | None: - """Returns the task handler that supports task context logging.""" + """Return the task handler that supports task context logging.""" handlers = [ handler for handler in logging.getLogger("airflow.task").handlers diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index 2dc495811ae7b2..6ce0d00207ae00 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -271,6 +271,8 @@ def process(value): def compare_values(self, x, y): """ + Compare x and y using self.comparator if available. Else, use __eq__. + The TaskInstance.executor_config attribute is a pickled object that may contain kubernetes objects. If the installed library version has changed since the object was originally pickled, diff --git a/airflow/www/auth.py b/airflow/www/auth.py index a34621d56c416b..39c8444f993c2c 100644 --- a/airflow/www/auth.py +++ b/airflow/www/auth.py @@ -88,7 +88,7 @@ def has_access(permissions: Sequence[tuple[str, str]] | None = None) -> Callable def has_access_with_pk(f): """ - This decorator is used to check permissions on views. + Check permissions on views. The implementation is very similar from https://github.com/dpgaspar/Flask-AppBuilder/blob/c6fecdc551629e15467fde5d06b4437379d90592/flask_appbuilder/security/decorators.py#L134 @@ -345,5 +345,5 @@ def decorated(*args, **kwargs): def has_access_view(access_view: AccessView = AccessView.WEBSITE) -> Callable[[T], T]: - """Decorator that checks current user's permissions to access the website.""" + """Check current user's permissions to access the website.""" return _has_access_no_details(lambda: get_auth_manager().is_authorized_view(access_view=access_view)) diff --git a/airflow/www/blueprints.py b/airflow/www/blueprints.py index 0312a9ffa71639..fda6b65397b66b 100644 --- a/airflow/www/blueprints.py +++ b/airflow/www/blueprints.py @@ -24,5 +24,5 @@ @routes.route("/") def index(): - """Main Airflow page.""" + """Return main Airflow page.""" return redirect(url_for("Airflow.index")) diff --git a/airflow/www/views.py b/airflow/www/views.py index 6e6caeba3c4072..16253078905596 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -3994,7 +3994,7 @@ def delete(self, pk): @expose("/action_post", methods=["POST"]) def action_post(self): """ - Action method to handle multiple records selected from a list view. + Handle multiple records selected from a list view. Same implementation as https://github.com/dpgaspar/Flask-AppBuilder/blob/2c5763371b81cd679d88b9971ba5d1fc4d71d54b/flask_appbuilder/views.py#L677 diff --git a/pyproject.toml b/pyproject.toml index 3e56623f6f6ead..94d629db758968 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1350,22 +1350,6 @@ combine-as-imports = true "tests/providers/elasticsearch/log/elasticmock/utilities/__init__.py" = ["E402"] # All the modules which do not follow D401 yet, please remove as soon as it becomes compatible -"airflow/auth/managers/utils/fab.py" = ["D401"] -"airflow/decorators/bash.py" = ["D401"] -"airflow/executors/debug_executor.py" = ["D401"] -"airflow/models/baseoperator.py" = ["D401"] -"airflow/models/dagrun.py" = ["D401"] -"airflow/models/taskinstance.py" = ["D401"] -"airflow/operators/python.py" = ["D401"] -"airflow/plugins_manager.py" = ["D401"] -"airflow/providers_manager.py" = ["D401"] -"airflow/serialization/serde.py" = ["D401"] -"airflow/utils/log/task_context_logger.py" = ["D401"] -"airflow/utils/sqlalchemy.py" = ["D401"] -"airflow/www/auth.py" = ["D401"] -"airflow/www/blueprints.py" = ["D401"] -"airflow/www/views.py" = ["D401"] -"airflow/utils/file.py" = ["D401"] "airflow/providers/airbyte/hooks/airbyte.py" = ["D401"] "airflow/providers/airbyte/operators/airbyte.py" = ["D401"] "airflow/providers/airbyte/sensors/airbyte.py" = ["D401"] From 17945fc5edcec619ccde0fbab0d6fb8e0eb206cd Mon Sep 17 00:00:00 2001 From: Kalyan Date: Fri, 9 Feb 2024 20:16:33 +0530 Subject: [PATCH 5/8] D401 fixes in Pinecone provider (#37270) --- airflow/providers/pinecone/hooks/pinecone.py | 14 +++++++------- pyproject.toml | 1 - 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/airflow/providers/pinecone/hooks/pinecone.py b/airflow/providers/pinecone/hooks/pinecone.py index f15605556cf871..6a116250f49228 100644 --- a/airflow/providers/pinecone/hooks/pinecone.py +++ b/airflow/providers/pinecone/hooks/pinecone.py @@ -45,7 +45,7 @@ class PineconeHook(BaseHook): @classmethod def get_connection_form_widgets(cls) -> dict[str, Any]: - """Returns connection widgets to add to connection form.""" + """Return connection widgets to add to connection form.""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import StringField @@ -60,7 +60,7 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: - """Returns custom field behaviour.""" + """Return custom field behaviour.""" return { "hidden_fields": ["port", "schema"], "relabeling": {"login": "Pinecone Environment", "password": "Pinecone API key"}, @@ -108,7 +108,7 @@ def upsert( **kwargs: Any, ) -> UpsertResponse: """ - The upsert operation writes vectors into a namespace. + Write vectors into a namespace. If a new value is upserted for an existing vector id, it will overwrite the previous value. @@ -204,7 +204,7 @@ def delete_index(index_name: str, timeout: int | None = None) -> None: @staticmethod def configure_index(index_name: str, replicas: int | None = None, pod_type: str | None = "") -> None: """ - Changes current configuration of the index. + Change the current configuration of the index. :param index_name: The name of the index to configure. :param replicas: The new number of replicas. @@ -258,7 +258,7 @@ def query_vector( sparse_vector: SparseValues | dict[str, list[float] | list[int]] | None = None, ) -> QueryResponse: """ - The Query operation searches a namespace, using a query vector. + Search a namespace using query vector. It retrieves the ids of the most similar items in a namespace, along with their similarity scores. API reference: https://docs.pinecone.io/reference/query @@ -288,7 +288,7 @@ def query_vector( @staticmethod def _chunks(iterable: list[Any], batch_size: int = 100) -> Any: - """Helper function to break an iterable into chunks of size batch_size.""" + """Break an iterable into chunks of size batch_size.""" it = iter(iterable) chunk = tuple(itertools.islice(it, batch_size)) while chunk: @@ -329,7 +329,7 @@ def describe_index_stats( **kwargs: Any, ) -> DescribeIndexStatsResponse: """ - Describes the index statistics. + Describe the index statistics. Returns statistics about the index's contents. For example: The vector count per namespace and the number of dimensions. diff --git a/pyproject.toml b/pyproject.toml index 94d629db758968..1856aac2a4f34c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1546,7 +1546,6 @@ combine-as-imports = true "airflow/providers/pagerduty/hooks/pagerduty.py" = ["D401"] "airflow/providers/pagerduty/hooks/pagerduty_events.py" = ["D401"] "airflow/providers/papermill/hooks/kernel.py" = ["D401"] -"airflow/providers/pinecone/hooks/pinecone.py" = ["D401"] "airflow/providers/postgres/hooks/postgres.py" = ["D401"] "airflow/providers/presto/hooks/presto.py" = ["D401"] "airflow/providers/qdrant/hooks/qdrant.py" = ["D401"] From 8317ed93a58900d922ab4ca8da02ed1c6050252c Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 9 Feb 2024 20:17:38 +0530 Subject: [PATCH 6/8] Updating the README and visuals for breeze build-docs (#37276) --- .../airflow_breeze/commands/developer_commands.py | 6 +++++- docs/README.rst | 13 ++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/dev/breeze/src/airflow_breeze/commands/developer_commands.py b/dev/breeze/src/airflow_breeze/commands/developer_commands.py index 92882a781a367d..e9eb6290c56d9b 100644 --- a/dev/breeze/src/airflow_breeze/commands/developer_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/developer_commands.py @@ -656,7 +656,11 @@ def build_docs( fix_ownership_using_docker() if result.returncode == 0: get_console().print( - "[info]Start the webserver in breeze and view the built docs at http://localhost:28080/docs/[/]" + "[info]To view the built documentation, you have two options:\n\n" + "1. Start the webserver in breeze and access the built docs at " + "http://localhost:28080/docs/\n" + "2. Alternatively, you can run ./docs/start_docs_server.sh for a lighter resource option and view" + "the built docs at http://localhost:8000" ) sys.exit(result.returncode) diff --git a/docs/README.rst b/docs/README.rst index e16f22499e8945..d5cddec4864f54 100644 --- a/docs/README.rst +++ b/docs/README.rst @@ -162,19 +162,18 @@ Running the Docs Locally ------------------------ After you build the documentation, you can check the formatting, style, and documentation build at ``http://localhost:28080/docs/`` -by starting a Breeze environment or by running the following command from the root directory. - -You need to have Python installed to run the command: +by starting a Breeze environment. Alternatively, you can run the following command from the root directory: .. code-block:: bash docs/start_doc_server.sh +This command requires Python to be installed. This method is lighter on the system resources as you do not need to +launch the webserver just to view docs. -Then, view your docs at ``localhost:8000``. If you use a virtual machine, like WSL2, -you need to find the WSL2 machine IP address and replace "0.0.0.0" in your browser with it. The address looks like -``http://n.n.n.n:8000``, where n.n.n.n is the IP of the WSL2. - +Once the server is running, you can view your documentation at http://localhost:8000. If you're using a virtual machine +like WSL2, you'll need to find the IP address of the WSL2 machine and replace "0.0.0.0" in your browser with it. +The address will look like http://n.n.n.n:8000, where n.n.n.n is the IP of the WSL2 machine. Cross-referencing syntax ======================== From 48bfb1a970f5b47ba1b385ad809b8324923ddf3e Mon Sep 17 00:00:00 2001 From: Niko Oliveira Date: Fri, 9 Feb 2024 08:43:32 -0800 Subject: [PATCH 7/8] Merge all ECS executor configs following recursive python dict update (#37137) Also document the behaviour and interaction between exec_config and run_task_kwargs config --- .../amazon/aws/executors/ecs/ecs_executor.py | 22 +- .../executors/ecs-executor.rst | 8 +- .../aws/executors/ecs/test_ecs_executor.py | 218 ++++++++++++++++++ 3 files changed, 239 insertions(+), 9 deletions(-) diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 2f0564ed9a3400..e6594e270f8be5 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -48,6 +48,7 @@ ) from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.utils import timezone +from airflow.utils.helpers import merge_dicts from airflow.utils.state import State if TYPE_CHECKING: @@ -408,8 +409,8 @@ def _run_task( The command and executor config will be placed in the container-override section of the JSON request before calling Boto3's "run_task" function. """ - run_task_api = self._run_task_kwargs(task_id, cmd, queue, exec_config) - boto_run_task = self.ecs.run_task(**run_task_api) + run_task_kwargs = self._run_task_kwargs(task_id, cmd, queue, exec_config) + boto_run_task = self.ecs.run_task(**run_task_kwargs) run_task_response = BotoRunTaskSchema().load(boto_run_task) return run_task_response @@ -421,17 +422,17 @@ def _run_task_kwargs( One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client. """ - run_task_api = deepcopy(self.run_task_kwargs) - container_override = self.get_container(run_task_api["overrides"]["containerOverrides"]) + run_task_kwargs = deepcopy(self.run_task_kwargs) + run_task_kwargs = merge_dicts(run_task_kwargs, exec_config) + container_override = self.get_container(run_task_kwargs["overrides"]["containerOverrides"]) container_override["command"] = cmd - container_override.update(exec_config) # Inject the env variable to configure logging for containerized execution environment if "environment" not in container_override: container_override["environment"] = [] container_override["environment"].append({"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"}) - return run_task_api + return run_task_kwargs def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None): """Save the task to be executed in the next sync by inserting the commands into a queue.""" @@ -484,6 +485,11 @@ def _load_run_kwargs(self) -> dict: def get_container(self, container_list): """Searches task list for core Airflow container.""" for container in container_list: - if container["name"] == self.container_name: - return container + try: + if container["name"] == self.container_name: + return container + except KeyError: + raise EcsExecutorException( + 'container "name" must be provided in "containerOverrides" configuration' + ) raise KeyError(f"No such container found by container name: {self.container_name}") diff --git a/docs/apache-airflow-providers-amazon/executors/ecs-executor.rst b/docs/apache-airflow-providers-amazon/executors/ecs-executor.rst index a6062a630437bd..d8d3764f5e0429 100644 --- a/docs/apache-airflow-providers-amazon/executors/ecs-executor.rst +++ b/docs/apache-airflow-providers-amazon/executors/ecs-executor.rst @@ -73,6 +73,9 @@ In the case of conflicts, the order of precedence from lowest to highest is: 3. Load any values provided in the RUN_TASK_KWARGS option if one is provided. +.. note:: + ``exec_config`` is an optional parameter that can be provided to operators. It is a dictionary type and in the context of the ECS Executor it represents a ``run_task_kwargs`` configuration which is then updated over-top of the ``run_task_kwargs`` specified in Airflow config above (if present). It is a recursive update which essentially applies Python update to each nested dictionary in the configuration. Loosely approximated as: ``run_task_kwargs.update(exec_config)`` + Required config options: ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -88,7 +91,7 @@ Optional config options: - ASSIGN_PUBLIC_IP - Whether to assign a public IP address to the containers launched by the ECS executor. Defaults to "False". -- CONN_ID - The Airflow connection (i.e. credentials) used by the ECS +- AWS_CONN_ID - The Airflow connection (i.e. credentials) used by the ECS executor to make API calls to AWS ECS. Defaults to "aws_default". - LAUNCH_TYPE - Launch type can either be 'FARGATE' OR 'EC2'. Defaults to "FARGATE". @@ -113,6 +116,9 @@ For a more detailed description of available options, including type hints and examples, see the ``config_templates`` folder in the Amazon provider package. +.. note:: + ``exec_config`` is an optional parameter that can be provided to operators. It is a dictionary type and in the context of the ECS Executor it represents a ``run_task_kwargs`` configuration which is then updated over-top of the ``run_task_kwargs`` specified in Airflow config above (if present). It is a recursive update which essentially applies Python update to each nested dictionary in the configuration. Loosely approximated as: ``run_task_kwargs.update(exec_config)`` + .. _dockerfile_for_ecs_executor: Dockerfile for ECS Executor diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py index 04e777455572d5..8766659c05cbb3 100644 --- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py @@ -34,6 +34,7 @@ from airflow.exceptions import AirflowException from airflow.executors.base_executor import BaseExecutor +from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.amazon.aws.executors.ecs import ecs_executor, ecs_executor_config from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoTaskSchema from airflow.providers.amazon.aws.executors.ecs.ecs_executor import ( @@ -1156,3 +1157,220 @@ def test_providing_no_capacity_provider_no_lunch_type_no_cluster_default(self, m task_kwargs = ecs_executor_config.build_task_kwargs() assert task_kwargs["launchType"] == "FARGATE" + + @pytest.mark.parametrize( + "run_task_kwargs, exec_config, expected_result", + [ + # No input run_task_kwargs or executor overrides + ( + {}, + {}, + { + "taskDefinition": "some-task-def", + "launchType": "FARGATE", + "cluster": "some-cluster", + "platformVersion": "LATEST", + "count": 1, + "overrides": { + "containerOverrides": [ + { + "command": ["command"], + "name": "container-name", + "environment": [{"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"}], + } + ] + }, + "networkConfiguration": { + "awsvpcConfiguration": { + "subnets": ["sub1", "sub2"], + "securityGroups": ["sg1", "sg2"], + "assignPublicIp": "DISABLED", + } + }, + }, + ), + # run_task_kwargs provided, not exec_config + ( + { + "startedBy": "Banana", + "tags": [{"key": "FOO", "value": "BAR"}], + "overrides": { + "containerOverrides": [ + { + "name": "container-name", + "memory": 500, + "cpu": 10, + "environment": [{"name": "X", "value": "Y"}], + } + ] + }, + }, + {}, + { + "startedBy": "Banana", + "tags": [{"key": "FOO", "value": "BAR"}], + "taskDefinition": "some-task-def", + "launchType": "FARGATE", + "cluster": "some-cluster", + "platformVersion": "LATEST", + "count": 1, + "overrides": { + "containerOverrides": [ + { + "memory": 500, + "cpu": 10, + "command": ["command"], + "name": "container-name", + "environment": [ + {"name": "X", "value": "Y"}, + # Added by the ecs executor + {"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"}, + ], + } + ] + }, + # Added by the ecs executor + "networkConfiguration": { + "awsvpcConfiguration": { + "subnets": ["sub1", "sub2"], + "securityGroups": ["sg1", "sg2"], + "assignPublicIp": "DISABLED", + } + }, + }, + ), + # exec_config provided, no run_task_kwargs + ( + {}, + { + "startedBy": "Banana", + "tags": [{"key": "FOO", "value": "BAR"}], + "overrides": { + "containerOverrides": [ + { + "name": "container-name", + "memory": 500, + "cpu": 10, + "environment": [{"name": "X", "value": "Y"}], + } + ] + }, + }, + { + "startedBy": "Banana", + "tags": [{"key": "FOO", "value": "BAR"}], + "taskDefinition": "some-task-def", + "launchType": "FARGATE", + "cluster": "some-cluster", + "platformVersion": "LATEST", + "count": 1, + "overrides": { + "containerOverrides": [ + { + "memory": 500, + "cpu": 10, + "command": ["command"], + "name": "container-name", + "environment": [ + {"name": "X", "value": "Y"}, + # Added by the ecs executor + {"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"}, + ], + } + ] + }, + # Added by the ecs executor + "networkConfiguration": { + "awsvpcConfiguration": { + "subnets": ["sub1", "sub2"], + "securityGroups": ["sg1", "sg2"], + "assignPublicIp": "DISABLED", + } + }, + }, + ), + # Both run_task_kwargs and executor_config provided. The latter should override the former, + # following a recursive python dict update strategy + ( + { + "startedBy": "Banana", + "tags": [{"key": "FOO", "value": "BAR"}], + "taskDefinition": "foobar", + "overrides": { + "containerOverrides": [ + { + "name": "container-name", + "memory": 500, + "cpu": 10, + "environment": [{"name": "X", "value": "Y"}], + } + ] + }, + }, + { + "startedBy": "Fish", + "tags": [{"key": "X", "value": "Y"}, {"key": "W", "value": "Z"}], + "overrides": { + "containerOverrides": [ + { + "name": "container-name", + "memory": 300, + "environment": [{"name": "W", "value": "Z"}], + } + ] + }, + }, + { + # tags and startedBy are overridden by exec_config + "startedBy": "Fish", + # List types overwrite entirely, as python dict update would do + "tags": [{"key": "X", "value": "Y"}, {"key": "W", "value": "Z"}], + # taskDefinition remains since it is not a list type and not overridden by exec config + "taskDefinition": "foobar", + "launchType": "FARGATE", + "cluster": "some-cluster", + "platformVersion": "LATEST", + "count": 1, + "overrides": { + "containerOverrides": [ + { + "memory": 300, + # cpu is not present because it was missing from the container overrides in + # the exec_config + "command": ["command"], + "name": "container-name", + "environment": [ + # Overridden list type + {"name": "W", "value": "Z"}, # Only new env vars present, overwritten + # Added by the ecs executor + {"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"}, + ], + } + ] + }, + # Added by the ecs executor + "networkConfiguration": { + "awsvpcConfiguration": { + "subnets": ["sub1", "sub2"], + "securityGroups": ["sg1", "sg2"], + "assignPublicIp": "DISABLED", + } + }, + }, + ), + ], + ) + def test_run_task_kwargs_exec_config_overrides( + self, set_env_vars, run_task_kwargs, exec_config, expected_result + ): + run_task_kwargs_env_key = f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper() + os.environ[run_task_kwargs_env_key] = json.dumps(run_task_kwargs) + + mock_ti_key = mock.Mock(spec=TaskInstanceKey) + command = ["command"] + + executor = AwsEcsExecutor() + + final_run_task_kwargs = executor._run_task_kwargs(mock_ti_key, command, "queue", exec_config) + + assert final_run_task_kwargs == expected_result From 00ed46769eaea24251fc4726a46df1f54f27c4bd Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Fri, 9 Feb 2024 13:13:36 -0500 Subject: [PATCH 8/8] D401 support in fab provider (#37283) --- .../fab/auth_manager/decorators/auth.py | 2 +- .../fab/auth_manager/fab_auth_manager.py | 2 +- .../auth_manager/security_manager/override.py | 52 +++++++++---------- pyproject.toml | 3 -- 4 files changed, 28 insertions(+), 31 deletions(-) diff --git a/airflow/providers/fab/auth_manager/decorators/auth.py b/airflow/providers/fab/auth_manager/decorators/auth.py index 95f97c8e795904..7089be08fc56df 100644 --- a/airflow/providers/fab/auth_manager/decorators/auth.py +++ b/airflow/providers/fab/auth_manager/decorators/auth.py @@ -67,7 +67,7 @@ def decorated(*args, **kwargs): def _has_access_fab(permissions: Sequence[tuple[str, str]] | None = None) -> Callable[[T], T]: """ - Factory for decorator that checks current user's permissions against required permissions. + Check current user's permissions against required permissions. This decorator is only kept for backward compatible reasons. The decorator ``airflow.www.auth.has_access``, which redirects to this decorator, is widely used in user plugins. diff --git a/airflow/providers/fab/auth_manager/fab_auth_manager.py b/airflow/providers/fab/auth_manager/fab_auth_manager.py index dfa53ef78b5b08..696709ae6cad8e 100644 --- a/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -446,7 +446,7 @@ def _get_fab_resource_types(dag_access_entity: DagAccessEntity) -> tuple[str, .. def _resource_name_for_dag(self, dag_id: str) -> str: """ - Returns the FAB resource name for a DAG id. + Return the FAB resource name for a DAG id. :param dag_id: the DAG id diff --git a/airflow/providers/fab/auth_manager/security_manager/override.py b/airflow/providers/fab/auth_manager/security_manager/override.py index 9fe89f8a69edbf..6f5c0f72c66985 100644 --- a/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/airflow/providers/fab/auth_manager/security_manager/override.py @@ -531,7 +531,7 @@ def auth_rate_limit(self) -> str: @property def auth_role_public(self): - """Gets the public role.""" + """Get the public role.""" return self.appbuilder.app.config["AUTH_ROLE_PUBLIC"] @property @@ -571,7 +571,7 @@ def auth_ldap_tls_demand(self): @property def auth_ldap_server(self): - """Gets the LDAP server object.""" + """Get the LDAP server object.""" return self.appbuilder.get_app.config["AUTH_LDAP_SERVER"] @property @@ -650,7 +650,7 @@ def api_login_allow_multiple_providers(self): @property def auth_username_ci(self): - """Gets the auth username for CI.""" + """Get the auth username for CI.""" return self.appbuilder.get_app.config.get("AUTH_USERNAME_CI", True) @property @@ -685,7 +685,7 @@ def auth_roles_sync_at_login(self) -> bool: @property def auth_role_admin(self): - """Gets the admin role.""" + """Get the admin role.""" return self.appbuilder.get_app.config["AUTH_ROLE_ADMIN"] @property @@ -697,7 +697,7 @@ def oauth_whitelists(self): return self.oauth_allow_list def create_builtin_roles(self): - """Returns FAB builtin roles.""" + """Return FAB builtin roles.""" return self.appbuilder.app.config.get("FAB_ROLES", {}) @property @@ -1445,7 +1445,7 @@ def add_user( password="", hashed_password="", ): - """Generic function to create user.""" + """Create a user.""" try: user = self.user_model() user.first_name = first_name @@ -1504,7 +1504,7 @@ def add_register_user(self, username, first_name, last_name, email, password="", return None def find_user(self, username=None, email=None): - """Finds user by username or email.""" + """Find user by username or email.""" if username: try: if self.auth_username_ci: @@ -1549,7 +1549,7 @@ def update_user(self, user: User) -> bool: def del_register_user(self, register_user): """ - Deletes registration object from database. + Delete registration object from database. :param register_user: RegisterUser object to delete """ @@ -1598,7 +1598,7 @@ def update_user_auth_stat(self, user, success=True): def get_action(self, name: str) -> Action: """ - Gets an existing action record. + Get an existing action record. :param name: name """ @@ -1606,7 +1606,7 @@ def get_action(self, name: str) -> Action: def create_action(self, name): """ - Adds an action to the backend, model action. + Add an action to the backend, model action. :param name: name of the action: 'can_add','can_edit' etc... @@ -1626,7 +1626,7 @@ def create_action(self, name): def delete_action(self, name: str) -> bool: """ - Deletes a permission action. + Delete a permission action. :param name: Name of action to delete (e.g. can_read). """ @@ -1659,7 +1659,7 @@ def delete_action(self, name: str) -> bool: def get_resource(self, name: str) -> Resource: """ - Returns a resource record by name, if it exists. + Return a resource record by name, if it exists. :param name: Name of resource """ @@ -1685,12 +1685,12 @@ def create_resource(self, name) -> Resource: return resource def get_all_resources(self) -> list[Resource]: - """Gets all existing resource records.""" + """Get all existing resource records.""" return self.get_session.query(self.resource_model).all() def delete_resource(self, name: str) -> bool: """ - Deletes a Resource from the backend. + Delete a Resource from the backend. :param name: name of the resource @@ -1728,7 +1728,7 @@ def get_permission( resource_name: str, ) -> Permission | None: """ - Gets a permission made with the given action->resource pair, if the permission already exists. + Get a permission made with the given action->resource pair, if the permission already exists. :param action_name: Name of action :param resource_name: Name of resource @@ -1753,7 +1753,7 @@ def get_resource_permissions(self, resource: Resource) -> Permission: def create_permission(self, action_name, resource_name) -> Permission | None: """ - Adds a permission on a resource to the backend. + Add a permission on a resource to the backend. :param action_name: name of the action to add: 'can_add','can_edit' etc... @@ -1781,7 +1781,7 @@ def create_permission(self, action_name, resource_name) -> Permission | None: def delete_permission(self, action_name: str, resource_name: str) -> None: """ - Deletes the permission linking an action->resource pair. + Delete the permission linking an action->resource pair. Doesn't delete the underlying action or resource. @@ -1846,7 +1846,7 @@ def remove_permission_from_role(self, role: Role, permission: Permission) -> Non self.get_session.rollback() def get_oid_identity_url(self, provider_name: str) -> str | None: - """Returns the OIDC identity provider URL.""" + """Return the OIDC identity provider URL.""" for provider in self.openid_providers: if provider.get("name") == provider_name: return provider.get("url") @@ -2091,7 +2091,7 @@ def oauth_user_info_getter( func: Callable[[AirflowSecurityManagerV2, str, dict[str, Any] | None], dict[str, Any]], ): """ - Decorator function to be the OAuth user info getter for all the providers. + Get OAuth user info for all the providers. Receives provider and response return a dict with the information returned from the provider. The returned user info dict should have its keys with the same name as the User Model. @@ -2210,7 +2210,7 @@ def get_oauth_user_info(self, provider: str, resp: dict[str, Any]) -> dict[str, @staticmethod def oauth_token_getter(): - """Authentication (OAuth) token getter function.""" + """Get authentication (OAuth) token.""" token = session.get("oauth") log.debug("Token Get: %s", token) return token @@ -2220,7 +2220,7 @@ def check_authorization( perms: Sequence[tuple[str, str]] | None = None, dag_id: str | None = None, ) -> bool: - """Checks that the logged in user has the specified permissions.""" + """Check the logged-in user has the specified permissions.""" if not perms: return True @@ -2254,7 +2254,7 @@ def set_oauth_session(self, provider, oauth_response): def get_oauth_token_key_name(self, provider): """ - Returns the token_key name for the oauth provider. + Return the token_key name for the oauth provider. If none is configured defaults to oauth_token this is configured using OAUTH_PROVIDERS and token_key key. @@ -2275,7 +2275,7 @@ def get_oauth_token_secret_name(self, provider): def auth_user_oauth(self, userinfo): """ - Method for authenticating user with OAuth. + Authenticate user with OAuth. :userinfo: dict with user information (keys are the same as User model columns) @@ -2608,7 +2608,7 @@ def _get_user_permission_resources( return result def _has_access_builtin_roles(self, role, action_name: str, resource_name: str) -> bool: - """Checks permission on builtin role.""" + """Check permission on builtin role.""" perms = self.builtin_roles.get(role.name, []) for _resource_name, _action_name in perms: if re2.match(_resource_name, resource_name) and re2.match(_action_name, action_name): @@ -2647,7 +2647,7 @@ def _get_all_non_dag_permissions(self) -> dict[tuple[str, str], Permission]: """ Get permissions except those that are for specific DAGs. - Returns a dict with a key of (action_name, resource_name) and value of permission + Return a dict with a key of (action_name, resource_name) and value of permission with all permissions except those that are for specific DAGs. """ return { @@ -2689,7 +2689,7 @@ def _get_root_dag_id(self, dag_id: str) -> str: @staticmethod def _cli_safe_flash(text: str, level: str) -> None: - """Shows a flash in a web context or prints a message if not.""" + """Show a flash in a web context or prints a message if not.""" if has_request_context(): flash(Markup(text), level) else: diff --git a/pyproject.toml b/pyproject.toml index 1856aac2a4f34c..cad94c51fec142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1373,9 +1373,6 @@ combine-as-imports = true "airflow/providers/common/io/xcom/backend.py" = ["D401"] "airflow/providers/databricks/hooks/databricks.py" = ["D401"] "airflow/providers/databricks/operators/databricks.py" = ["D401"] -"airflow/providers/fab/auth_manager/decorators/auth.py" = ["D401"] -"airflow/providers/fab/auth_manager/fab_auth_manager.py" = ["D401"] -"airflow/providers/fab/auth_manager/security_manager/override.py" = ["D401"] "airflow/providers/google/cloud/hooks/automl.py" = ["D401"] "airflow/providers/google/cloud/hooks/bigquery.py" = ["D401"] "airflow/providers/google/cloud/hooks/bigquery_dts.py" = ["D401"]