diff --git a/docs/website/docs/guides/ml-frameworks/pytorch.md b/docs/website/docs/guides/ml-frameworks/pytorch.md index 739db736a10c..e80dec9e852d 100644 --- a/docs/website/docs/guides/ml-frameworks/pytorch.md +++ b/docs/website/docs/guides/ml-frameworks/pytorch.md @@ -321,6 +321,127 @@ their values independently at runtime. self.value = new_value ``` +#### :octicons-file-symlink-file-16: Using external parameters + +Model parameters can be stored in standalone files that can be efficiently +stored and loaded separately from model compute graphs. See the +[Parameters guide](../parameters.md) for more general information about +parameters in IREE. + +When using iree-turbine, the `aot.externalize_module_parameters()` function +separates parameters from program modules and encodes a symbolic relationship +between them so they can be loaded at runtime. + +We use [Safetensors](https://huggingface.co/docs/safetensors/) here to store the +models parameters on disk, so that they can be loaded later during runtime. + +```python +import torch +from safetensors.torch import save_file +import numpy as np +import shark_turbine.aot as aot + +class LinearModule(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(in_features, out_features)) + self.bias = torch.nn.Parameter(torch.randn(out_features)) + + def forward(self, input): + return (input @ self.weight) + self.bias + +linear_module = LinearModule(4,3) + +# Create a params dictionary. Note that the keys here match LinearModule's +# attributes. We will use the saved safetensor file for use from the command +# line. +wt = linear_module.weight.t().contiguous() +bias = linear_module.bias.t().contiguous() +params = { "weight": wt, "bias": bias } +save_file(params, "params.safetensors") + +# Externalize the model parameters. This removes weight tensors from the IR +# module, allowing them to be loaded at runtime. Symbolic references to these +# parameters are still retained in the IR. +aot.externalize_module_parameters(linear_module) + +input = torch.randn(4) +exported_module = aot.export(linear_module, input) + +# Compile the exported module, to generate the binary. When `save_to` is +# not None, the binary will be stored at the path passed in to `save_to`. +# Here, we pass in None, so that the binary can stored in a variable. +binary = exported_module.compile(save_to=None) + +# Save the input as an npy tensor, so that it can be passed in through the +# command line to `iree-run-module`. +input_np = input.numpy() +np.save("input.npy", input_np) +``` + +=== "Python runtime" + + Runtime invocation now requires loading the parameters as a separate module. + To get the parameters as a module, iree.runtime provides a convenient method, + called `create_io_parameters_module()`. + + ```python + import iree.runtime as ireert + + # To load the parameters, we need to define ParameterIndex for each + # parameter class. + idx = ireert.ParameterIndex() + idx.add_buffer("weight", wt.detach().numpy().tobytes()) + idx.add_buffer("bias", bias.detach().numpy().tobytes()) + + + # Create the runtime instance, and load the runtime. + config = ireert.Config(driver_name="local-task") + instance = config.vm_instance + + param_module = ireert.create_io_parameters_module( + instance, idx.create_provider(scope="model"), + ) + + # Load the runtime. There are essentially two modules to load, one for the + # weights, and one for the main module. Ensure that the VMFB file is not + # already open or deleted before use. + vm_modules = ireert.load_vm_modules( + param_module, + ireert.create_hal_module(instance, config.device), + ireert.VmModule.copy_buffer(instance, binary.map_memory()), + config=config, + ) + + # vm_modules is a list of modules. The last module in the list is the one + # generated from the binary, so we use that to generate an output. + result = vm_modules[-1].main(input) + print(result.to_host()) + ``` + +=== "Command line tools" + + It is also possible to save the VMFB binary to disk, then call `iree-run-module` + through the command line to generate outputs. + + ```python + # When save_to is not None, the binary is saved to the given path, + # and a None value is returned. + binary = exported_module.compile(save_to="compiled_module.vmfb") + ``` + + The stored safetensors file, the input tensor, and the VMFB can now be passed + in to IREE through the command line. + + ```bash + iree-run-module --module=compiled_module.vmfb --parameters=model=params.safetensors \ + --input=@input.npy + ``` + + Note here that the `--parameters` flag has `model=` following it immediately. + This simply specifies the scope of the parameters, and is reflected in the + compiled module. + #### :octicons-code-16: Samples | Code samples | |