Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] [Doc] Add tutorial for Inductor Cpp Wrapper as a prototype feature for PyTorch 2.1 #2510

Merged
merged 45 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7344c43
init cpp wrapper rst
chunyuan-w Jul 18, 2023
df6f193
update prototype_index.rst
chunyuan-w Jul 27, 2023
943a15d
fix format
chunyuan-w Jul 27, 2023
ae44464
fix typo
chunyuan-w Jul 27, 2023
33da0e5
add Conclusion
chunyuan-w Jul 27, 2023
9dcd9fc
minor fix
chunyuan-w Jul 27, 2023
dd483e2
add frontend model code
chunyuan-w Jul 27, 2023
b8be48d
add new line
chunyuan-w Jul 27, 2023
75e947f
Merge branch 'main' into chunyuan/cpp_wrapper
chunyuan-w Jul 28, 2023
a4a465e
Merge branch 'main' into chunyuan/cpp_wrapper
chunyuan-w Jul 29, 2023
845bca1
add more details in Intro
chunyuan-w Aug 2, 2023
9d54660
minor fix
chunyuan-w Aug 2, 2023
3f0845a
reorder the content
chunyuan-w Aug 2, 2023
d6e99b5
fix typo
chunyuan-w Aug 2, 2023
abb54d5
add GPU example code
chunyuan-w Aug 8, 2023
65f82cf
add empty line
chunyuan-w Aug 8, 2023
3af7b0d
minor fix
chunyuan-w Aug 8, 2023
94f9b33
Merge branch 'main' into chunyuan/cpp_wrapper
chunyuan-w Aug 9, 2023
bc3cfbc
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
c5af6c7
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
da3c5a0
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
6d34267
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
a5ab18d
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
4eea963
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
201a5c1
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
528138c
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
2ae1b27
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
7d0ef78
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
926a112
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
0e016e8
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
a10b337
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
788766f
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 15, 2023
e736959
add more details in Conclusion
chunyuan-w Aug 15, 2023
8496aa3
addempty line
chunyuan-w Aug 15, 2023
46d2d66
Merge branch 'main' into chunyuan/cpp_wrapper
Aug 15, 2023
5e297ef
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 16, 2023
3fefdee
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 16, 2023
aa0d6aa
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 16, 2023
ec10ba8
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 16, 2023
af9e20d
Update prototype_source/inductor_cpp_wrapper_tutorial.rst
chunyuan-w Aug 16, 2023
a095533
Merge branch 'main' into chunyuan/cpp_wrapper
chunyuan-w Aug 16, 2023
00f398c
Merge branch 'main' into chunyuan/cpp_wrapper
Sep 5, 2023
1e6a461
Merge branch 'main' into chunyuan/cpp_wrapper
Sep 19, 2023
475b972
Merge branch 'main' into chunyuan/cpp_wrapper
Sep 26, 2023
fe47acd
Merge branch 'main' into chunyuan/cpp_wrapper
Oct 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions prototype_source/inductor_cpp_wrapper_tutorial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
Inductor C++ Wrapper Tutorial
==============================================================

**Author**: `Chunyuan Wu <https://github.com/chunyuan-w>`_, `Bin Bao <https://github.com/desertfire>`__, `Jiong Gong <https://github.com/jgong5>`__

