diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index a85964af4f64ac..c3d8b671fbb089 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -126,9 +126,9 @@ def initialize_method_map() -> dict[str, Callable]: # XCom.get_many, # Not supported because it returns query XCom.clear, XCom.set, - Variable.set, - Variable.update, - Variable.delete, + Variable._set, + Variable._update, + Variable._delete, DAG.fetch_callback, DAG.fetch_dagrun, DagRun.fetch_task_instances, @@ -237,7 +237,8 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse: response = json.dumps(output_json) if output_json is not None else None log.debug("Sending response: %s", response) return Response(response=response, headers={"Content-Type": "application/json"}) - except AirflowException as e: # In case of AirflowException transport the exception class back to caller + # In case of AirflowException or other selective known types, transport the exception class back to caller + except (KeyError, AttributeError, AirflowException) as e: exception_json = BaseSerialization.serialize(e, use_pydantic_models=True) response = json.dumps(exception_json) log.debug("Sending exception response: %s", response) diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index fc0945b3c0fe0b..8838377877becc 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -159,7 +159,7 @@ def wrapper(*args, **kwargs): if result is None or result == b"": return None result = BaseSerialization.deserialize(json.loads(result), use_pydantic_models=True) - if isinstance(result, AirflowException): + if isinstance(result, (KeyError, AttributeError, AirflowException)): raise result return result diff --git a/airflow/models/variable.py b/airflow/models/variable.py index 63b71303bc8034..563cac46e8c848 100644 --- a/airflow/models/variable.py +++ b/airflow/models/variable.py @@ -154,7 +154,6 @@ def get( @staticmethod @provide_session - @internal_api_call def set( key: str, value: Any, @@ -167,6 +166,35 @@ def set( This operation overwrites an existing variable. + :param key: Variable Key + :param value: Value to set for the Variable + :param description: Description of the Variable + :param serialize_json: Serialize the value to a JSON string + :param session: Session + """ + Variable._set( + key=key, value=value, description=description, serialize_json=serialize_json, session=session + ) + # invalidate key in cache for faster propagation + # we cannot save the value set because it's possible that it's shadowed by a custom backend + # (see call to check_for_write_conflict above) + SecretCache.invalidate_variable(key) + + @staticmethod + @provide_session + @internal_api_call + def _set( + key: str, + value: Any, + description: str | None = None, + serialize_json: bool = False, + session: Session = None, + ) -> None: + """ + Set a value for an Airflow Variable with a given Key. + + This operation overwrites an existing variable. + :param key: Variable Key :param value: Value to set for the Variable :param description: Description of the Variable @@ -190,7 +218,6 @@ def set( @staticmethod @provide_session - @internal_api_call def update( key: str, value: Any, @@ -200,6 +227,27 @@ def update( """ Update a given Airflow Variable with the Provided value. + :param key: Variable Key + :param value: Value to set for the Variable + :param serialize_json: Serialize the value to a JSON string + :param session: Session + """ + Variable._update(key=key, value=value, serialize_json=serialize_json, session=session) + # We need to invalidate the cache for internal API cases on the client side + SecretCache.invalidate_variable(key) + + @staticmethod + @provide_session + @internal_api_call + def _update( + key: str, + value: Any, + serialize_json: bool = False, + session: Session = None, + ) -> None: + """ + Update a given Airflow Variable with the Provided value. + :param key: Variable Key :param value: Value to set for the Variable :param serialize_json: Serialize the value to a JSON string @@ -219,11 +267,23 @@ def update( @staticmethod @provide_session - @internal_api_call def delete(key: str, session: Session = None) -> int: """ Delete an Airflow Variable for a given key. + :param key: Variable Keys + """ + rows = Variable._delete(key=key, session=session) + SecretCache.invalidate_variable(key) + return rows + + @staticmethod + @provide_session + @internal_api_call + def _delete(key: str, session: Session = None) -> int: + """ + Delete an Airflow Variable for a given key. + :param key: Variable Keys """ rows = session.execute(delete(Variable).where(Variable.key == key)).rowcount diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index a5bd5e3646e836..f216ce73161038 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -46,6 +46,7 @@ class DagAttributeTypes(str, Enum): RELATIVEDELTA = "relativedelta" BASE_TRIGGER = "base_trigger" AIRFLOW_EXC_SER = "airflow_exc_ser" + BASE_EXC_SER = "base_exc_ser" DICT = "dict" SET = "set" TUPLE = "tuple" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 6d0bbd4e23fd82..84ad5679182bb2 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -692,6 +692,15 @@ def serialize( ), type_=DAT.AIRFLOW_EXC_SER, ) + elif isinstance(var, (KeyError, AttributeError)): + return cls._encode( + cls.serialize( + {"exc_cls_name": var.__class__.__name__, "args": [var.args], "kwargs": {}}, + use_pydantic_models=use_pydantic_models, + strict=strict, + ), + type_=DAT.BASE_EXC_SER, + ) elif isinstance(var, BaseTrigger): return cls._encode( cls.serialize(var.serialize(), use_pydantic_models=use_pydantic_models, strict=strict), @@ -834,13 +843,16 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: return decode_timezone(var) elif type_ == DAT.RELATIVEDELTA: return decode_relativedelta(var) - elif type_ == DAT.AIRFLOW_EXC_SER: + elif type_ == DAT.AIRFLOW_EXC_SER or type_ == DAT.BASE_EXC_SER: deser = cls.deserialize(var, use_pydantic_models=use_pydantic_models) exc_cls_name = deser["exc_cls_name"] args = deser["args"] kwargs = deser["kwargs"] del deser - exc_cls = import_string(exc_cls_name) + if type_ == DAT.AIRFLOW_EXC_SER: + exc_cls = import_string(exc_cls_name) + else: + exc_cls = import_string(f"builtins.{exc_cls_name}") return exc_cls(*args, **kwargs) elif type_ == DAT.BASE_TRIGGER: tr_cls_name, kwargs = cls.deserialize(var, use_pydantic_models=use_pydantic_models) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index ba5ac7c7b54ad7..cbd38e9b390c84 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1444,7 +1444,10 @@ def test_check_task_dependencies( # Parameterized tests to check for the correct firing # of the trigger_rule under various circumstances of mapped task # Numeric fields are in order: - # successes, skipped, failed, upstream_failed, done,removed + # successes, skipped, failed, upstream_failed, done,remove + # Does not work for database isolation mode because there is local test monkeypatching of upstream_failed + # That never gets propagated to internal_api + @pytest.mark.skip_if_database_isolation_mode @pytest.mark.parametrize( "trigger_rule, upstream_states, flag_upstream_failed, expect_state, expect_completed", [ @@ -1540,8 +1543,10 @@ def do_something_else(i): monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: upstream_states) ti = dr.get_task_instance("do_something_else", session=session) ti.map_index = 0 + base_task = ti.task + for map_index in range(1, 5): - ti = TaskInstance(dr.task_instances[-1].task, run_id=dr.run_id, map_index=map_index) + ti = TaskInstance(base_task, run_id=dr.run_id, map_index=map_index) session.add(ti) ti.dag_run = dr session.flush() diff --git a/tests/models/test_variable.py b/tests/models/test_variable.py index e3d5c023a24ab7..6fb6fa15f214c0 100644 --- a/tests/models/test_variable.py +++ b/tests/models/test_variable.py @@ -47,6 +47,7 @@ def setup_test_cases(self): db.clear_db_variables() crypto._fernet = None + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet @conf_vars({("core", "fernet_key"): "", ("core", "unit_test_mode"): "True"}) def test_variable_no_encryption(self, session): """ @@ -60,6 +61,7 @@ def test_variable_no_encryption(self, session): # should mask anything. That logic is tested in test_secrets_masker.py self.mask_secret.assert_called_once_with("value", "key") + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()}) def test_variable_with_encryption(self, session): """ @@ -70,6 +72,7 @@ def test_variable_with_encryption(self, session): assert test_var.is_encrypted assert test_var.val == "value" + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet @pytest.mark.parametrize("test_value", ["value", ""]) def test_var_with_encryption_rotate_fernet_key(self, test_value, session): """ @@ -152,6 +155,7 @@ def test_variable_update(self, session): Variable.update(key="test_key", value="value2", session=session) assert "value2" == Variable.get("test_key") + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, API server has other ENV def test_variable_update_fails_on_non_metastore_variable(self, session): with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="env-value"): with pytest.raises(AttributeError): @@ -281,6 +285,7 @@ def test_caching_caches(self, mock_ensure_secrets: mock.Mock): mock_backend.get_variable.assert_called_once() # second call was not made because of cache assert first == second + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other env def test_cache_invalidation_on_set(self, session): with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="from_env"): a = Variable.get("key") # value is saved in cache @@ -316,7 +321,7 @@ def test_masking_only_secret_values(variable_value, deserialize_json, expected_m val=variable_value, ) session.add(var) - session.flush() + session.commit() # Make sure we re-load it, not just get the cached object back session.expunge(var) _secrets_masker().patterns = set() @@ -326,5 +331,4 @@ def test_masking_only_secret_values(variable_value, deserialize_json, expected_m for expected_masked_value in expected_masked_values: assert expected_masked_value in _secrets_masker().patterns finally: - session.rollback() db.clear_db_variables() diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 07ed3ab321d3aa..0ca13b343f2302 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -809,6 +809,7 @@ def test_catch_invalid_allowed_states(self): dag=self.dag, ) + @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode def test_external_task_sensor_waits_for_task_check_existence(self): op = ExternalTaskSensor( task_id="test_external_task_sensor_check", @@ -821,6 +822,7 @@ def test_external_task_sensor_waits_for_task_check_existence(self): with pytest.raises(AirflowException): op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode def test_external_task_sensor_waits_for_dag_check_existence(self): op = ExternalTaskSensor( task_id="test_external_task_sensor_check",