diff --git a/delvewheel/_wheel_repair.py b/delvewheel/_wheel_repair.py index 1668ef1..095d4d4 100644 --- a/delvewheel/_wheel_repair.py +++ b/delvewheel/_wheel_repair.py @@ -288,8 +288,8 @@ def _patch_py_contents(self, at_start: bool, libs_dir: str, load_order_filename: def _patch_py_file(self, py_path: str, libs_dir: str, load_order_filename: typing.Optional[str], depth: int) -> None: """Given the path to a .py file, create or patch the file so that vendored DLLs can be loaded at runtime. The patch is placed at the - topmost location after the docstring (if any) and any - "from __future__ import" statements. + topmost location after the shebang (if any), docstring or header + comments (if any), and any "from __future__ import" statements. py_path is the path to the .py file to create or patch libs_dir is the name of the directory where DLLs are stored. @@ -334,16 +334,7 @@ def _patch_py_file(self, py_path: str, libs_dir: str, load_order_filename: typin if remainder: file.write('\n') file.write(remainder) - elif docstring is None: - # prepend patch - patch_py_contents = self._patch_py_contents(True, libs_dir, load_order_filename, depth) - with open(py_path, 'w', newline=newline) as file: - file.write(patch_py_contents) - remainder = py_contents.lstrip() - if remainder: - file.write('\n') - file.write(remainder) - else: + elif docstring is not None: # place patch just after docstring patch_py_contents = self._patch_py_contents(False, libs_dir, load_order_filename, depth) if len(children) == 0 or not isinstance(children[0], ast.Expr) or ast.literal_eval(children[0].value) != docstring: @@ -391,6 +382,53 @@ def _patch_py_file(self, py_path: str, libs_dir: str, load_order_filename: typin file.write(patch_py_contents) file.write('\n') file.write(py_contents[docstring_end_index:].lstrip()) + else: + py_contents_lines = py_contents.splitlines() + start = 0 + if py_contents_lines and py_contents_lines[0].startswith('#!'): + start = 1 + while start < len(py_contents_lines) and py_contents_lines[start].strip() in ('', '#'): + start += 1 + if start < len(py_contents_lines) and py_contents_lines[start][:1] == '#': + # insert patch after header comments + end = start + 1 + while end < len(py_contents_lines) and py_contents_lines[end][:1] == '#': + end += 1 + patch_py_contents = self._patch_py_contents(False, libs_dir, load_order_filename, depth) + with open(py_path, 'w', newline=newline) as file: + file.write('\n'.join(py_contents_lines[:end]).rstrip()) + file.write('\n\n\n') + file.write(patch_py_contents) + remainder = '\n'.join(py_contents_lines[end:]).lstrip() + if remainder: + file.write('\n') + file.write(remainder) + if not remainder.endswith('\n'): + file.write('\n') + elif py_contents_lines and py_contents_lines[0].startswith('#!'): + # insert patch after shebang + patch_py_contents = self._patch_py_contents(False, libs_dir, load_order_filename, depth) + with open(py_path, 'w', newline=newline) as file: + file.write(py_contents_lines[0].rstrip()) + file.write('\n\n\n') + file.write(patch_py_contents) + remainder = '\n'.join(py_contents_lines[1:]).lstrip() + if remainder: + file.write('\n') + file.write(remainder) + if not remainder.endswith('\n'): + file.write('\n') + else: + # prepend patch + patch_py_contents = self._patch_py_contents(True, libs_dir, load_order_filename, depth) + with open(py_path, 'w', newline=newline) as file: + file.write(patch_py_contents) + remainder = py_contents.lstrip() + if remainder: + file.write('\n') + file.write(remainder) + if not remainder.endswith('\n'): + file.write('\n') # verify that the file can be parsed properly with open(py_path) as file: diff --git a/tests/run_tests.py b/tests/run_tests.py index 5832278..6e9ad7a 100644 --- a/tests/run_tests.py +++ b/tests/run_tests.py @@ -525,8 +525,15 @@ def test_init_patch(self): 5. 1 future import 6. multiple future imports 7. docstring and multiple future imports - 8. escaped quotes at docstring end""" - cases = 9 + 8. escaped quotes at docstring end + 9. comment after docstring + 10. comment without docstring + 11. blank line, multiline comment, no docstring + 12. comment, no docstring, code + 13. shebang + 14. shebang, comment, code + 15. shebang, split comments, code""" + cases = 16 check_call(['delvewheel', 'repair', '--add-path', 'simpleext/x64', '--no-mangle-all', 'simpleext/simpleext-0.0.1-0init-cp310-cp310-win_amd64.whl']) self.assertTrue(import_simpleext_successful('0init', [f'simpleext{x}.simpleext' for x in range(cases)])) diff --git a/tests/simpleext/simpleext-0.0.1-0init-cp310-cp310-win_amd64.whl b/tests/simpleext/simpleext-0.0.1-0init-cp310-cp310-win_amd64.whl index b70e4c8..301e024 100644 Binary files a/tests/simpleext/simpleext-0.0.1-0init-cp310-cp310-win_amd64.whl and b/tests/simpleext/simpleext-0.0.1-0init-cp310-cp310-win_amd64.whl differ