diff --git a/comtypes/client/_generate.py b/comtypes/client/_generate.py index f4f5fcfd..f6827bec 100644 --- a/comtypes/client/_generate.py +++ b/comtypes/client/_generate.py @@ -3,6 +3,7 @@ import sys import comtypes.client import comtypes.tools.codegenerator +import imp import logging logger = logging.getLogger(__name__) @@ -23,7 +24,11 @@ def _my_import(fullname): if comtypes.client.gen_dir \ and comtypes.client.gen_dir not in comtypes.gen.__path__: comtypes.gen.__path__.append(comtypes.client.gen_dir) - return __import__(fullname, globals(), locals(), ['DUMMY']) + + mod = imp.reload(eval(fullname)) if fullname in sys.modules \ + else __import__(fullname, globals(), locals(), ['DUMMY']) + mod._comtypes_validate_file() + return mod def _name_module(tlib): # Determine the name of a typelib wrapper module. @@ -152,8 +157,10 @@ def _CreateWrapper(tlib, pathname=None): # helper which creates and imports the real typelib wrapper module. fullname = _name_module(tlib) try: - return sys.modules[fullname] - except KeyError: + mod = sys.modules[fullname] + mod._comtypes_validate_file() + return mod + except (KeyError, AttributeError): pass modname = fullname.split(".")[-1] @@ -165,11 +172,9 @@ def _CreateWrapper(tlib, pathname=None): # generate the module since it doesn't exist or is out of date from comtypes.tools.tlbparser import generate_module - if comtypes.client.gen_dir is None: - import cStringIO - ofi = cStringIO.StringIO() - else: - ofi = open(os.path.join(comtypes.client.gen_dir, modname + ".py"), "w") + import cStringIO + ofi = cStringIO.StringIO() + # XXX use logging! if __verbose__: print "# Generating comtypes.gen.%s" % modname @@ -184,7 +189,8 @@ def _CreateWrapper(tlib, pathname=None): sys.modules[fullname] = mod setattr(comtypes.gen, modname, mod) else: - ofi.close() + with open(os.path.join(comtypes.client.gen_dir, modname + ".py"), "w") as fd: + fd.write(ofi.getvalue()) mod = _my_import(fullname) return mod diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index be0f0566..97436c78 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -258,6 +258,7 @@ def generate_code(self, items, filename=None): for line in wrapper.wrap(text): print >> self.output, line print >> self.output, "from comtypes import _check_version; _check_version(%r)" % version + print >> self.output, "def _comtypes_validate_file(): pass" return loops def type_name(self, t, generate=True): @@ -627,7 +628,7 @@ def External(self, ext): modname = comtypes.client._generate._name_module(ext.tlib) ext.name = "%s.%s" % (modname, ext.symbol_name) self._externals[libdesc] = modname - print >> self.imports, "import", modname + print >> self.imports, "import %s; %s._comtypes_validate_file()" % (modname, modname) comtypes.client.GetModule(ext.tlib) def Constant(self, tp):