diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index afa5588919..01e6a82926 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -70,3 +70,5 @@ pycocotools semilearn==0.3.2 torchao==0.5.0 segment_anything==1.0 +torchrec==0.8.0 +fbgemm-gpu==0.8.0 diff --git a/.jenkins/metadata.json b/.jenkins/metadata.json index 2f1a9933aa..9218ff6669 100644 --- a/.jenkins/metadata.json +++ b/.jenkins/metadata.json @@ -28,6 +28,9 @@ "intermediate_source/model_parallel_tutorial.py": { "needs": "linux.16xlarge.nvidia.gpu" }, + "intermediate_source/torchrec_intro_tutorial.py": { + "needs": "linux.g5.4xlarge.nvidia.gpu" + }, "recipes_source/torch_export_aoti_python.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" }, diff --git a/en-wordlist.txt b/en-wordlist.txt index e69cbaa1a5..bdc49536e2 100644 --- a/en-wordlist.txt +++ b/en-wordlist.txt @@ -619,3 +619,32 @@ warmup webp wsi wsis +Meta's +RecSys +TorchRec +sharding +TBE +dtype +EBC +sharder +hyperoptimized +DMP +unsharded +lookups +KJTs +amongst +async +everytime +prototyped +GBs +HBM +gloo +nccl +Localhost +gpu +torchmetrics +url +colab +sharders +Criteo +torchrec \ No newline at end of file diff --git a/index.rst b/index.rst index 95c4a8f3ef..bd967142cb 100644 --- a/index.rst +++ b/index.rst @@ -846,7 +846,7 @@ Welcome to PyTorch Tutorials :header: Introduction to TorchRec :card_description: TorchRec is a PyTorch domain library built to provide common sparsity & parallelism primitives needed for large-scale recommender systems. :image: _static/img/thumbnails/torchrec.png - :link: intermediate/torchrec_tutorial.html + :link: intermediate/torchrec_intro_tutorial.html :tags: TorchRec,Recommender .. customcarditem:: @@ -1180,7 +1180,7 @@ Additional Resources :hidden: :caption: Recommendation Systems - intermediate/torchrec_tutorial + intermediate/torchrec_intro_tutorial advanced/sharding .. toctree:: diff --git a/intermediate_source/torchrec_intro_tutorial.py b/intermediate_source/torchrec_intro_tutorial.py new file mode 100644 index 0000000000..70b2a5b472 --- /dev/null +++ b/intermediate_source/torchrec_intro_tutorial.py @@ -0,0 +1,1109 @@ +""" +Introduction to TorchRec +================================== + +**TorchRec** is a PyTorch library tailored for building scalable and efficient recommendation systems using embeddings. +This tutorial guides you through the installation process, introduces the concept of embeddings, and highlights their importance in +recommendation systems. It offers practical demonstrations on implementing embeddings with PyTorch +and TorchRec, focusing on handling large embedding tables through distributed training and advanced optimizations. + +.. grid:: 2 + + .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn + :class-card: card-prerequisites + + * Fundamentals of embeddings and their role in recommendation systems + * How to set up TorchRec to manage and implement embeddings in PyTorch environments + * Explore advanced techniques for distributing large embedding tables across multiple GPUs + + .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites + :class-card: card-prerequisites + + * PyTorch v2.5 or later with CUDA 11.8 or later + * Python 3.9 or later + * `FBGEMM `__ + + +""" + +############################################### +# Install Dependencies +# ^^^^^^^^^^^^^^^^^^^^ +# +# Before running this tutorial in Google Colab or other environment, install the +# following dependencies: +# +# .. code-block:: sh +# +# !pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U +# !pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121 +# !pip3 install torchmetrics==1.0.3 +# !pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121 + + + +###################################################################### +# Embeddings +# ~~~~~~~~~~ +# +# When building recommendation systems, categorical features typically +# have massive cardinality, posts, users, ads, and so on. +# +# In order to represent these entities and model these relationships, +# **embeddings** are used. In machine learning, **embeddings are a vectors +# of real numbers in a high-dimensional space used to represent meaning in +# complex data like words, images, or users**. +# +# Embeddings in RecSys +# ~~~~~~~~~~~~~~~~~~~~ +# +# Now you might wonder, how are these embeddings generated in the first +# place? Well, embeddings are represented as individual rows in an +# **Embedding Table**, also referred to as embedding weights. The reason +# for this is that embeddings or embedding table weights are trained just +# like all of the other weights of the model via gradient descent! +# +# Embedding tables are simply a large matrix for storing embeddings, with +# two dimensions (B, N), where: +# +# * B is the number of embeddings stored by the table +# * N is the number of dimensions per embedding (N-dimensional embedding). +# +# The inputs to embedding tables represent embedding lookups to retrieve +# the embedding for a specific index or row. In recommendation systems, such +# as those used in many large systems, unique IDs are not only used for +# specific users, but also across entities like posts and ads to serve as +# lookup indices to respective embedding tables! +# +# Embeddings are trained in RecSys through the following process: +# +# * **Input/lookup indices are fed into the model, as unique IDs**. IDs are +# hashed to the total size of the embedding table to prevent issues when +# the ID > number of rows +# +# * Embeddings are then retrieved and **pooled, such as taking the sum or +# mean of the embeddings**. This is required as there can be a variable number of +# embeddings per example while the model expects consistent shapes. +# +# * The **embeddings are used in conjunction with the rest of the model to +# produce a prediction**, such as `Click-Through Rate +# (CTR) `__ +# for an ad. +# +# * The loss is calculated with the prediction and the label +# for an example, and **all weights of the model are updated through +# gradient descent and backpropagation, including the embedding weights** +# that were associated with the example. +# +# These embeddings are crucial for representing categorical features, such +# as users, posts, and ads, in order to capture relationships and make +# good recommendations. The `Deep learning recommendation +# model `__ (DLRM) paper talks more +# about the technical details of using embedding tables in RecSys. +# +# This tutorial introduces the concept of embeddings, showcase +# TorchRec specific modules and data types, and depict how distributed training +# works with TorchRec. +# + +import torch + + +###################################################################### +# Embeddings in PyTorch +# --------------------- +# +# In PyTorch, we have the following types of embeddings: +# +# * :class:`torch.nn.Embedding`: An embedding table where forward pass returns the +# embeddings themselves as is. +# +# * :class:`torch.nn.EmbeddingBag`: Embedding table where forward pass returns +# embeddings that are then pooled, for example, sum or mean, otherwise known +# as **Pooled Embeddings**. +# +# In this section, we will go over a very brief introduction to performing +# embedding lookups by passing in indices into the table. +# + +num_embeddings, embedding_dim = 10, 4 + +# Initialize our embedding table +weights = torch.rand(num_embeddings, embedding_dim) +print("Weights:", weights) + +# Pass in pre-generated weights just for example, typically weights are randomly initialized +embedding_collection = torch.nn.Embedding( + num_embeddings, embedding_dim, _weight=weights +) +embedding_bag_collection = torch.nn.EmbeddingBag( + num_embeddings, embedding_dim, _weight=weights +) + +# Print out the tables, we should see the same weights as above +print("Embedding Collection Table: ", embedding_collection.weight) +print("Embedding Bag Collection Table: ", embedding_bag_collection.weight) + +# Lookup rows (ids for embedding ids) from the embedding tables +# 2D tensor with shape (batch_size, ids for each batch) +ids = torch.tensor([[1, 3]]) +print("Input row IDS: ", ids) + +embeddings = embedding_collection(ids) + +# Print out the embedding lookups +# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above +print("Embedding Collection Results: ") +print(embeddings) +print("Shape: ", embeddings.shape) + +# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above +pooled_embeddings = embedding_bag_collection(ids) + +print("Embedding Bag Collection Results: ") +print(pooled_embeddings) +print("Shape: ", pooled_embeddings.shape) + +# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on) +# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection +print("Mean: ", torch.mean(embedding_collection(ids), dim=1)) + + +###################################################################### +# Congratulations! Now you have a basic understanding of how to use +# embedding tables --- one of the foundations of modern recommendation +# systems! These tables represent entities and their relationships. For +# example, the relationship between a given user and the pages and posts +# they have liked. +# + + +###################################################################### +# TorchRec Features Overview +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# In the section above we've learned how to use embedding tables, one of the foundations of +# modern recommendation systems! These tables represent entities and +# relationships, such as users, pages, posts, etc. Given that these +# entities are always increasing, a **hash** function is typically applied +# to make sure the IDs are within the bounds of a certain embedding table. +# However, in order to represent a vast amount of entities and reduce hash +# collisions, these tables can become quite massive (think about the number of ads +# for example). In fact, these tables can become so massive that they +# won't be able to fit on 1 GPU, even with 80G of memory. +# +# In order to train models with massive embedding tables, sharding these +# tables across GPUs is required, which then introduces a whole new set of +# problems and opportunities in parallelism and optimization. Luckily, we have +# the TorchRec library that has encountered, consolidated, and addressed +# many of these concerns. TorchRec serves as a **library that provides +# primitives for large scale distributed embeddings**. +# +# Next, we will explore the major features of the TorchRec +# library. We will start with ``torch.nn.Embedding`` and will extend that to +# custom TorchRec modules, explore distributed training environment with +# generating a sharding plan for embeddings, look at inherent TorchRec +# optimizations, and extend the model to be ready for inference in C++. +# Below is a quick outline of what this section consists of: +# +# * TorchRec Modules and Data Types +# * Distributed Training, Sharding, and Optimizations +# * Inference +# +# Let's begin with importing TorchRec: + +import torchrec + + +###################################################################### +# TorchRec Modules and Data Types +# ------------------------------ +# +# This section goes over TorchRec Modules and data types including such +# entities as ``EmbeddingCollection`` and ``EmbeddingBagCollection``, +# ``JaggedTensor``, ``KeyedJaggedTensor``, ``KeyedTensor`` and more. +# +# From ``EmbeddingBag`` to ``EmbeddingBagCollection`` +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We have already explored :class:`torch.nn.Embedding` and :class:`torch.nn.EmbeddingBag`. +# TorchRec extends these modules by creating collections of embeddings, in +# other words modules that can have multiple embedding tables, with +# ``EmbeddingCollection`` and ``EmbeddingBagCollection`` +# We will use ``EmbeddingBagCollection`` to represent a group of +# embedding bags. +# +# In the example code below, we create an ``EmbeddingBagCollection`` (EBC) +# with two embedding bags, 1 representing **products** and 1 representing **users**. +# Each table, ``product_table`` and ``user_table``, is represented by a 64 dimension +# embedding of size 4096. +# + +ebc = torchrec.EmbeddingBagCollection( + device="cpu", + tables=[ + torchrec.EmbeddingBagConfig( + name="product_table", + embedding_dim=64, + num_embeddings=4096, + feature_names=["product"], + pooling=torchrec.PoolingType.SUM, + ), + torchrec.EmbeddingBagConfig( + name="user_table", + embedding_dim=64, + num_embeddings=4096, + feature_names=["user"], + pooling=torchrec.PoolingType.SUM, + ) + ] +) +print(ebc.embedding_bags) + + +###################################################################### +# Let’s inspect the forward method for ``EmbeddingBagCollection`` and the +# module’s inputs and outputs: +# + +import inspect + +# Let's look at the ``EmbeddingBagCollection`` forward method +# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``? +print(inspect.getsource(ebc.forward)) + + +###################################################################### +# TorchRec Input/Output Data Types +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# TorchRec has distinct data types for input and output of its modules: +# ``JaggedTensor``, ``KeyedJaggedTensor``, and ``KeyedTensor``. Now you +# might ask, why create new data types to represent sparse features? To +# answer that question, we must understand how sparse features are +# represented in code. +# +# Sparse features are otherwise known as ``id_list_feature`` and +# ``id_score_list_feature``, and are the **IDs** that will be used as +# indices to an embedding table to retrieve the embedding for that ID. To +# give a very simple example, imagine a single sparse feature being Ads +# that a user interacted with. The input itself would be a set of Ad IDs +# that a user interacted with, and the embeddings retrieved would be a +# semantic representation of those Ads. The tricky part of representing +# these features in code is that in each input example, **the number of +# IDs is variable**. One day a user might have interacted with only one ad +# while the next day they interact with three. +# +# A simple representation is shown below, where we have a ``lengths`` +# tensor denoting how many indices are in an example for a batch and a +# ``values`` tensor containing the indices themselves. +# + +# Batch Size 2 +# 1 ID in example 1, 2 IDs in example 2 +id_list_feature_lengths = torch.tensor([1, 2]) + +# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2 +id_list_feature_values = torch.tensor([5, 7, 1]) + + +###################################################################### +# Next, let's look at the offsets as well as what is contained in each batch +# + +# Lengths can be converted to offsets for easy indexing of values +id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0) + +print("Offsets: ", id_list_feature_offsets) +print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]]) +print( + "Second Batch: ", + id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]], +) + +from torchrec import JaggedTensor + +# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors! +jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths) + +# Automatically compute offsets from lengths +print("Offsets: ", jt.offsets()) + +# Convert to list of values +print("List of Values: ", jt.to_dense()) + +# ``__str__`` representation +print(jt) + +from torchrec import KeyedJaggedTensor + +# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection`` +# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets +# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both! + +product_jt = JaggedTensor( + values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1]) +) +user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2])) + +# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``? +kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt}) + +# Look at our feature keys for the ``KeyedJaggedTensor`` +print("Keys: ", kjt.keys()) + +# Look at the overall lengths for the ``KeyedJaggedTensor`` +print("Lengths: ", kjt.lengths()) + +# Look at all values for ``KeyedJaggedTensor`` +print("Values: ", kjt.values()) + +# Can convert ``KeyedJaggedTensor`` to dictionary representation +print("to_dict: ", kjt.to_dict()) + +# ``KeyedJaggedTensor`` string representation +print(kjt) + +# Q2: What are the offsets for the ``KeyedJaggedTensor``? + +# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before +result = ebc(kjt) +result + +# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results +print(result.keys()) + +# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined +# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each +# meaning embeddings for both features are of size 64. 64 + 64 = 128 +print(result.values().shape) + +# Nice to_dict method to determine the embeddings that belong to each feature +result_dict = result.to_dict() +for key, embedding in result_dict.items(): + print(key, embedding.shape) + + +###################################################################### +# Congrats! You now understand TorchRec modules and data types. +# Give yourself a pat on the back for making it this far. Next, we will +# learn about distributed training and sharding. +# + + +###################################################################### +# Distributed Training and Sharding +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Now that we have a grasp on TorchRec modules and data types, it's time +# to take it to the next level. +# +# Remember, the main purpose of TorchRec is to provide primitives for +# distributed embeddings. So far, we've only worked with embedding tables +# on a single device. This has been possible given how small the embedding tables +# have been, but in a production setting this isn't generally the case. +# Embedding tables often get massive, where one table can't fit on a single +# GPU, creating the requirement for multiple devices and a distributed +# environment. +# +# In this section, we will explore setting up a distributed environment, +# exactly how actual production training is done, and explore sharding +# embedding tables, all with TorchRec. +# +# **This section will also only use 1 GPU, though it will be treated in a +# distributed fashion. This is only a limitation for training, as training +# has a process per GPU. Inference does not run into this requirement** +# +# In the example code below, we set up our PyTorch distributed environment. +# +# .. warning:: +# If you are running this in Google Colab, you can only call this cell once, +# calling it again will cause an error as you can only initialize the process +# group once. + +import os + +import torch.distributed as dist + +# Set up environment variables for distributed training +# RANK is which GPU we are on, default 0 +os.environ["RANK"] = "0" +# How many devices in our "world", colab notebook can only handle 1 process +os.environ["WORLD_SIZE"] = "1" +# Localhost as we are training locally +os.environ["MASTER_ADDR"] = "localhost" +# Port for distributed training +os.environ["MASTER_PORT"] = "29500" + +# nccl backend is for GPUs, gloo is for CPUs +dist.init_process_group(backend="gloo") + +print(f"Distributed environment initialized: {dist}") + + +###################################################################### +# Distributed Embeddings +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# We have already worked with the main TorchRec module: +# ``EmbeddingBagCollection``. We have examined how it works along with how +# data is represented in TorchRec. However, we have not yet explored one +# of the main parts of TorchRec, which is **distributed embeddings**. +# +# GPUs are the most popular choice for ML workloads by far today, as they +# are able to do magnitudes more floating point operations/s +# (`FLOPs `__) than CPU. However, +# GPUs come with the limitation of scarce fast memory (HBM which is +# analogous to RAM for CPU), typically, ~10s of GBs. +# +# A RecSys model can contain embedding tables that far exceed the memory +# limit for 1 GPU, hence the need for distribution of the embedding tables +# across multiple GPUs, otherwise known as **model parallel**. On the +# other hand, **data parallel** is where the entire model is replicated on +# each GPU, which each GPU taking in a distinct batch of data for +# training, syncing gradients on the backwards pass. +# +# Parts of the model that **require less compute but more memory +# (embeddings) are distributed with model parallel** while parts that +# **require more compute and less memory (dense layers, MLP, etc.) are +# distributed with data parallel**. +# +# Sharding +# ~~~~~~~~ +# +# In order to distribute an embedding table, we split up the embedding +# table into parts and place those parts onto different devices, also +# known as “sharding”. +# +# There are many ways to shard embedding tables. The most common ways are: +# +# * Table-Wise: the table is placed entirely onto one device +# * Column-Wise: columns of embedding tables are sharded +# * Row-Wise: rows of embedding tables are sharded +# +# Sharded Modules +# ~~~~~~~~~~~~~~~ +# +# While all of this seems like a lot to deal with and implement, you're in +# luck. **TorchRec provides all the primitives for easy distributed +# training and inference**! In fact, TorchRec modules have two corresponding +# classes for working with any TorchRec module in a distributed +# environment: +# +# * **The module sharder**: This class exposes a ``shard`` API +# that handles sharding a TorchRec Module, producing a sharded module. +# * For ``EmbeddingBagCollection``, the sharder is `EmbeddingBagCollectionSharder `__ +# * **Sharded module**: This class is a sharded variant of a TorchRec module. +# It has the same input/output as a the regular TorchRec module, but much +# more optimized and works in a distributed environment. +# * For ``EmbeddingBagCollection``, the sharded variant is `ShardedEmbeddingBagCollection `__ +# +# Every TorchRec module has an unsharded and sharded variant. +# +# * The unsharded version is meant to be prototyped and experimented with. +# * The sharded version is meant to be used in a distributed environment for +# distributed training and inference. +# +# The sharded versions of TorchRec modules, for example +# ``EmbeddingBagCollection``, will handle everything that is needed for Model +# Parallelism, such as communication between GPUs for distributing +# embeddings to the correct GPUs. +# +# Refresher of our ``EmbeddingBagCollection`` module +ebc + +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology +from torchrec.distributed.types import ShardingEnv + +# Corresponding sharder for ``EmbeddingBagCollection`` module +sharder = EmbeddingBagCollectionSharder() + +# ``ProcessGroup`` from torch.distributed initialized 2 cells above +pg = dist.GroupMember.WORLD +assert pg is not None, "Process group is not initialized" + +print(f"Process Group: {pg}") + + +###################################################################### +# Planner +# ~~~~~~~ +# +# Before we can show how sharding works, we must know about the +# **planner**, which helps us determine the best sharding configuration. +# +# Given a number of embedding tables and a number of ranks, there are many +# different sharding configurations that are possible. For example, given +# 2 embedding tables and 2 GPUs, you can: +# +# * Place 1 table on each GPU +# * Place both tables on a single GPU and no tables on the other +# * Place certain rows and columns on each GPU +# +# Given all of these possibilities, we typically want a sharding +# configuration that is optimal for performance. +# +# That is where the planner comes in. The planner is able to determine +# given the number of embedding tables and the number of GPUs, what is the optimal +# configuration. Turns out, this is incredibly difficult to do manually, +# with tons of factors that engineers have to consider to ensure an +# optimal sharding plan. Luckily, TorchRec provides an auto planner when +# the planner is used. +# +# The TorchRec planner: +# +# * Assesses memory constraints of hardware +# * Estimates compute based on memory fetches as embedding lookups +# * Addresses data specific factors +# * Considers other hardware specifics like bandwidth to generate an optimal sharding plan +# +# In order to take into consideration all these variables, The TorchRec +# planner can take in `various amounts of data for embedding tables, +# constraints, hardware information, and +# topology `__ +# to aid in generating the optimal sharding plan for a model, which is +# routinely provided across stacks. +# +# To learn more about sharding, see our `sharding +# tutorial `__. +# + +# In our case, 1 GPU and compute on CUDA device +planner = EmbeddingShardingPlanner( + topology=Topology( + world_size=1, + compute_device="cuda", + ) +) + +# Run planner to get plan for sharding +plan = planner.collective_plan(ebc, [sharder], pg) + +print(f"Sharding Plan generated: {plan}") + + +###################################################################### +# Planner Result +# ~~~~~~~~~~~~~~ +# +# As you can see above, when running the planner there is quite a bit of output. +# We can see a lot of stats being calculated along with where our +# tables end up being placed. +# +# The result of running the planner is a static plan, which can be reused +# for sharding! This allows sharding to be static for production models +# instead of determining a new sharding plan everytime. Below, we use the +# sharding plan to finally generate our ``ShardedEmbeddingBagCollection``. +# + +# The static plan that was generated +plan + +env = ShardingEnv.from_process_group(pg) + +# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder`` +sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda")) + +print(f"Sharded EBC Module: {sharded_ebc}") + + +###################################################################### +# GPU Training with ``LazyAwaitable`` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Remember that TorchRec is a highly optimized library for distributed +# embeddings. A concept that TorchRec introduces to enable higher +# performance for training on GPU is a +# `LazyAwaitable `__. +# You will see ``LazyAwaitable`` types as outputs of various sharded +# TorchRec modules. All a ``LazyAwaitable`` type does is delay calculating some +# result as long as possible, and it does it by acting like an async type. +# + +from typing import List + +from torchrec.distributed.types import LazyAwaitable + + +# Demonstrate a ``LazyAwaitable`` type: +class ExampleAwaitable(LazyAwaitable[torch.Tensor]): + def __init__(self, size: List[int]) -> None: + super().__init__() + self._size = size + + def _wait_impl(self) -> torch.Tensor: + return torch.ones(self._size) + + +awaitable = ExampleAwaitable([3, 2]) +awaitable.wait() + +kjt = kjt.to("cuda") +output = sharded_ebc(kjt) +# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`? +print(output) + +kt = output.wait() +# Now we have our ``KeyedTensor`` after calling ``.wait()`` +# If you are confused as to why we have a ``KeyedTensor ``output, +# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module +print(type(kt)) + +print(kt.keys()) + +print(kt.values().shape) + +# Same output format as unsharded ``EmbeddingBagCollection`` +result_dict = kt.to_dict() +for key, embedding in result_dict.items(): + print(key, embedding.shape) + + +###################################################################### +# Anatomy of Sharded TorchRec modules +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We have now successfully sharded an ``EmbeddingBagCollection`` given a +# sharding plan that we generated! The sharded module has common APIs from +# TorchRec which abstract away distributed communication/compute amongst +# multiple GPUs. In fact, these APIs are highly optimized for performance +# in training and inference. **Below are the three common APIs for +# distributed training/inference** that are provided by TorchRec: +# +# * ``input_dist``: Handles distributing inputs from GPU to GPU. +# * ``lookups``: Does the actual embedding lookup in an optimized, +# batched manner using FBGEMM TBE (more on this later). +# * ``output_dist``: Handles distributing outputs from GPU to GPU. +# +# The distribution of inputs and outputs is done through `NCCL +# Collectives `__, +# namely +# `All-to-Alls `__, +# which is where all GPUs send and receive data to and from one another. +# TorchRec interfaces with PyTorch distributed for collectives and +# provides clean abstractions to the end users, removing the concern for +# the lower level details. +# +# The backwards pass does all of these collectives but in the reverse +# order for distribution of gradients. ``input_dist``, ``lookup``, and +# ``output_dist`` all depend on the sharding scheme. Since we sharded in a +# table-wise fashion, these APIs are modules that are constructed by +# `TwPooledEmbeddingSharding `__. +# + +sharded_ebc + +# Distribute input KJTs to all other GPUs and receive KJTs +sharded_ebc._input_dists + +# Distribute output embeddings to all other GPUs and receive embeddings +sharded_ebc._output_dists + + +###################################################################### +# Optimizing Embedding Lookups +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In performing lookups for a collection of embedding tables, a trivial +# solution would be to iterate through all the ``nn.EmbeddingBags`` and do +# a lookup per table. This is exactly what the standard, unsharded +# ``EmbeddingBagCollection`` does. However, while this solution +# is simple, it is extremely slow. +# +# `FBGEMM `__ is a +# library that provides GPU operators (otherwise known as kernels) that +# are very optimized. One of these operators is known as **Table Batched +# Embedding** (TBE), provides two major optimizations: +# +# - Table batching, which allows you to look up multiple embeddings with +# one kernel call. +# - Optimizer Fusion, which allows the module to update itself given the +# canonical pytorch optimizers and arguments. +# +# The ``ShardedEmbeddingBagCollection`` uses the FBGEMM TBE as the lookup +# instead of traditional ``nn.EmbeddingBags`` for optimized embedding +# lookups. +# + +sharded_ebc._lookups + + +###################################################################### +# ``DistributedModelParallel`` +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We have now explored sharding a single ``EmbeddingBagCollection``! We were +# able to take the ``EmbeddingBagCollectionSharder`` and use the unsharded +# ``EmbeddingBagCollection`` to generate a +# ``ShardedEmbeddingBagCollection`` module. This workflow is fine, but +# typically when implementing model parallel, +# `DistributedModelParallel `__ +# (DMP) is used as the standard interface. When wrapping your model (in +# our case ``ebc``), with DMP, the following will occur: +# +# 1. Decide how to shard the model. DMP will collect the available +# sharders and come up with a plan of the optimal way to shard the +# embedding table(s) (for example, ``EmbeddingBagCollection``) +# 2. Actually shard the model. This includes allocating memory for each +# embedding table on the appropriate device(s). +# +# DMP takes in everything that we've just experimented with, like a static +# sharding plan, a list of sharders, etc. However, it also has some nice +# defaults to seamlessly shard a TorchRec model. In this toy example, +# since we have two embedding tables and one GPU, TorchRec will place both +# on the single GPU. +# + +ebc + +model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda")) + +out = model(kjt) +out.wait() + +model + + +###################################################################### +# Sharding Best Practices +# ~~~~~~~~~~~~~~~~~~~~~~~ +# +# Currently, our configuration is only sharding on 1 GPU (or rank), which +# is trivial: just place all the tables on 1 GPUs memory. However, in real +# production use cases, embedding tables are **typically sharded on +# hundreds of GPUs**, with different sharding methods such as table-wise, +# row-wise, and column-wise. It is incredibly important to determine a +# proper sharding configuration (to prevent out of memory issues) while +# keeping it balanced not only in terms of memory but also compute for +# optimal performance. +# + + +###################################################################### +# Adding in the Optimizer +# ~~~~~~~~~~~~~~~~~~~~~~~ +# +# Remember that TorchRec modules are hyperoptimized for large scale +# distributed training. An important optimization is in regards to the +# optimizer. +# +# TorchRec modules provide a seamless API to fuse the +# backwards pass and optimize step in training, providing a significant +# optimization in performance and decreasing the memory used, alongside +# granularity in assigning distinct optimizers to distinct model +# parameters. +# +# Optimizer Classes +# ^^^^^^^^^^^^^^^^^ +# +# TorchRec uses ``CombinedOptimizer``, which contains a collection of +# ``KeyedOptimizers``. A ``CombinedOptimizer`` effectively makes it easy +# to handle multiple optimizers for various sub groups in the model. A +# ``KeyedOptimizer`` extends the ``torch.optim.Optimizer`` and is +# initialized through a dictionary of parameters exposes the parameters. +# Each ``TBE`` module in a ``EmbeddingBagCollection`` will have it's own +# ``KeyedOptimizer`` which combines into one ``CombinedOptimizer``. +# +# Fused optimizer in TorchRec +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Using ``DistributedModelParallel``, the **optimizer is fused, which +# means that the optimizer update is done in the backward**. This is an +# optimization in TorchRec and FBGEMM, where the optimizer embedding +# gradients are not materialized and applied directly to the parameters. +# This brings significant memory savings as embedding gradients are +# typically size of the parameters themselves. +# +# You can, however, choose to make the optimizer ``dense`` which does not +# apply this optimization and let's you inspect the embedding gradients or +# apply computations to it as you wish. A dense optimizer in this case +# would be your `canonical PyTorch model training loop with +# optimizer. `__ +# +# Once the optimizer is created through ``DistributedModelParallel``, you +# still need to manage an optimizer for the other parameters not +# associated with TorchRec embedding modules. To find the other +# parameters, +# use ``in_backward_optimizer_filter(model.named_parameters())``. +# Apply an optimizer to those parameters as you would a normal Torch +# optimizer and combine this and the ``model.fused_optimizer`` into one +# ``CombinedOptimizer`` that you can use in your training loop to +# ``zero_grad`` and ``step`` through. +# +# Adding an Optimizer to ``EmbeddingBagCollection`` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# We will do this in two ways, which are equivalent, but give you options +# depending on your preferences: +# +# 1. Passing optimizer kwargs through ``fused_params`` in sharder. +# 2. Through ``apply_optimizer_in_backward``, which converts the optimizer +# parameters to ``fused_params`` to pass to the ``TBE`` in the ``EmbeddingBagCollection`` or ``EmbeddingCollection``. +# + +# Option 1: Passing optimizer kwargs through fused parameters +from torchrec.optim.optimizers import in_backward_optimizer_filter +from fbgemm_gpu.split_embedding_configs import EmbOptimType + + +# We initialize the sharder with +fused_params = { + "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, + "learning_rate": 0.02, + "eps": 0.002, +} + +# Initialize sharder with ``fused_params`` +sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params) + +# We'll use same plan and unsharded EBC as before but this time with our new sharder +sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[""], env, torch.device("cuda")) + +# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly. +# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied +print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}") +print(f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}") + +print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}") + +from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward +import copy +# Option 2: Applying optimizer through apply_optimizer_in_backward +# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it + +# We can achieve the same result as we did in the previous +ebc_apply_opt = copy.deepcopy(ebc) +optimizer_kwargs = {"lr": 0.5} + +for name, param in ebc_apply_opt.named_parameters(): + print(f"{name=}") + apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs) + +sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[""], env, torch.device("cuda")) + +# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted +print(sharded_ebc_apply_opt.fused_optimizer) +print(type(sharded_ebc_apply_opt.fused_optimizer)) + +# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s) +# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC +# there are no other parameters that aren't associated with TorchRec +print("Non Fused Model Parameters:") +print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys()) + +# Here we do a dummy backwards call and see that parameter updates for fused +# optimizers happen as a result of the backward pass + +ebc_output = sharded_ebc_fused_params(kjt).wait().values() +loss = torch.sum(torch.ones_like(ebc_output) - ebc_output) +print(f"First Iteration Loss: {loss}") + +loss.backward() + +ebc_output = sharded_ebc_fused_params(kjt).wait().values() +loss = torch.sum(torch.ones_like(ebc_output) - ebc_output) +# We don't call an optimizer.step(), so for the loss to have changed here, +# that means that the gradients were somehow updated, which is what the +# fused optimizer automatically handles for us +print(f"Second Iteration Loss: {loss}") + + +###################################################################### +# Inference +# ~~~~~~~~~~~~ +# +# Now that we are able to train distributed embeddings, how can we take +# the trained model and optimize it for inference? Inference is typically +# very sensitive to **performance and size of the model**. Running just +# the trained model in a Python environment is incredibly inefficient. +# There are two key differences between inference and training +# environments: +# * **Quantization**: Inference models are typically +# quantized, where model parameters lose precision for lower latency in +# predictions and reduced model size. For example FP32 (4 bytes) in +# trained model to INT8 (1 byte) for each embedding weight. This is also +# necessary given the vast scale of embedding tables, as we want to use as +# few devices as possible for inference to minimize latency. +# +# * **C++ environment**: Inference latency is very important, so in order to ensure +# ample performance, the model is typically ran in a C++ environment, +# along with the situations where we don't have a Python runtime, like on +# device. +# +# TorchRec provides primitives for converting a TorchRec model into being +# inference ready with: +# +# * APIs for quantizing the model, introducing +# optimizations automatically with FBGEMM TBE +# * Sharding embeddings for distributed inference +# * Compiling the model to `TorchScript `__ +# (compatible in C++) +# +# In this section, we will go over this entire workflow of: +# +# * Quantizing the model +# * Sharding the quantized model +# * Compiling the sharded quantized model into TorchScript +# + +ebc + +class InferenceModule(torch.nn.Module): + def __init__(self, ebc: torchrec.EmbeddingBagCollection): + super().__init__() + self.ebc_ = ebc + + def forward(self, kjt: KeyedJaggedTensor): + return self.ebc_(kjt) + +module = InferenceModule(ebc) +for name, param in module.named_parameters(): + # Here, the parameters should still be FP32, as we are using a standard EBC + # FP32 is default, regularly used for training + print(name, param.shape, param.dtype) + + +###################################################################### +# Quantization +# ~~~~~~~~~~~~ +# +# As you can see above, the normal EBC contains embedding table weights as +# FP32 precision (32 bits for each weight). Here, we will use the TorchRec +# inference library to quantize the embedding weights of the model to INT8 +# + +from torch import quantization as quant +from torchrec.modules.embedding_configs import QuantConfig +from torchrec.quant.embedding_modules import ( + EmbeddingBagCollection as QuantEmbeddingBagCollection, +) + + +quant_dtype = torch.int8 + + +qconfig = QuantConfig( + # dtype of the result of the embedding lookup, post activation + # torch.float generally for compatibility with rest of the model + # as rest of the model here usually isn't quantized + activation=quant.PlaceholderObserver.with_args(dtype=torch.float), + # quantized type for embedding weights, aka parameters to actually quantize + weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype), +) +qconfig_spec = { + # Map of module type to qconfig + torchrec.EmbeddingBagCollection: qconfig, +} +mapping = { + # Map of module type to quantized module type + torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection, +} + + +module = InferenceModule(ebc) + +# Quantize the module +qebc = quant.quantize_dynamic( + module, + qconfig_spec=qconfig_spec, + mapping=mapping, + inplace=False, +) + + +print(f"Quantized EBC: {qebc}") + +kjt = kjt.to("cpu") + +qebc(kjt) + +# Once quantized, goes from parameters -> buffers, as no longer trainable +for name, buffer in qebc.named_buffers(): + # The shapes of the tables should be the same but the dtype should be int8 now + # post quantization + print(name, buffer.shape, buffer.dtype) + + +###################################################################### +# Shard +# ~~~~~ +# +# Here we perform sharding of the TorchRec quantized model. This is to +# ensure we are using the performant module through FBGEMM TBE. Here we +# are using one device to be consistent with training (1 TBE). +# + +from torchrec import distributed as trec_dist +from torchrec.distributed.shard import _shard_modules + + +sharded_qebc = _shard_modules( + module=qebc, + device=torch.device("cpu"), + env=trec_dist.ShardingEnv.from_local( + 1, + 0, + ), +) + + +print(f"Sharded Quantized EBC: {sharded_qebc}") + +sharded_qebc(kjt) + + +###################################################################### +# Compilation +# ~~~~~~~~~~~ +# +# Now we have the optimized eager TorchRec inference model. The next step +# is to ensure that this model is loadable in C++, as currently it is only +# runnable in a Python runtime. +# +# The recommended method of compilation at Meta is two fold: `torch.fx +# tracing `__ (generate +# intermediate representation of model) and converting the result to +# TorchScript, where TorchScript is C++ compatible. +# + +from torchrec.fx import Tracer + + +tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"]) + +graph = tracer.trace(sharded_qebc) +gm = torch.fx.GraphModule(sharded_qebc, graph) + +print("Graph Module Created!") + +print(gm.code) + +scripted_gm = torch.jit.script(gm) +print("Scripted Graph Module Created!") + +print(scripted_gm.code) + + +###################################################################### +# Conclusion +# ^^^^^^^^^^ +# +# In this tutorial, you have gone from training a distributed RecSys model all the way +# to making it inference ready. The `TorchRec repo +# `__ has a +# full example of how to load a TorchRec TorchScript model into C++ for +# inference. +# + + +###################################################################### +# See Also +# -------------- +# +# For more information, please see our +# `dlrm `__ +# example, which includes multinode training on the Criteo 1TB +# dataset using the methods described in `Deep Learning Recommendation Model +# for Personalization and Recommendation Systems `__. +# diff --git a/intermediate_source/torchrec_tutorial.rst b/intermediate_source/torchrec_tutorial.rst deleted file mode 100644 index 6a450b1659..0000000000 --- a/intermediate_source/torchrec_tutorial.rst +++ /dev/null @@ -1,244 +0,0 @@ -Introduction to TorchRec -======================== - -.. tip:: - To get the most of this tutorial, we suggest using this - `Colab Version `__. - This will allow you to experiment with the information presented below. - -Follow along with the video below or on `youtube `__. - -.. raw:: html - -
- -
- -When building recommendation systems, we frequently want to represent -entities like products or pages with embeddings. For example, see Meta -AI’s `Deep learning recommendation -model `__, or DLRM. As the number of -entities grow, the size of the embedding tables can exceed a single -GPU’s memory. A common practice is to shard the embedding table across -devices, a type of model parallelism. To that end, TorchRec introduces -its primary API -called |DistributedModelParallel|_, -or DMP. Like PyTorch’s DistributedDataParallel, DMP wraps a model to -enable distributed training. - -Installation ------------- - -Requirements: python >= 3.7 - -We highly recommend CUDA when using TorchRec (If using CUDA: cuda >= 11.0). - - -.. code:: shell - - # install pytorch with cudatoolkit 11.3 - conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y - # install TorchRec - pip3 install torchrec-nightly - - -Overview --------- - -This tutorial will cover three pieces of TorchRec: the ``nn.module`` |EmbeddingBagCollection|_, the |DistributedModelParallel|_ API, and -the datastructure |KeyedJaggedTensor|_. - - -Distributed Setup -~~~~~~~~~~~~~~~~~ - -We setup our environment with torch.distributed. For more info on -distributed, see this -`tutorial `__. - -Here, we use one rank (the colab process) corresponding to our 1 colab -GPU. - -.. code:: python - - import os - import torch - import torchrec - import torch.distributed as dist - - os.environ["RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29500" - - # Note - you will need a V100 or A100 to run tutorial as as! - # If using an older GPU (such as colab free K80), - # you will need to compile fbgemm with the appripriate CUDA architecture - # or run with "gloo" on CPUs - dist.init_process_group(backend="nccl") - - -From EmbeddingBag to EmbeddingBagCollection -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -PyTorch represents embeddings through |torch.nn.Embedding|_ and |torch.nn.EmbeddingBag|_. -EmbeddingBag is a pooled version of Embedding. - -TorchRec extends these modules by creating collections of embeddings. We -will use |EmbeddingBagCollection|_ to represent a group of EmbeddingBags. - -Here, we create an EmbeddingBagCollection (EBC) with two embedding bags. -Each table, ``product_table`` and ``user_table``, is represented by a 64 -dimension embedding of size 4096. Note how we initially allocate the EBC -on device “meta”. This will tell EBC to not allocate memory yet. - -.. code:: python - - ebc = torchrec.EmbeddingBagCollection( - device="meta", - tables=[ - torchrec.EmbeddingBagConfig( - name="product_table", - embedding_dim=64, - num_embeddings=4096, - feature_names=["product"], - pooling=torchrec.PoolingType.SUM, - ), - torchrec.EmbeddingBagConfig( - name="user_table", - embedding_dim=64, - num_embeddings=4096, - feature_names=["user"], - pooling=torchrec.PoolingType.SUM, - ) - ] - ) - - -DistributedModelParallel -~~~~~~~~~~~~~~~~~~~~~~~~ - -Now, we’re ready to wrap our model with |DistributedModelParallel|_ (DMP). Instantiating DMP will: - -1. Decide how to shard the model. DMP will collect the available - ‘sharders’ and come up with a ‘plan’ of the optimal way to shard the - embedding table(s) (i.e., the EmbeddingBagCollection). -2. Actually shard the model. This includes allocating memory for each - embedding table on the appropriate device(s). - -In this toy example, since we have two EmbeddingTables and one GPU, -TorchRec will place both on the single GPU. - -.. code:: python - - model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda")) - print(model) - print(model.plan) - - -Query vanilla nn.EmbeddingBag with input and offsets -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -We query |nn.Embedding|_ and |nn.EmbeddingBag|_ -with ``input`` and ``offsets``. Input is a 1-D tensor containing the -lookup values. Offsets is a 1-D tensor where the sequence is a -cumulative sum of the number of values to pool per example. - -Let’s look at an example, recreating the product EmbeddingBag above: - -:: - - |------------| - | product ID | - |------------| - | [101, 202] | - | [] | - | [303] | - |------------| - -.. code:: python - - product_eb = torch.nn.EmbeddingBag(4096, 64) - product_eb(input=torch.tensor([101, 202, 303]), offsets=torch.tensor([0, 2, 2])) - - -Representing minibatches with KeyedJaggedTensor -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -We need an efficient representation of multiple examples of an arbitrary -number of entity IDs per feature per example. In order to enable this -“jagged” representation, we use the TorchRec datastructure -|KeyedJaggedTensor|_ (KJT). - -Let’s take a look at how to lookup a collection of two embedding -bags, “product” and “user”. Assume the minibatch is made up of three -examples for three users. The first of which has two product IDs, the -second with none, and the third with one product ID. - -:: - - |------------|------------| - | product ID | user ID | - |------------|------------| - | [101, 202] | [404] | - | [] | [505] | - | [303] | [606] | - |------------|------------| - -The query should be: - -.. code:: python - - mb = torchrec.KeyedJaggedTensor( - keys = ["product", "user"], - values = torch.tensor([101, 202, 303, 404, 505, 606]).cuda(), - lengths = torch.tensor([2, 0, 1, 1, 1, 1], dtype=torch.int64).cuda(), - ) - - print(mb.to(torch.device("cpu"))) - - -Note that the KJT batch size is -``batch_size = len(lengths)//len(keys)``. In the above example, -batch_size is 3. - - - -Putting it all together, querying our distributed model with a KJT minibatch -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Finally, we can query our model using our minibatch of products and -users. - -The resulting lookup will contain a KeyedTensor, where each key (or -feature) contains a 2D tensor of size 3x64 (batch_size x embedding_dim). - -.. code:: python - - pooled_embeddings = model(mb) - print(pooled_embeddings) - - -More resources --------------- - -For more information, please see our -`dlrm `__ -example, which includes multinode training on the criteo terabyte -dataset, using Meta’s `DLRM `__. - - -.. |DistributedModelParallel| replace:: ``DistributedModelParallel`` -.. _DistributedModelParallel: https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel -.. |EmbeddingBagCollection| replace:: ``EmbeddingBagCollection`` -.. _EmbeddingBagCollection: https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection -.. |KeyedJaggedTensor| replace:: ``KeyedJaggedTensor`` -.. _KeyedJaggedTensor: https://pytorch.org/torchrec/torchrec.sparse.html#torchrec.sparse.jagged_tensor.JaggedTensor -.. |torch.nn.Embedding| replace:: ``torch.nn.Embedding`` -.. _torch.nn.Embedding: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html -.. |torch.nn.EmbeddingBag| replace:: ``torch.nn.EmbeddingBag`` -.. _torch.nn.EmbeddingBag: https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html -.. |nn.Embedding| replace:: ``nn.Embedding`` -.. _nn.Embedding: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html -.. |nn.EmbeddingBag| replace:: ``nn.EmbeddingBag`` -.. _nn.EmbeddingBag: https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html