Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional deps #61

Merged
merged 13 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ The main supports the following additional arguments:
- `--track_import_stack`: Store the stack trace of imports belonging to the tracked module
- `--detect_transitive`: Mark each dependency as either "direct" (imported directly) or "transitive" (inherited from a direct import)
- `--full_depth`: Track all dependencies, including transitive dependencies of direct third-party deps
- `--show_optional`: Show whether each dependency is optional or required

## Integrating `import_tracker` into a project

Expand Down
8 changes: 8 additions & 0 deletions import_tracker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ def main():
default=False,
help="Detect whether each dependency is 'direct' or 'transitive'",
)
parser.add_argument(
"--show_optional",
"-o",
action="store_true",
default=False,
help="Show whether each dependency is optional or required",
)
parser.add_argument(
"--log_level",
"-l",
Expand Down Expand Up @@ -109,6 +116,7 @@ def main():
track_import_stack=args.track_import_stack,
full_depth=args.full_depth,
detect_transitive=args.detect_transitive,
show_optional=args.show_optional,
),
indent=args.indent,
)
Expand Down
5 changes: 5 additions & 0 deletions import_tracker/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
# Labels for direct vs transitive dependencies
TYPE_DIRECT = "direct"
TYPE_TRANSITIVE = "transitive"

# Info section headers
INFO_TYPE = "type"
INFO_STACK = "stack"
INFO_OPTIONAL = "optional"
154 changes: 123 additions & 31 deletions import_tracker/import_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
"""
# Standard
from types import ModuleType
from typing import Any, Dict, Iterable, List, Optional, Set, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import dis
import importlib
import os
import sys

# Local
from .constants import THIS_PACKAGE, TYPE_DIRECT, TYPE_TRANSITIVE
from . import constants
from .log import log

## Public ######################################################################
Expand All @@ -24,6 +24,7 @@ def track_module(
track_import_stack: bool = False,
full_depth: bool = False,
detect_transitive: bool = False,
show_optional: bool = False,
) -> Union[Dict[str, List[str]], Dict[str, Dict[str, Any]]]:
"""Track the dependencies of a single python module

