Skip to content

Commit

Permalink
Fix TypeError when defining enumeration types (#525)
Browse files Browse the repository at this point in the history
  • Loading branch information
junkmd authored Apr 10, 2024
1 parent 7fa88e1 commit e3da62b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
10 changes: 10 additions & 0 deletions comtypes/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ def test_no_replacing_Patch_namespace(self):
# NOTE: `WindowsInstaller`, which has `Patch` definition in dll.
comtypes.client.GetModule("msi.dll")

def test_the_name_of_the_enum_member_and_the_coclass_are_duplicated(self):
# NOTE: In `MSHTML`, the name `htmlInputImage` is used both as a member of
# the `_htmlInput` enum type and as a CoClass that has `IHTMLElement` and
# others as interfaces.
# If a CoClass is assigned where an integer should be assigned, such as in
# the definition of an enumeration, the generation of the module will fail.
# See also https://github.com/enthought/comtypes/issues/524
with contextlib.redirect_stdout(None): # supress warnings
comtypes.client.GetModule("mshtml.tlb")

def test_abstracted_wrapper_module_in_friendly_module(self):
mod = comtypes.client.GetModule("scrrun.dll")
self.assertTrue(hasattr(mod, "__wrapper_module__"))
Expand Down
24 changes: 12 additions & 12 deletions comtypes/tools/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def EnumValue(self, tp: typedesc.EnumValue) -> None:
tp_name = self._to_type_name(tp)
print("%s = %d" % (tp_name, value), file=self.stream)
if tp.enumeration.name:
self.enums.add(tp.enumeration.name, tp_name)
self.enums.add(tp.enumeration.name, tp_name, value)
self.names.add(tp_name)

def Enumeration(self, tp: typedesc.Enumeration) -> None:
Expand Down Expand Up @@ -1546,28 +1546,28 @@ def getvalue(self):

class EnumerationNamespaces(object):
def __init__(self):
self.data: Dict[str, List[str]] = {}
self.data: Dict[str, List[Tuple[str, int]]] = {}

def add(self, enum_name: str, member_name: str) -> None:
def add(self, enum_name: str, member_name: str, value: int) -> None:
"""Adds a namespace will be enumeration and its member.
Examples:
>>> enums = EnumerationNamespaces()
>>> enums.add('Foo', 'ham')
>>> enums.add('Foo', 'spam')
>>> enums.add('Bar', 'bacon')
>>> enums.add('Foo', 'ham', 1)
>>> enums.add('Foo', 'spam', 2)
>>> enums.add('Bar', 'bacon', 3)
>>> assert 'Foo' in enums
>>> assert 'Baz' not in enums
>>> print(enums.getvalue()) # <BLANKLINE> is necessary for doctest
class Foo(IntFlag):
ham = __wrapper_module__.ham
spam = __wrapper_module__.spam
ham = 1
spam = 2
<BLANKLINE>
<BLANKLINE>
class Bar(IntFlag):
bacon = __wrapper_module__.bacon
bacon = 3
"""
self.data.setdefault(enum_name, []).append(member_name)
self.data.setdefault(enum_name, []).append((member_name, value))

def __contains__(self, item: str) -> bool:
return item in self.data
Expand All @@ -1580,7 +1580,7 @@ def getvalue(self) -> str:
for enum_name, enum_members in self.data.items():
lines = []
lines.append(f"class {enum_name}(IntFlag):")
for member_name in enum_members:
lines.append(f" {member_name} = __wrapper_module__.{member_name}")
for member_name, value in enum_members:
lines.append(f" {member_name} = {value}")
blocks.append("\n".join(lines))
return "\n\n\n".join(blocks)

0 comments on commit e3da62b

Please sign in to comment.