Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TypeError when defining enumeration types #525

Merged
merged 2 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions comtypes/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ def test_no_replacing_Patch_namespace(self):
# NOTE: `WindowsInstaller`, which has `Patch` definition in dll.
comtypes.client.GetModule("msi.dll")

def test_the_name_of_the_enum_member_and_the_coclass_are_duplicated(self):
# NOTE: In `MSHTML`, the name `htmlInputImage` is used both as a member of
# the `_htmlInput` enum type and as a CoClass that has `IHTMLElement` and
# others as interfaces.
# If a CoClass is assigned where an integer should be assigned, such as in
# the definition of an enumeration, the generation of the module will fail.
# See also https://github.com/enthought/comtypes/issues/524
with contextlib.redirect_stdout(None): # supress warnings
comtypes.client.GetModule("mshtml.tlb")

def test_abstracted_wrapper_module_in_friendly_module(self):
mod = comtypes.client.GetModule("scrrun.dll")
self.assertTrue(hasattr(mod, "__wrapper_module__"))
Expand Down
24 changes: 12 additions & 12 deletions comtypes/tools/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def EnumValue(self, tp: typedesc.EnumValue) -> None:
tp_name = self._to_type_name(tp)
print("%s = %d" % (tp_name, value), file=self.stream)
if tp.enumeration.name:
self.enums.add(tp.enumeration.name, tp_name)
self.enums.add(tp.enumeration.name, tp_name, value)
self.names.add(tp_name)

def Enumeration(self, tp: typedesc.Enumeration) -> None:
Expand Down Expand Up @@ -1546,28 +1546,28 @@ def getvalue(self):

class EnumerationNamespaces(object):
def __init__(self):
self.data: Dict[str, List[str]] = {}
self.data: Dict[str, List[Tuple[str, int]]] = {}

def add(self, enum_name: str, member_name: str) -> None:
def add(self, enum_name: str, member_name: str, value: int) -> None:
"""Adds a namespace will be enumeration and its member.

Examples:
>>> enums = EnumerationNamespaces()
>>> enums.add('Foo', 'ham')
>>> enums.add('Foo', 'spam')
>>> enums.add('Bar', 'bacon')
>>> enums.add('Foo', 'ham', 1)
>>> enums.add('Foo', 'spam', 2)
>>> enums.add('Bar', 'bacon', 3)
>>> assert 'Foo' in enums
>>> assert 'Baz' not in enums
>>> print(enums.getvalue()) # <BLANKLINE> is necessary for doctest
class Foo(IntFlag):
ham = __wrapper_module__.ham
spam = __wrapper_module__.spam
ham = 1
spam = 2
<BLANKLINE>
<BLANKLINE>
class Bar(IntFlag):
bacon = __wrapper_module__.bacon
bacon = 3
"""
self.data.setdefault(enum_name, []).append(member_name)
self.data.setdefault(enum_name, []).append((member_name, value))

def __contains__(self, item: str) -> bool:
return item in self.data
Expand All @@ -1580,7 +1580,7 @@ def getvalue(self) -> str:
for enum_name, enum_members in self.data.items():
lines = []
lines.append(f"class {enum_name}(IntFlag):")
for member_name in enum_members:
lines.append(f" {member_name} = __wrapper_module__.{member_name}")
for member_name, value in enum_members:
lines.append(f" {member_name} = {value}")
blocks.append("\n".join(lines))
return "\n\n\n".join(blocks)
Loading