Expand All @@ -46,6 +47,9 @@ def track_module(
library.
detect_transitive: bool
Detect whether each dependency is 'direct' or 'transitive'
show_optional: bool
Show whether each requirement is optional (behind a try/except) or
not

Returns:
import_mapping: Union[Dict[str, List[str]], Dict[str, Dict[str, Any]]]
Expand All @@ -72,8 +76,10 @@ def track_module(
for module_to_check in modules_to_check:

# Figure out all direct imports from this module
module_imports = _get_imports(module_to_check)
module_import_names = {mod.__name__ for mod in module_imports}
req_imports, opt_imports = _get_imports(module_to_check)
opt_dep_names = {mod.__name__ for mod in opt_imports}
all_imports = req_imports.union(opt_imports)
module_import_names = {mod.__name__ for mod in all_imports}
log.debug3(
"Full import names for [%s]: %s",
module_to_check.__name__,
Expand All @@ -84,10 +90,14 @@ def track_module(
non_std_module_names = _get_non_std_modules(module_import_names)
log.debug3("Non std module names: %s", non_std_module_names)
non_std_module_imports = [
mod for mod in module_imports if mod.__name__ in non_std_module_names
mod for mod in all_imports if mod.__name__ in non_std_module_names
]

module_deps_map[module_to_check.__name__] = non_std_module_names
# Set the deps for this module as a mapping from each dep to its
# optional status
module_deps_map[module_to_check.__name__] = {
mod: mod in opt_dep_names for mod in non_std_module_names
}
log.debug2(
"Deps for [%s] -> %s",
module_to_check.__name__,
Expand Down Expand Up @@ -164,11 +174,11 @@ def track_module(
}
log.debug("Raw output deps map: %s", flattened_deps)

# If not detecting transitive or import stacks, the values are simple lists
# of dependency names
if not detect_transitive and not track_import_stack:
# If not displaying any of the extra info, the values are simple lists of
# dependency names
if not any([detect_transitive, track_import_stack, show_optional]):
deps_out = {
mod: list(sorted(deps.keys())) for mod, deps in flattened_deps.items()
mod: list(sorted(deps.keys())) for mod, (deps, _) in flattened_deps.items()
}

# Otherwise, the values will be dicts with some combination of "type" and
Expand All @@ -179,22 +189,32 @@ def track_module(
# If detecting transitive deps, look through the stacks and mark each dep as
# transitive or direct
if detect_transitive:
for mod, deps in flattened_deps.items():
for mod, (deps, _) in flattened_deps.items():
for dep_name, dep_stacks in deps.items():
deps_out.setdefault(mod, {}).setdefault(dep_name, {})["type"] = (
TYPE_DIRECT
deps_out.setdefault(mod, {}).setdefault(dep_name, {})[
constants.INFO_TYPE
] = (
constants.TYPE_DIRECT
if any(len(dep_stack) == 1 for dep_stack in dep_stacks)
else TYPE_TRANSITIVE
else constants.TYPE_TRANSITIVE
)

# If tracking import stacks, move them to the "stack" key in the output
if track_import_stack:
for mod, deps in flattened_deps.items():
for mod, (deps, _) in flattened_deps.items():
for dep_name, dep_stacks in deps.items():
deps_out.setdefault(mod, {}).setdefault(dep_name, {})[
"stack"
constants.INFO_STACK
] = dep_stacks

# If showing optional, add the optional status of each dependency
if show_optional:
for mod, (deps, optional_mapping) in flattened_deps.items():
for dep_name, dep_stacks in deps.items():
deps_out.setdefault(mod, {}).setdefault(dep_name, {})[
constants.INFO_OPTIONAL
] = optional_mapping.get(dep_name, False)

log.debug("Final output: %s", deps_out)
return deps_out

Expand Down Expand Up @@ -280,7 +300,7 @@ def _is_third_party(mod_name: str) -> bool:
mod_name not in sys.modules
or _get_import_parent_path(mod_name) not in [_std_lib_dir, _std_dylib_dir]
)
and mod_pkg != THIS_PACKAGE
and mod_pkg != constants.THIS_PACKAGE
and mod_pkg not in _known_std_pkgs
)

Expand All @@ -299,6 +319,21 @@ def _get_value_col(dis_line: str) -> str:
return ""


def _get_op_number(dis_line: str) -> Optional[int]:
"""Get the opcode number out of the line of `dis` output"""
line_parts = dis_line.split()
if not line_parts:
return None
opcode_idx = min([i for i, val in enumerate(line_parts) if val.isupper()])
assert opcode_idx > 0, f"Opcode found at the beginning of line! [{dis_line}]"
return int(line_parts[opcode_idx - 1])


def _get_try_end_number(dis_line: str) -> int:
"""For a SETUP_FINALLY/SETUP_EXPECT line, extract the target end line"""
return int(_get_value_col(dis_line).split()[-1])


def _figure_out_import(
mod: ModuleType,
dots: Optional[int],
Expand Down Expand Up @@ -355,12 +390,13 @@ def _figure_out_import(
return sys.modules.get(import_name)


def _get_imports(mod: ModuleType) -> Set[ModuleType]:
"""Get the list of import string from a module by parsing the module's
bytecode
def _get_imports(mod: ModuleType) -> Tuple[Set[ModuleType], Set[ModuleType]]:
"""Get the sets of required and optional imports for the given module by
parsing its bytecode
"""
log.debug2("Getting imports for %s", mod.__name__)
all_imports = set()
req_imports = set()
opt_imports = set()

# Attempt to disassemble the byte code for this module. If the module has no
# code, we ignore it since it's most likely a c extension
Expand All @@ -369,21 +405,30 @@ def _get_imports(mod: ModuleType) -> Set[ModuleType]:
mod_code = loader.get_code(mod.__name__)
except (AttributeError, ImportError):
log.warning("Couldn't find a loader for %s!", mod.__name__)
return all_imports
return req_imports, opt_imports
if mod_code is None:
log.debug2("No code object found for %s", mod.__name__)
return all_imports
return req_imports, opt_imports
bcode = dis.Bytecode(mod_code)

# Parse all bytecode lines
current_dots = None
current_import_name = None
current_import_from = None
open_import = False
open_tries = set()
log.debug4("Byte Code:")
for line in bcode.dis().split("\n"):
log.debug4(line)
line_val = _get_value_col(line)

# Check whether this line ends a try
op_num = _get_op_number(line)
if op_num in open_tries:
open_tries.remove(op_num)
log.debug3("Closed try %d. Remaining open tries: %s", op_num, open_tries)

# Parse the individual ops
if "LOAD_CONST" in line:
if line_val.isnumeric():
current_dots = int(line_val)
Expand All @@ -394,6 +439,13 @@ def _get_imports(mod: ModuleType) -> Set[ModuleType]:
open_import = True
current_import_from = line_val
else:
# If this is a SETUP_FINALLY (try:), increment the number of try
# blocks open
if "SETUP_FINALLY" in line or "SETUP_EXCEPT" in line:
# Get the end target for this try
open_tries.add(_get_try_end_number(line))
log.debug3("Open tries: %s", open_tries)

# This closes an import, so figure out what the module is that is
# being imported!
if open_import:
Expand All @@ -402,7 +454,15 @@ def _get_imports(mod: ModuleType) -> Set[ModuleType]:
)
if import_mod is not None:
log.debug2("Adding import module [%s]", import_mod.__name__)
all_imports.add(import_mod)
if open_tries:
log.debug(
"Found optional dependency of [%s]: %s",
mod.__name__,
import_mod.__name__,
)
opt_imports.add(import_mod)
else:
req_imports.add(import_mod)

# If this is a STORE_NAME, subsequent "from" statements may use the
# same dots and name
Expand All @@ -422,7 +482,7 @@ def _get_imports(mod: ModuleType) -> Set[ModuleType]:
current_import_from,
)

return all_imports
return req_imports, opt_imports


def _find_parent_direct_deps(
Expand All @@ -444,14 +504,16 @@ def _find_parent_direct_deps(
for i in range(1, len(mod_name_parts)):
parent_mod_name = ".".join(mod_name_parts[:i])
parent_deps = module_deps_map.get(parent_mod_name, {})
for dep in parent_deps:
if not dep.startswith(mod_base_name) and dep not in mod_deps:
for dep, parent_dep_opt in parent_deps.items():
currently_optional = mod_deps.get(dep, True)
if not dep.startswith(mod_base_name) and currently_optional:
log.debug3(
"Adding direct-dependency of parent mod [%s]: %s",
"Adding direct-dependency of parent mod [%s] to [%s]: %s",
parent_mod_name,
mod_name,
dep,
)
mod_deps.add(dep)
mod_deps[dep] = currently_optional and parent_dep_opt
parent_direct_deps.setdefault(mod_name, {}).setdefault(
parent_mod_name, set()
).add(dep)
Expand All @@ -463,7 +525,7 @@ def _flatten_deps(
module_name: str,
module_deps_map: Dict[str, List[str]],
parent_direct_deps: Dict[str, Dict[str, List[str]]],
) -> Dict[str, List[str]]:
) -> Tuple[Dict[str, List[str]], Dict[str, bool]]:
"""Flatten the names of all modules that the target module depends on"""

# Look through all modules that are directly required by this target module.
Expand Down Expand Up @@ -521,16 +583,46 @@ def _flatten_deps(
# Create the flattened dependencies with the source lists for each
mod_base_name = module_name.partition(".")[0]
flat_base_deps = {}
optional_deps_map = {}
for dep, dep_sources in all_deps.items():
if not dep.startswith(mod_base_name):
# Truncate the dep_sources entries and trim to avoid duplicates
dep_root_mod_name = dep.partition(".")[0]
flat_dep_sources = flat_base_deps.setdefault(dep_root_mod_name, [])
opt_dep_values = optional_deps_map.setdefault(dep_root_mod_name, [])
for dep_source in dep_sources:
log.debug4("Considering dep source list for %s: %s", dep, dep_source)

# If any link in the dep_source is optional, the whole
# dep_source should be considered optional
is_optional = False
for parent_idx, dep_mod in enumerate(dep_source[1:] + [dep]):
dep_parent = dep_source[parent_idx]
log.debug4(
"Checking whether [%s -> %s] is optional (dep=%s)",
dep_parent,
dep_mod,
dep_root_mod_name,
)
if module_deps_map.get(dep_parent, {}).get(dep_mod, False):
log.debug4("Found optional link %s -> %s", dep_parent, dep_mod)
is_optional = True
break
opt_dep_values.append(
[
is_optional,
dep_source,
]
)

flat_dep_source = dep_source
if dep_root_mod_name in dep_source:
flat_dep_source = dep_source[: dep_source.index(dep_root_mod_name)]
if flat_dep_source not in flat_dep_sources:
flat_dep_sources.append(flat_dep_source)
return flat_base_deps
log.debug3("Optional deps map for [%s]: %s", module_name, optional_deps_map)
optional_deps_map = {
mod: all([opt_val[0] for opt_val in opt_vals])
for mod, opt_vals in optional_deps_map.items()
}
return flat_base_deps, optional_deps_map
Loading