Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Commit

Permalink
Fix typing issues of StringSetFlag (#107)
Browse files Browse the repository at this point in the history
* Selectively ignore override errors as this is a custom class
* Use a separate interface declaration file (.pyi) by refactoring out
  StringSetFlag class from the utils module to overcome limitation of
  mypy semantic analysis on "__ror__ = __or__" in subclasses of
  Enum/Flag classes.
  • Loading branch information
achimnol committed Jan 5, 2022
1 parent 9c4ee94 commit ec96e50
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 56 deletions.
1 change: 1 addition & 0 deletions changes/107.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix typing issues of `StringSetFlag` by refactoring it using a separate interface definition file
57 changes: 57 additions & 0 deletions src/ai/backend/common/enum_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import enum

__all__ = (
'StringSetFlag',
)


class StringSetFlag(enum.Flag):

def __eq__(self, other):
return self.value == other

def __hash__(self):
return hash(self.value)

def __or__(self, other):
if isinstance(other, type(self)):
other = other.value
if not isinstance(other, (set, frozenset)):
other = set((other,))
return set((self.value,)) | other

__ror__ = __or__

def __and__(self, other):
if isinstance(other, (set, frozenset)):
return self.value in other
if isinstance(other, str):
return self.value == other
raise TypeError

__rand__ = __and__

def __xor__(self, other):
if isinstance(other, (set, frozenset)):
return set((self.value,)) ^ other
if isinstance(other, str):
if other == self.value:
return set()
else:
return other
raise TypeError

def __rxor__(self, other):
if isinstance(other, (set, frozenset)):
return other ^ set((self.value,))
if isinstance(other, str):
if other == self.value:
return set()
else:
return other
raise TypeError

def __str__(self):
return self.value
22 changes: 22 additions & 0 deletions src/ai/backend/common/enum_extension.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import enum


class StringSetFlag(enum.Flag):
def __eq__(self, other: object) -> bool: ...
def __hash__(self) -> int: ...
def __or__( # type: ignore[override]
self,
other: StringSetFlag | str | set[str] | frozenset[str],
) -> set[str]: ...
def __and__( # type: ignore[override]
self,
other: StringSetFlag | str | set[str] | frozenset[str],
) -> bool: ...
def __xor__( # type: ignore[override]
self,
other: StringSetFlag | str | set[str] | frozenset[str],
) -> set[str]: ...
def __ror__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> set[str]: ...
def __rand__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> bool: ...
def __rxor__(self, other: StringSetFlag | str | set[str] | frozenset[str]) -> set[str]: ...
def __str__(self) -> str: ...
52 changes: 1 addition & 51 deletions src/ai/backend/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import base64
from collections import OrderedDict
from datetime import timedelta
import enum
from itertools import chain
import numbers
import random
Expand Down Expand Up @@ -34,6 +33,7 @@
current_loop,
run_through,
)
from .enum_extension import StringSetFlag # for legacy imports # noqa
from .files import AsyncFileWriter # for legacy imports # noqa
from .networking import ( # for legacy imports # noqa
curl,
Expand Down Expand Up @@ -198,56 +198,6 @@ def str_to_timedelta(tstr: str) -> timedelta:
return timedelta(**params) # type: ignore


class StringSetFlag(enum.Flag):

def __eq__(self, other):
return self.value == other

def __hash__(self):
return hash(self.value)

def __or__(self, other):
if isinstance(other, type(self)):
other = other.value
if not isinstance(other, (set, frozenset)):
other = set((other,))
return set((self.value,)) | other

__ror__ = __or__

def __and__(self, other):
if isinstance(other, (set, frozenset)):
return self.value in other
if isinstance(other, str):
return self.value == other
raise TypeError

__rand__ = __and__

def __xor__(self, other):
if isinstance(other, (set, frozenset)):
return set((self.value,)) ^ other
if isinstance(other, str):
if other == self.value:
return set()
else:
return other
raise TypeError

def __rxor__(self, other):
if isinstance(other, (set, frozenset)):
return other ^ set((self.value,))
if isinstance(other, str):
if other == self.value:
return set()
else:
return other
raise TypeError

def __str__(self):
return self.value


class FstabEntry:
"""
Entry class represents a non-comment line on the `fstab` file.
Expand Down
8 changes: 3 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import pytest

from ai.backend.common.asyncio import AsyncBarrier, run_through
from ai.backend.common.enum_extension import StringSetFlag
from ai.backend.common.files import AsyncFileWriter
from ai.backend.common.networking import curl
from ai.backend.common.utils import (
odict, dict2kvlist, nmget,
generate_uuid, get_random_seq,
readable_size_to_bytes,
str_to_timedelta,
StringSetFlag,
)
from ai.backend.common.testutils import (
mock_corofunc, mock_awaitable, AsyncContextManagerMock,
Expand Down Expand Up @@ -156,9 +156,7 @@ async def test_curl_returns_default_value_if_not_success(mocker) -> None:

def test_string_set_flag() -> None:

# FIXME: Remove "type: ignore" when mypy gets released with
# python/mypy#11579.
class MyFlags(StringSetFlag): # type: ignore
class MyFlags(StringSetFlag):
A = 'a'
B = 'b'

Expand All @@ -182,7 +180,7 @@ class MyFlags(StringSetFlag): # type: ignore
assert {'b'} == MyFlags.A ^ {'a', 'b'}
assert {'a', 'b', 'c'} == MyFlags.A ^ {'b', 'c'}
with pytest.raises(TypeError):
123 & MyFlags.A
123 & MyFlags.A # type: ignore[operator]

assert {'a', 'c'} & MyFlags.A
assert not {'a', 'c'} & MyFlags.B
Expand Down

0 comments on commit ec96e50

Please sign in to comment.