Skip to content

Commit

Permalink
Merge branch 'main' into shanea/delay-device-mesh-import
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras authored Apr 26, 2024
2 parents 3c12813 + 4e8746d commit 67e0b64
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 4 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed

- Added original legacy unsharding implementation back, as the default. The new
shared memory implementation can be used by passing `use_legacy_shared_mem_impl` to `unshard.py`.

## [v0.3.0](https://github.com/allenai/OLMo/releases/tag/v0.3.0) - 2024-04-25

### Added
Expand Down
141 changes: 138 additions & 3 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,10 @@ class TorchLegacyShardedCheckpointer(Checkpointer):
The world size must be kept consistent when using this checkpointer.
"""

def __init__(self, cfg: TrainConfig, thread_count: Optional[int] = None, use_shared_mem_impl: bool = False):
super().__init__(cfg, thread_count)
self.use_shared_mem_impl = use_shared_mem_impl

def save_checkpoint(
self,
dir: PathOrStr,
Expand Down Expand Up @@ -1147,20 +1151,25 @@ def _rebuild_from_type_v2_monkey(func, new_type, args, state):
finally:
torch._tensor._rebuild_from_type_v2 = original_rebuild_from_type_v2

def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
def _unshard_using_shared_memory(
self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None
):
"""
The current unsharding implementation consists of:
This unsharding implementation consists of:
1. Loading each shard on a separate process and copying their sharded tensors to shared memory.
2. Loading 1 shard on the main process as a base unsharded object.
3. Using the sharded tensors in shared memory to populate the base unsharded object.
This implementation replaced a prior implementation that instead loaded
This implementation is an alternative to a prior implementation that instead loaded
all shards using threads, because that implementation turned out to
be extremely slow (e.g. 6+ hours) sometimes when the world size was 1024.
The current implementation is slower than the old one in many scenarios,
but is significantly faster in the above mentioned case (e.g. 30 minutes)
if there are enough CPUs.
We keep the other implementation since this once can be more unreliable,
likely due to its dependence on a large amount of shared memory.
"""

input_dir = Path(input_dir)
Expand Down Expand Up @@ -1211,6 +1220,132 @@ def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Option
log.info("Unsharding from %d shards ...", world_size)
return self._unshard_using_sharded_mem(state, world_size, device, input_dir)

def _unshard(self, input_dir: PathOrStr, device: torch.device, skip_keys: Optional[Set[str]] = None):
if self.use_shared_mem_impl:
return self._unshard_using_shared_memory(input_dir, device, skip_keys)

input_dir = Path(input_dir)
skip_keys = skip_keys or set()

with self._patch_sharded_tensor_load():
# We load in threads because it's faster.
executor = ThreadPoolExecutor()
shards_dict = {}
for shard_name in input_dir.glob("rank*.pt"):
log.info("Loading %s ...", shard_name)
shard_number = int(shard_name.name[4:-3]) # shard names look like "rankXX.pt"
shards_dict[shard_number] = executor.submit(torch.load, shard_name, map_location="cpu")
shards = [None] * len(shards_dict)
for rank, shard_future in shards_dict.items():
shard = shard_future.result()
for key in skip_keys:
if key in shard:
del shard[key]
shards[rank] = shard
assert all(shard is not None for shard in shards)
executor.shutdown()
del shards_dict

log.info("Unsharding from %d shards ...", len(shards))

unsharded_state_dict = self._unshard_object(shards, device=device)
# At this point in time we need 2x memory :-(
del shards

return unsharded_state_dict

def _unshard_object(self, os: List[Any], device: torch.device) -> Any:
rank0_item = os[0]
assert all(type(o) is type(rank0_item) for o in os)
if isinstance(rank0_item, str):
assert all(o == rank0_item for o in os)
return rank0_item
elif isinstance(rank0_item, (list, tuple, set)):
assert all(len(o) == len(rank0_item) for o in os)
return rank0_item.__class__(self._unshard_object(o, device=device) for o in zip(*os))
elif isinstance(rank0_item, dict):
assert all(o.keys() == rank0_item.keys() for o in os)
return {key: self._unshard_object([o[key] for o in os], device=device) for key in rank0_item.keys()}
elif isinstance(rank0_item, ShardedTensor):
return self._gather(os, device=device)
else:
assert all(self._objects_are_equal(o, rank0_item) for o in os)
return rank0_item

def _gather(self, shards: List[ShardedTensor], device: torch.device) -> torch.Tensor:
world_size = len(shards)
shard0_md = shards[0].metadata()
# Make sure all shards agree on the metadata
assert all(shard.metadata() == shard0_md for shard in shards)
# Make sure the nth shard expects to be the nth shard.
assert all(
shard_md.placement.rank() == rank # type: ignore
for rank, shard_md in enumerate(shard0_md.shards_metadata)
)

def shard_size(shard_md):
return reduce((lambda x, y: x * y), shard_md.shard_sizes) # type: ignore[attr-defined]

rank_sizes = [0 for _ in range(world_size)]
max_rank_size = 0
shard_placement: Dict[ShardMetadata, Tuple[int, int]] = {}
for shard_md in shard0_md.shards_metadata:
shard_rank = cast(_remote_device, shard_md.placement).rank()
assert shard_rank is not None

shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank])
rank_sizes[shard_rank] += shard_size(shard_md)
max_rank_size = max(max_rank_size, rank_sizes[shard_rank])

gather_list: List[torch.Tensor] = [torch.empty((max_rank_size,)) for _ in range(world_size)]

datas = []
with torch.no_grad():
for shard in shards:
data = torch.empty(max_rank_size)

for local_shard in shard.local_shards():
src = local_shard.tensor.flatten()
shard_offset = shard_placement[local_shard.metadata][1]
data[shard_offset : shard_offset + src.numel()].copy_(src)

datas.append(data)

# torch.gather in a nutshell
for rank, data in enumerate(datas):
gather_list[rank].copy_(data)

full_size = shard0_md.size
out = torch.empty(*full_size, dtype=shard0_md.tensor_properties.dtype, device=device)
dims = len(full_size)
for shard_md in shard0_md.shards_metadata:
rank, rank_offset = shard_placement[shard_md]
tensor = gather_list[rank]
tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)]
tensor = tensor.view(shard_md.shard_sizes)

out_narrow_view = out
for dim in range(dims):
out_narrow_view = out_narrow_view.narrow(
dim,
shard_md.shard_offsets[dim],
shard_md.shard_sizes[dim],
)

out_narrow_view.copy_(tensor)

return out

def _objects_are_equal(self, a: Any, b: Any) -> bool:
if type(a) is not type(b):
return False
if isinstance(a, np.ndarray):
return np.array_equal(a, b)
elif isinstance(a, torch.Tensor):
return torch.equal(a, b)
else:
return a == b


@dataclass
class _LocalShardedCheckpointerMetadata(BaseConfig):
Expand Down
12 changes: 11 additions & 1 deletion scripts/unshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def main(
sharded_checkpoint_type: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy,
model_only: bool = False,
safe_tensors: bool = False,
use_shared_mem_impl: bool = False,
) -> None:
if isinstance(input_dir, str):
input_dir = Path(input_dir)
Expand All @@ -32,7 +33,7 @@ def main(
config = TrainConfig.load(input_dir / "config.yaml", validate_paths=False)
checkpointer: Checkpointer
if sharded_checkpoint_type == ShardedCheckpointerType.torch_legacy:
checkpointer = TorchLegacyShardedCheckpointer(config)
checkpointer = TorchLegacyShardedCheckpointer(config, use_shared_mem_impl=use_shared_mem_impl)
elif sharded_checkpoint_type == ShardedCheckpointerType.local:
checkpointer = LocalShardedCheckpointer(config)
else:
Expand Down Expand Up @@ -99,6 +100,14 @@ def main(
"--safe-tensors",
action="store_true",
)
parser.add_argument(
"--use-legacy-shared-mem-impl",
action="store_true",
help="""This ignored if type is not torch_legacy. For legacy sharded checkpoints,
use the shared memory implementation. This has high CPU, RAM and shared
memory requirements but can be significantly faster when the world size
is large (e.g. 1024).""",
)
args = parser.parse_args()

logging.basicConfig(level=logging.INFO)
Expand All @@ -108,4 +117,5 @@ def main(
sharded_checkpoint_type=args.type,
model_only=args.model_only,
safe_tensors=args.safe_tensors,
use_shared_mem_impl=args.use_legacy_shared_mem_impl,
)

0 comments on commit 67e0b64

Please sign in to comment.