Skip to content

Commit

Permalink
[VLM][BugFix] Make sure that multi_modal_kwargs can broadcast prope…
Browse files Browse the repository at this point in the history
…rly with ring buffer. (vllm-project#5905)

Signed-off-by: Xiaowei Jiang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
2 people authored and prashantgupta24 committed Jun 28, 2024
1 parent bc30b64 commit 9bafffb
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GraphCaptureContext:


def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
Expand Down Expand Up @@ -473,11 +473,11 @@ def recv_object(self, src: int) -> Any:

def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
Expand Down Expand Up @@ -558,9 +558,9 @@ def broadcast_tensor_dict(

def send_tensor_dict(
self,
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
dst: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
Expand Down Expand Up @@ -599,7 +599,7 @@ def send_tensor_dict(
def recv_tensor_dict(
self,
src: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
Expand All @@ -615,15 +615,15 @@ def recv_tensor_dict(
assert src < self.world_size, f"Invalid src rank ({src})"

recv_metadata_list = self.recv_object(src=src)
tensor_dict = {}
tensor_dict: Dict[str, Any] = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
_update_nested_dict(tensor_dict, key, tensor)
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
Expand All @@ -633,9 +633,9 @@ def recv_tensor_dict(
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group)
tensor_dict[key] = tensor
_update_nested_dict(tensor_dict, key, tensor)
else:
tensor_dict[key] = value
_update_nested_dict(tensor_dict, key, value)
return tensor_dict

def barrier(self):
Expand Down

0 comments on commit 9bafffb

Please sign in to comment.