Prerequisites:
----------------
- `torch.compile and TorchInductor concepts in PyTorch <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__

Introduction
------------

svekars marked this conversation as resolved.
Show resolved Hide resolved
Python, as the primary interface of PyTorch, is easy to use and efficient for development and debugging.
The Inductor's default wrapper generates Python code to invoke generated kernels and external kernels.
However, in deployments requiring high performance, Python, as an interpreted language, runs relatively slower compared to compiled languages.

We implemented an Inductor C++ wrapper by leveraging the PyTorch C++ APIs
to generate pure C++ code that combines the generated and external kernels.
This allows for the execution of each captured Dynamo graph in pure C++,
thereby reducing the Python overhead within the graph.


Enabling the API
------------
This feature is still in prototype stage. To activate this feature, add the following to your code:

.. code:: python

import torch._inductor.config as config
config.cpp_wrapper = True

This will speed up your models by reducing the Python overhead of the Inductor wrapper.


Example code
------------

We will use the below frontend code as an example:

.. code:: python

import torch

def fn(x):
return torch.tensor(list(range(2, 40, 2)), device=x.device) + x

x = torch.randn(1)
opt_fn = torch.compile()(fn)
y = opt_fn(x)


**For CPU**

The main part of Inductor-generated code with the default Python wrapper will look like this:

.. code:: python

def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (1, ), (1, ))
buf0 = empty_strided((19, ), (1, ), device='cpu', dtype=torch.float32)
cpp_fused_add_lift_fresh_0(c_void_p(constant0.data_ptr()), c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
del arg0_1
return (buf0, )

By turning on the C++ wrapper, the generated code for the ``call`` function becomes a C++ function
``inductor_entry_cpp`` of the C++ extension ``module``:

.. code:: python
svekars marked this conversation as resolved.
Show resolved Hide resolved

std::vector<at::Tensor> inductor_entry_cpp(const std::vector<at::Tensor>& args) {
at::Tensor arg0_1 = args[0];
at::Tensor constant0 = args[1];
auto buf0 = at::empty_strided({19L, }, {1L, }, at::device(at::kCPU).dtype(at::kFloat));
cpp_fused_add_lift_fresh_0((long*)(constant0.data_ptr()), (float*)(arg0_1.data_ptr()), (float*)(buf0.data_ptr()));
arg0_1.reset();
return {buf0};
}

module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'c2buojsvlqbywxe3itb43hldieh4jqulk72iswa2awalwev7hjn2', False)

def _wrap_func(f):
def g(args):
args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]
constants_tensor = [constant0]
args_tensor.extend(constants_tensor)

return f(args_tensor)
return g
call = _wrap_func(module.inductor_entry_cpp)

**For GPU**

Based on the same example code, the generated code for GPU will look like this:

.. code:: python

def call(args):
arg0_1, = args
args.clear()
assert_size_stride(arg0_1, (1, ), (1, ))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
buf0 = empty_strided((19, ), (1, ), device='cuda', dtype=torch.float32)
# Source Nodes: [add, tensor], Original ATen: [aten.add, aten.lift_fresh]
stream0 = get_cuda_stream(0)
triton_poi_fused_add_lift_fresh_0.run(constant0, arg0_1, buf0, 19, grid=grid(19), stream=stream0)
run_intermediate_hooks('add', buf0)
del arg0_1
return (buf0, )

With the C++ wrapper turned on, the below equivalent C++ code will be generated:

.. code:: python

std::vector<at::Tensor> inductor_entry_cpp(const std::vector<at::Tensor>& args) {
at::Tensor arg0_1 = args[0];
at::Tensor constant0 = args[1];

at::cuda::CUDAGuard device_guard(0);
auto buf0 = at::empty_strided({19L, }, {1L, }, at::TensorOptions(c10::Device(at::kCUDA, 0)).dtype(at::kFloat));
// Source Nodes: [add, tensor], Original ATen: [aten.add, aten.lift_fresh]
if (triton_poi_fused_add_lift_fresh_0 == nullptr) {
triton_poi_fused_add_lift_fresh_0 = loadKernel("/tmp/torchinductor_user/mm/cmm6xjgijjffxjku4akv55eyzibirvw6bti6uqmfnruujm5cvvmw.cubin", "triton_poi_fused_add_lift_fresh_0_0d1d2d3");
}
CUdeviceptr var_0 = reinterpret_cast<CUdeviceptr>(constant0.data_ptr());
CUdeviceptr var_1 = reinterpret_cast<CUdeviceptr>(arg0_1.data_ptr());
CUdeviceptr var_2 = reinterpret_cast<CUdeviceptr>(buf0.data_ptr());
auto var_3 = 19;
void* kernel_args_var_0[] = {&var_0, &var_1, &var_2, &var_3};
cudaStream_t stream0 = at::cuda::getCurrentCUDAStream(0);
launchKernel(triton_poi_fused_add_lift_fresh_0, 1, 1, 1, 1, 0, kernel_args_var_0, stream0);
arg0_1.reset();
return {buf0};
}

module = CppWrapperCodeCache.load(cpp_wrapper_src, 'inductor_entry_cpp', 'czbpeilh4qqmbyejdgsbpdfuk2ss5jigl2qjb7xs4gearrjvuwem', True)

def _wrap_func(f):
def g(args):
args_tensor = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]
constants_tensor = [constant0]
args_tensor.extend(constants_tensor)

return f(args_tensor)
return g
call = _wrap_func(module.inductor_entry_cpp)


Conclusion
------------

In this tutorial, we introduced a new C++ wrapper in TorchInductor to speed up your models with just two lines of code changes.
We explained the motivation of this new feature and walked through the easy-to-use API to activate this experimental feature.
Furthermore, we demonstrated the Inductor-generated code using the default Python wrapper and the new C++ wrapper on both CPU and GPU
to visually showcase the difference between these two wrappers.

This feature is still in prototype stage. If you have any feature requests or run into any issues, please file a bug report at `GitHub issues <https://github.com/pytorch/pytorch/issues>`_.
10 changes: 10 additions & 0 deletions prototype_source/prototype_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ Prototype features are not available as part of binary distributions like PyPI o
:link: ../prototype/maskedtensor_adagrad.html
:tags: MaskedTensor

.. Model-Optimization

.. customcarditem::
:header: Inductor Cpp Wrapper Tutorial
:card_description: Speed up your models with Inductor Cpp Wrapper
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../prototype/inductor_cpp_wrapper_tutorial.html
:tags: Model-Optimization

.. End of tutorial card section

.. raw:: html
Expand All @@ -208,6 +217,7 @@ Prototype features are not available as part of binary distributions like PyPI o
prototype/fx_graph_mode_ptq_static.html
prototype/graph_mode_dynamic_bert_tutorial.html
prototype/quantization_in_pytorch_2_0_export_tutorial.html
prototype/inductor_cpp_wrapper_tutorial.html
prototype/ios_gpu_workflow.html
prototype/nnapi_mobilenetv2.html
prototype/tracing_based_selective_build.html
Expand Down