Skip to content

Commit

Permalink
Memfix (#106)
Browse files Browse the repository at this point in the history
* Bump version

* Add garbage collection to predict_step

* Clear predictions in predict_loop

* Codesyle fix

* Reorganize pred clearing

* Move gc collect to writer

---------

Co-authored-by: GitHub Action <[email protected]>
  • Loading branch information
surajpaib and actions-user committed Feb 3, 2024
1 parent ca9794e commit 2a3a451
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion lighter/callbacks/writer/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, Dict, Union

import gc
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
Expand Down Expand Up @@ -105,7 +106,6 @@ def on_predict_batch_end(
If the IDs are not provided, it generates global unique IDs based on the prediction count.
Finally, it writes the predictions using the specified writer.
"""

# If the IDs are not provided, generate global unique IDs based on the prediction count. DDP supported.
if outputs["id"] is None:
batch_size = len(outputs["pred"])
Expand All @@ -115,3 +115,7 @@ def on_predict_batch_end(

for id, pred in zip(outputs["id"], outputs["pred"]):
self.write(tensor=pred, id=id)

# Clear the predictions to save CPU memory. https://github.com/Lightning-AI/pytorch-lightning/issues/15656
trainer.predict_loop._predictions = [[] for _ in range(trainer.predict_loop.num_dataloaders)]
gc.collect()

0 comments on commit 2a3a451

Please sign in to comment.