Skip to content

Commit

Permalink
[TP] Add API logging for TP high level API
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
fduwjj committed May 24, 2023
1 parent e3d97b6 commit e0921a2
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torch/distributed/tensor/parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def parallelize_module( # type: ignore[return]
granularity, you need to pass in a dict of module FQN and parallel style instead.
"""

torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")

if device_mesh.ndim > 1:
device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)

Expand Down
3 changes: 3 additions & 0 deletions torch/distributed/tensor/parallel/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def enable_2d_with_fsdp() -> bool:
Return:
A `bool` indicated whether extension registration succeeds or not.
"""

torch._C._log_api_usage_once("torch.distributed.tensor.parallel.enable_2d_with_fsdp")

try:
from torch.distributed.fsdp._fsdp_extensions import (
_set_fsdp_extensions,
Expand Down

0 comments on commit e0921a2

Please sign in to comment.