diff --git a/recipes_source/distributed_comm_debug_mode.rst b/recipes_source/distributed_comm_debug_mode.rst new file mode 100644 index 0000000000..db79cdc899 --- /dev/null +++ b/recipes_source/distributed_comm_debug_mode.rst @@ -0,0 +1,210 @@ +Getting Started with ``CommDebugMode`` +===================================================== + +**Author**: `Anshul Sinha `__ + + +In this tutorial, we will explore how to use ``CommDebugMode`` with PyTorch's +DistributedTensor (DTensor) for debugging by tracking collective operations in distributed training environments. + +Prerequisites +--------------------- + +* Python 3.8 - 3.11 +* PyTorch 2.2 or later + + +What is ``CommDebugMode`` and why is it useful +---------------------------------------------------- +As the size of models continues to increase, users are seeking to leverage various combinations +of parallel strategies to scale up distributed training. However, the lack of interoperability +between existing solutions poses a significant challenge, primarily due to the absence of a +unified abstraction that can bridge these different parallelism strategies. To address this +issue, PyTorch has proposed `DistributedTensor(DTensor) +`_ +which abstracts away the complexities of tensor communication in distributed training, +providing a seamless user experience. However, when dealing with existing parallelism solutions and +developing parallelism solutions using the unified abstraction like DTensor, the lack of transparency +about what and when the collective communications happens under the hood could make it challenging +for advanced users to identify and resolve issues. To address this challenge, ``CommDebugMode``, a +Python context manager will serve as one of the primary debugging tools for DTensors, enabling +users to view when and why collective operations are happening when using DTensors, effectively +addressing this issue. + + +Using ``CommDebugMode`` +------------------------ + +Here is how you can use ``CommDebugMode``: + +.. code-block:: python + + # The model used in this example is a MLPModule applying Tensor Parallel + comm_mode = CommDebugMode() + with comm_mode: + output = model(inp) + + # print the operation level collective tracing information + print(comm_mode.generate_comm_debug_tracing_table(noise_level=0)) + + # log the operation level collective tracing information to a file + comm_mode.log_comm_debug_tracing_table_to_file( + noise_level=1, file_name="transformer_operation_log.txt" + ) + + # dump the operation level collective tracing information to json file, + # used in the visual browser below + comm_mode.generate_json_dump(noise_level=2) + +This is what the output looks like for a MLPModule at noise level 0: + +.. code-block:: python + + Expected Output: + Global + FORWARD PASS + *c10d_functional.all_reduce: 1 + MLPModule + FORWARD PASS + *c10d_functional.all_reduce: 1 + MLPModule.net1 + MLPModule.relu + MLPModule.net2 + FORWARD PASS + *c10d_functional.all_reduce: 1 + +To use ``CommDebugMode``, you must wrap the code running the model in ``CommDebugMode`` and call the API that +you want to use to display the data. You can also use a ``noise_level`` argument to control the verbosity +level of displayed information. Here is what each noise level displays: + +| 0. Prints module-level collective counts +| 1. Prints DTensor operations (not including trivial operations), module sharding information +| 2. Prints tensor operations (not including trivial operations) +| 3. Prints all operations + +In the example above, you can see that the collective operation, all_reduce, occurs once in the forward pass +of the ``MLPModule``. Furthermore, you can use ``CommDebugMode`` to pinpoint that the all-reduce operation happens +in the second linear layer of the ``MLPModule``. + + +Below is the interactive module tree visualization that you can use to upload your own JSON dump: + +.. raw:: html + + + + + + + CommDebugMode Module Tree + + + +
+
+ Drag file here +
+ +
+
+ + + + +Conclusion +------------------------------------------ + +In this recipe, we have learned how to use ``CommDebugMode`` to debug Distributed Tensors and +parallelism solutions that uses communication collectives with PyTorch. You can use your own +JSON outputs in the embedded visual browser. + +For more detailed information about ``CommDebugMode``, see +`comm_mode_features_example.py +`_ diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst index 8959ea98a3..d94d7d5c22 100644 --- a/recipes_source/recipes_index.rst +++ b/recipes_source/recipes_index.rst @@ -395,6 +395,13 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu :link: ../recipes/distributed_async_checkpoint_recipe.html :tags: Distributed-Training +.. customcarditem:: + :header: Getting Started with CommDebugMode + :card_description: Learn how to use CommDebugMode for DTensors + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../recipes/distributed_comm_debug_mode.html + :tags: Distributed-Training + .. TorchServe .. customcarditem:: @@ -449,3 +456,4 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu /recipes/cuda_rpc /recipes/distributed_optim_torchscript /recipes/mobile_interpreter + /recipes/distributed_comm_debug_mode