From f33b68656a5828e713fecf7a144ba44a36e1f8a2 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Wed, 21 Feb 2024 17:04:36 +0100 Subject: [PATCH] Simplify checks for package versions (#37585) Replaces more complex package version checks with one-liners. --- airflow/utils/pydantic.py | 12 +++---- airflow/utils/sqlalchemy.py | 11 ++----- airflow/utils/timezone.py | 4 ++- .../serializers/test_serializers.py | 4 ++- tests/utils/test_sqlalchemy.py | 32 ++++--------------- 5 files changed, 19 insertions(+), 44 deletions(-) diff --git a/airflow/utils/pydantic.py b/airflow/utils/pydantic.py index 0ec184672c2c36..13ab7911663f2e 100644 --- a/airflow/utils/pydantic.py +++ b/airflow/utils/pydantic.py @@ -24,18 +24,14 @@ from __future__ import annotations +from importlib import metadata -def is_pydantic_2_installed() -> bool: - import sys +from packaging import version - from packaging.version import Version - if sys.version_info >= (3, 9): - from importlib.metadata import distribution - else: - from importlib_metadata import distribution +def is_pydantic_2_installed() -> bool: try: - return Version(distribution("pydantic").version) >= Version("2.0.0") + return version.parse(metadata.version("pydantic")).major == 2 except ImportError: return False diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index 3c72d5ce8952d4..d803d8244be87f 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -22,11 +22,11 @@ import datetime import json import logging -from importlib.metadata import version +from importlib import metadata from typing import TYPE_CHECKING, Any, Generator, Iterable, overload from dateutil import relativedelta -from packaging.version import Version, parse as parse_version +from packaging import version from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_ from sqlalchemy.dialects import mysql from sqlalchemy.types import JSON, Text, TypeDecorator @@ -555,10 +555,5 @@ def get_orm_mapper(): return sqlalchemy.orm.mapper if is_sqlalchemy_v1() else sqlalchemy.orm.Mapper -def _get_lib_major_version(lib_name: str) -> int: - ver: Version = parse_version(version(lib_name)) - return ver.major - - def is_sqlalchemy_v1() -> bool: - return _get_lib_major_version("sqlalchemy") == 1 + return version.parse(metadata.version("sqlalchemy")).major == 1 diff --git a/airflow/utils/timezone.py b/airflow/utils/timezone.py index 966c4bbdc1a15b..7fc4bd6ac12ee4 100644 --- a/airflow/utils/timezone.py +++ b/airflow/utils/timezone.py @@ -18,10 +18,12 @@ from __future__ import annotations import datetime as dt +from importlib import metadata from typing import TYPE_CHECKING, overload import pendulum from dateutil.relativedelta import relativedelta +from packaging import version from pendulum.datetime import DateTime if TYPE_CHECKING: @@ -29,7 +31,7 @@ from airflow.typing_compat import Literal -_PENDULUM3 = pendulum.__version__.startswith("3") +_PENDULUM3 = version.parse(metadata.version("pendulum")).major == 3 # UTC Timezone as a tzinfo instance. Actual value depends on pendulum version: # - Timezone("UTC") in pendulum 3 # - FixedTimezone(0, "UTC") in pendulum 2 diff --git a/tests/serialization/serializers/test_serializers.py b/tests/serialization/serializers/test_serializers.py index 63627b89d18d16..cb0b03b324900c 100644 --- a/tests/serialization/serializers/test_serializers.py +++ b/tests/serialization/serializers/test_serializers.py @@ -18,6 +18,7 @@ import datetime import decimal +from importlib import metadata from unittest.mock import patch import numpy as np @@ -26,6 +27,7 @@ import pytest from dateutil.tz import tzutc from deltalake import DeltaTable +from packaging import version from pendulum import DateTime from pendulum.tz.timezone import FixedTimezone, Timezone @@ -38,7 +40,7 @@ else: from backports.zoneinfo import ZoneInfo -PENDULUM3 = pendulum.__version__.startswith("3") +PENDULUM3 = version.parse(metadata.version("pendulum")).major == 3 class TestSerializers: diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py index c4af2a4084d7cc..1136a106cb9e19 100644 --- a/tests/utils/test_sqlalchemy.py +++ b/tests/utils/test_sqlalchemy.py @@ -35,7 +35,6 @@ from airflow.settings import Session from airflow.utils.sqlalchemy import ( ExecutorConfigType, - _get_lib_major_version, ensure_pod_is_valid_after_unpickling, is_sqlalchemy_v1, prohibit_commit, @@ -317,32 +316,13 @@ def test_result_processor_bad_pickled_obj(self): @pytest.mark.parametrize( - "version_string, expected_major_version", + "mock_version, expected_result", [ - ("1.4.22", 1), # Test 1: "1.4.22" parsed as 1 - ("10.4.22", 10), # Test 2: "10.4.22" not parsed as 1 - ("invalid", None), # Test 3: Invalid version string - ("3.x.x", None), # Test 4: Malformed version + ("1.0.0", True), # Test 1: v1 identified as v1 + ("2.3.4", False), # Test 2: v2 not identified as v1 ], ) -def test_get_lib_major_version(version_string, expected_major_version): - with mock.patch("airflow.utils.sqlalchemy.version") as mock_version: - mock_version.return_value = version_string - if expected_major_version is not None: - assert _get_lib_major_version("dummy_module") == expected_major_version - else: - with pytest.raises(ValueError): - _get_lib_major_version("dummy_module") - - -@pytest.mark.parametrize( - "major_version, expected_result", - [ - (1, True), # Test 1: v1 identified as v1 - (2, False), # Test 2: v2 not identified as v1 - ], -) -def test_is_sqlalchemy_v1(major_version, expected_result): - with mock.patch("airflow.utils.sqlalchemy._get_lib_major_version") as mock_get_major_version: - mock_get_major_version.return_value = major_version +def test_is_sqlalchemy_v1(mock_version, expected_result): + with mock.patch("airflow.utils.sqlalchemy.metadata") as mock_metadata: + mock_metadata.version.return_value = mock_version assert is_sqlalchemy_v1() == expected_result