diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index 266f689c..edecf389 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -271,7 +271,6 @@ def __init__( self.mp.start() self.cpu_offload_state_dict = None self.staging = False - self.staging_state_dict = None self.staging_id = None self.staging_stream = torch.cuda.Stream() else: @@ -384,7 +383,7 @@ def _async_with_pinned_memory(self, checkpoint_id: str) -> None: if self.cpu_offload_state_dict is None: logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f") self.cpu_offload_state_dict = _create_cpu_state_dict( - state_dict, pin_memory=True + state_dict, pin_memory=True, share_memory=True ) logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f") @@ -395,7 +394,6 @@ def _async_with_pinned_memory(self, checkpoint_id: str) -> None: non_blocking=True, ) self.staging = True - self.staging_state_dict = state_dict self.staging_id = checkpoint_id def save(self, curr_step: int, force: bool = False) -> None: @@ -435,12 +433,19 @@ def maybe_wait_for_staging(self) -> None: and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM and self.staging ): - logger.debug(f"Waiting for staging, {time.monotonic()=:.2f}.") - self.staging_stream.synchronize() - logger.debug( - f"Sending the state dict to the background process, {time.monotonic()=:.2f}." - ) - self.mp_queue_send.put((self.staging_state_dict, self.staging_id)) + if not self.staging_stream.query(): + self.staging_stream.synchronize() + + def sync_func(): + self.mp_queue_send.put_nowait( + (self.cpu_offload_state_dict, self.staging_id) + ) + + # This may be a faster way to do zero-overhead checkpointing staging + # checkpointing but we need more thorough investigation before + # swithing to this method. + # self.my_thread = threading.Thread(target=func).start() + sync_func() self.staging = False def load(self, step: int = -1) -> bool: