Skip to content

Commit

Permalink
script: fix edge case with Major.minor only
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 18, 2024
1 parent 1d0e380 commit bb536fd
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions scripts/adjust-torch-versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ def _determine_torchaudio(torch_version: str) -> str:
_version_exceptions = {
"1.8.2": "0.9.1",
}
# drop all except semantic version
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _version_exceptions:
return _version_exceptions[torch_ver]
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
if torch_version in _version_exceptions:
return _version_exceptions[torch_version]
ver_major, ver_minor, ver_bugfix = map(int, torch_version.split("."))
ta_ver_array = [ver_major, ver_minor, ver_bugfix]
if ver_major == 1:
ta_ver_array[0] = 0
Expand All @@ -52,11 +50,9 @@ def _determine_torchtext(torch_version: str) -> str:
"2.0.0": "0.15.1",
"1.8.2": "0.9.1",
}
# drop all except semantic version
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _version_exceptions:
return _version_exceptions[torch_ver]
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
if torch_version in _version_exceptions:
return _version_exceptions[torch_version]
ver_major, ver_minor, ver_bugfix = map(int, torch_version.split("."))
tt_ver_array = [0, 0, 0]
if ver_major == 1:
tt_ver_array[1] = ver_minor + 1
Expand Down Expand Up @@ -91,11 +87,9 @@ def _determine_torchvision(torch_version: str) -> str:
"1.10.0": "0.11.1",
"1.8.2": "0.9.1",
}
# drop all except semantic version
torch_ver = re.search(r"([\.\d]+)", torch_version).groups()[0]
if torch_ver in _version_exceptions:
return _version_exceptions[torch_ver]
ver_major, ver_minor, ver_bugfix = map(int, torch_ver.split("."))
if torch_version in _version_exceptions:
return _version_exceptions[torch_version]
ver_major, ver_minor, ver_bugfix = map(int, torch_version.split("."))
tv_ver_array = [0, 0, 0]
if ver_major == 1:
tv_ver_array[1] = ver_minor + 1
Expand All @@ -111,7 +105,12 @@ def find_latest(ver: str) -> Dict[str, str]:
"""Find the latest version.
>>> from pprint import pprint
>>> pprint(find_latest("2.1.0"))
>>> pprint(find_latest("2.4.1"))
{'torch': '2.4.1',
'torchaudio': '2.4.1',
'torchtext': '0.18.0',
'torchvision': '0.19.1'}
>>> pprint(find_latest("2.1"))
{'torch': '2.1.0',
'torchaudio': '2.1.0',
'torchtext': '0.16.0',
Expand All @@ -122,6 +121,8 @@ def find_latest(ver: str) -> Dict[str, str]:
ver = re.search(r"([\.\d]+)", ver).groups()[0]
# in case there remaining dot at the end - e.g "1.9.0.dev20210504"
ver = ver[:-1] if ver[-1] == "." else ver
if not re.match(r"\d+\.\d+\.\d+", ver):
ver += ".0" # add missing bugfix
logging.debug(f"finding ecosystem versions for: {ver}")

# find first match
Expand Down

0 comments on commit bb536fd

Please sign in to comment.