Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added refitting acceleration #2983

Merged
merged 9 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 134 additions & 29 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import collections.abc
import copy
import logging
from typing import Any, Optional, Sequence, Tuple
from typing import Any, List, Optional, Sequence, Tuple

import numpy as np
import tensorrt as trt
Expand All @@ -13,7 +13,7 @@
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import partitioning
from torch_tensorrt.dynamo._exporter import inline_torch_modules
from torch_tensorrt.dynamo.conversion import CompilationSettings
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
Expand Down Expand Up @@ -108,38 +108,97 @@ def construct_refit_mapping(
return weight_map


def construct_refit_mapping_from_weight_name_map(
weight_name_map: dict[Any, Any], state_dict: dict[Any, Any]
) -> dict[Any, Any]:
engine_weight_map = {}
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
# Batch Norm Layer
params = {}
for w in sd_weight_name:
params[w.split(".")[-1]] = state_dict[w]
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7)
shift = params["bias"] - params["running_mean"] * scale
# Set scale to scale or shift to shift
engine_weight_map[engine_weight_name] = eval(
engine_weight_name.split(" ")[-1].lower()
)

elif sd_weight_name not in state_dict:
# If weights is not in sd, we can leave it unchanged
continue
else:
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name]

engine_weight_map[engine_weight_name] = (
engine_weight_map[engine_weight_name]
.clone()
.reshape(-1)
.contiguous()
.to(torch_dtype),
trt_dtype,
)

return engine_weight_map


def _refit_single_trt_engine_with_gm(
new_gm: torch.fx.GraphModule,
old_engine: trt.ICudaEngine,
input_list: Tuple[Any, ...],
input_list: Sequence[Any],
settings: CompilationSettings = CompilationSettings(),
weight_name_map: Optional[dict[str, List[str]]] = None,
) -> None:
"""
Refit a TensorRT Engine in place
"""
# Get the refitting mapping
mapping = construct_refit_mapping(new_gm, input_list, settings)

refitted = set()

trt_wt_location = trt.TensorLocation.HOST
refitter = trt.Refitter(old_engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()

for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)
if weight_name_map:
# Get the refitting mapping
trt_wt_location = trt.TensorLocation.DEVICE
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict()
)
for layer_name in weight_list:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")
else:
mapping = construct_refit_mapping(new_gm, input_list, settings)
trt_wt_location = trt.TensorLocation.HOST
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")

if not refitter.refit_cuda_engine():
logger.error("Error: failed to refit new weights.")
exit(0)
raise AssertionError("Refitting failed.")


def refit_module_weights(
Expand All @@ -148,6 +207,8 @@ def refit_module_weights(
arg_inputs: Optional[Tuple[Any, ...]] = None,
kwarg_inputs: Optional[dict[str, Any]] = None,
verify_output: bool = False,
use_weight_map_cache: bool = True,
in_place: bool = False,
) -> torch.fx.GraphModule:
"""
Refit a compiled graph module with ExportedProgram. This performs weight updates in compiled_module without recompiling the engine.
Expand All @@ -170,7 +231,12 @@ def refit_module_weights(
if len(list(compiled_module.named_children())) == 0:
inline_module = True

compiled_module = copy.deepcopy(compiled_module)
if not in_place:
compiled_module = copy.deepcopy(compiled_module)
elif inline_module:
raise AssertionError(
"Exported program does not support modifying in place. Please set inplace to false and use the returned graph module."
)

# Get the settings and check the setting to be uniform
settings: CompilationSettings = None
Expand All @@ -182,13 +248,14 @@ def refit_module_weights(
for name, engine in compiled_module.__dict__.items()
if "engine" in name
]
encoded_settings = compiled_submodules[0][1].__getstate__()[0][
# [('_run_on_acc_0', inline_module)]
encoded_metadata = compiled_submodules[0][1].__getstate__()[0][
SERIALIZED_METADATA_IDX
]
assert (
encoded_settings != ""
), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True."
settings = TorchTensorRTModule.decode_metadata(encoded_settings)
encoded_metadata != ""
cehongwang marked this conversation as resolved.
Show resolved Hide resolved
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True"
settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"]
# Handle torch modules
compiled_submodules_map = dict(compiled_submodules)
for name, submodule in compiled_module.named_children():
Expand Down Expand Up @@ -287,6 +354,7 @@ def refit_module_weights(
# Extract engine from the submodule
try:
if inline_module:
weight_name_map = None
compiled_submodule = compiled_submodules_map[name]
# If this is a torch module, load the old state_dict
if "_run_on_acc" not in name:
Expand All @@ -297,8 +365,33 @@ def refit_module_weights(
engine = get_engine_from_encoded_engine(
engine_info[ENGINE_IDX], runtime
)
if use_weight_map_cache:
encoded_metadata = compiled_submodule.__getstate__()[0][
SERIALIZED_METADATA_IDX
]
weight_name_map = TorchTensorRTModule.decode_metadata(
encoded_metadata
)["weight_name_map"]
if not weight_name_map:
use_weight_map_cache = False
logger.warning(
cehongwang marked this conversation as resolved.
Show resolved Hide resolved
"This engine does not have a weight map cache. Rebuilding the weight map"
)
else:
compiled_submodule = getattr(compiled_module, name)
weight_name_map = None
if use_weight_map_cache:
try:
weight_name_map = compiled_submodule.weight_name_map
cehongwang marked this conversation as resolved.
Show resolved Hide resolved
except AttributeError:
logger.warning(
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
)
if not weight_name_map:
use_weight_map_cache = False
logger.warning(
"This engine does not have a weight map cache. Rebuilding the weight map"
)
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
engine = compiled_submodule.engine
elif isinstance(compiled_submodule, TorchTensorRTModule):
Expand Down Expand Up @@ -335,13 +428,25 @@ def refit_module_weights(
to_torch_device(settings.device),
name,
)

_refit_single_trt_engine_with_gm(
new_gm=new_submodule,
old_engine=engine,
input_list=submodule_inputs,
settings=settings,
)
try:
_refit_single_trt_engine_with_gm(
new_gm=new_submodule,
old_engine=engine,
input_list=submodule_inputs,
settings=settings,
weight_name_map=weight_name_map,
)
except AssertionError as e:
# If fast_refit is used and failed, we fall back to regular refit
logger.warning(e)
if use_weight_map_cache and weight_name_map:
_refit_single_trt_engine_with_gm(
new_gm=new_submodule,
old_engine=engine,
input_list=submodule_inputs,
settings=settings,
weight_name_map=None,
)

if isinstance(compiled_submodule, TorchTensorRTModule):
serialized_engine = bytes(engine.serialize())
Expand Down
Loading
Loading