Model Weight Refit #2900
Replies: 1 comment 1 reply
-
Initiative:After a model is compiled to an fx graph module, we may still want to make modifications such as updating the weights or even make modifications to the original structure. We want to design a handy way for users to interact with and easily modify compiled graph modules. Design ConceptWe create something like Conceptually this module would present just as an nn.Module to the user in all ways except when it is run it calls the compiled module under the hood. Users can also use inheritance and create custom subclasses that achieve custom behaviors of nn.Module. Example Usecase 1Users can load the LoRA into nn.Module of stable diffusion pipeline (which will be on host memory), and that automatically triggers refit or users can initiate the refit manually. Example Usecase 2In model training, after a backward propagation, the state_dict gets updated, and that automatically triggers refit. |
Beta Was this translation helpful? Give feedback.
-
Model Weight Refit
TL;DR
TensorRT supports updating engine weights after compilation via the nvinfer1::IRefitter class, referenced here in C++, and here in Python. This could be a beneficial feature to bring into Torch-TensorRT, specifically the FX path, since models which are pre-compiled and saved can be easily refitted to new training weights, so long as the model architecture is unchanged. This can save hours of compilation time and enables different extensions such as LoRA and on-cloud pre-compiled TensorRT engine.
User scenario
Alex is a digital artist who frequently uses Stable Diffusion to generate AI-powered artwork. To achieve various artistic styles, Alex utilizes different LoRA (Low-Rank Adaptation) configurations. Each LoRA configuration provides a unique style, enabling Alex to create diverse visual effects and aesthetics in their artwork. To maximize the creating efficiency, he uses Torch-TensorRT to accelerate the stable diffusion to make the creation process faster.
However, every time when he uses a new LoRA, the compilation time of the TensorRT module is around 10 minutes. This significantly slows down Alex's workflow. Instead of focusing on creating new art, Alex spends a considerable amount of time waiting for the model to recompile.
With the proposed engine refit feature, Alex can apply the LoRA within one minute. This significantly cuts the wait time and he can switch different combinations of LoRA whenever he wants.
Problem
Building a TensorRT Engine is time-consuming due to complex procedures like kernel auto-tuning. For instance, compiling a large language model can take several hours. When model weights are frequently updated, such as during A/B testing of different versions or adding adapters for various purposes, the need to repeatedly recompile the engine becomes highly inefficient. This can potentially cause
Motivations and usecases
Model weight refit will reduce time spent compiling models with Torch-TensorRT by 80% since models would only need to be compiled once per architecture, and subsequent weight updates can be propagated into the compiled model post-compilation, without the overhead of recompiling. This also enables
Proposed APIs
Users of this feature would first have a compiled TensorRT Graph Module ready, for example with:
then save their model. Then, at a later time when loading the compiled model, if the weights have updated from their original values, the user could call the Model Weight Refit function to refit the stored weights in TensorRT Graph Modules.
Example Workflow
User can first load the previously compiled TensorRT Graph Module using
trt_gm.load_state_dict(state_dict)
The additional step, as per the proposed API would be to call:
This function would parse the new exported program, determine the mapping between weights in the exported program and in TensorRT engines, and return a copy graph module with updated weights using the TRT Python API for TRT-accelerated modules, and the Torch API for non-accelerated modules.
Implementation Design
High Level Explanation
After the export of the PyTorch model, the newly compiled graph module will first go through the ATen Tracing and Lowering. After that, the graph will be partitioned into several subgraphs if there are any graph breaks resulting from unsupported operations. Then each of these subgraphs is converted to INetworkDefinition. These INetworkDefinition are eventually used to refit the weight in each TensorRT engine in the compiled Graph Module.
Model interpretation
The model interpretation is mostly the same process as when the model is first compiled into a TRT engine. To make sure that the model with the new weights has the same compilation setting as the compiled model, the old settings are stored and re-used. The references to the settings are stored in graph modules (PythonTensorRTModule/TensorRTModule).
Mapping Construction
After the INetworkDefinition is constructed, different layers of INetworkDefinition are examined to extract the weights to be refitted. Specifically, weights like
are extracted and map the weights to the keys of TensorRT engine weights.
Extensions Required to Core API Implementations
The existing library should not require many changes, as this add-on would simply add functionality while preserving existing core APIs.
One small change is that we want to store the settings used to compile a module and add the reference to the module. In this way, during the second module parsing, we can re-use the settings.
Target Platforms
This feature targets the Stable Diffusion refitting on Windows and LLMs refitting on Linux.
Implementation Phases
Prototype - Small/Medium
MVP
1.4.0
- Mediumrefit_module_weights
function, including refitting weights with multiple TRT-accelerated submodules and multiple Torch/FX non-accelerated submodulesExtension Phase 1 [Potential] - Medium
LoRA for Stable Diffusion 3
Refitting acceleration
state_dict
and in the TRT Engine. If that is successful, no re-interpretation is needed.LLM Parameter Efficient Fine Tuning
Beta Was this translation helpful? Give feedback.
All reactions