Skip to content

Commit

Permalink
vae dataloader eats itself on batch 1, no idea why
Browse files Browse the repository at this point in the history
  • Loading branch information
neggles committed Jan 24, 2024
1 parent 668e455 commit eae37ed
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 77 deletions.
12 changes: 11 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,20 @@
"type": "python",
"request": "launch",
"cwd": "${workspaceFolder}",
"module": "neurosis.trainer.sdxl",
"module": "neurosis.trainer.cli",
"args": ["fit", "--config", "./configs/sdxl/wdxl-test.yml"],
"justMyCode": true,
"subProcess": true
},
{
"name": "neurosis VAE test",
"type": "python",
"request": "launch",
"cwd": "${workspaceFolder}",
"module": "neurosis.trainer.cli",
"args": ["fit", "--config", "configs/wdxl-vae-test.yaml"],
"justMyCode": true,
"subProcess": true
}
]
}
17 changes: 7 additions & 10 deletions src/neurosis/dataset/mongo/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,18 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
}

def refresh_clients(self):
"""Helper func to replace the current clients with new ones"""
self.client = self.settings.new_client()

# detect forks and reset fsspec
"""Helper func to replace the current clients with new ones post-fork etc."""
pid = getpid()
fs_pid = self.fs._pid if self.fs is not None else None
if self.client is None or self.pid != pid:
self.client = self.settings.new_client()
self.pid = pid

if (self.pid != pid) or (fs_pid != pid) or self.fs is None:
logger.info(f"loader PID {pid} detected fork, resetting fsspec clients")
if self.fs is None or self.fs._pid != pid:
logger.debug(f"Loader detected fork, new PID {pid} - resetting fsspec clients")
import fsspec

fsspec.asyn.reset_lock()
self.fs = S3FileSystem(**self.s3fs_kwargs, skip_instance_cache=True)
self.pid = pid

@property
def collection(self) -> MongoCollection:
Expand Down Expand Up @@ -322,7 +320,6 @@ def setup(self, stage: str):
self.sampler = AspectBucketSampler(self.dataset)

logger.info(f"Refreshing dataset clients for {stage}")
self.dataset.fs = None
self.dataset.refresh_clients()

def train_dataloader(self):
Expand All @@ -339,6 +336,6 @@ def train_dataloader(self):


def mongo_worker_init(worker_id: int = -1):
logger.info(f"Worker {worker_id} initializing")
logger.debug(f"Worker {worker_id} initializing")
clear_fsspec()
set_s3fs_opts()
39 changes: 24 additions & 15 deletions src/neurosis/dataset/mongo/nobucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from pymongoarrow.api import find_pandas_all
from pymongoarrow.schema import Schema
from s3fs import S3FileSystem
from torch import Tensor
from torch.utils.data import DataLoader
from torch import Generator, Tensor
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler

from neurosis.dataset.base import NoBucketDataset
from neurosis.dataset.mongo.settings import MongoSettings, get_mongo_settings
Expand Down Expand Up @@ -104,20 +104,18 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
}

def refresh_clients(self):
"""Helper func to replace the current clients with new ones"""
self.client = self.settings.new_client()

# detect forks and reset fsspec
"""Helper func to replace the current clients with new ones post-fork etc."""
pid = getpid()
fs_pid = self.fs._pid if self.fs is not None else None
if self.client is None or self.pid != pid:
self.client = self.settings.new_client()
self.pid = pid

if (self.pid != pid) or (fs_pid != pid) or self.fs is None:
logger.info(f"loader PID {pid} detected fork, resetting fsspec clients")
if self.fs is None or self.fs._pid != pid:
logger.debug(f"Loader detected fork, new PID {pid} - resetting fsspec clients")
import fsspec

fsspec.asyn.reset_lock()
self.fs = S3FileSystem(**self.s3fs_kwargs, skip_instance_cache=True)
self.pid = pid

@property
def collection(self) -> MongoCollection:
Expand Down Expand Up @@ -211,6 +209,7 @@ def __init__(
s3_bucket: Optional[str] = None,
s3fs_kwargs: dict = {},
pma_schema: Optional[Schema] = None,
seed: Optional[int] = None,
num_workers: int = 0,
prefetch_factor: int = 2,
pin_memory: bool = True,
Expand All @@ -236,24 +235,34 @@ def __init__(
s3fs_kwargs=s3fs_kwargs,
pma_schema=pma_schema,
)
self.seed = seed
self.num_workers = num_workers
self.pin_memory = pin_memory
self.prefetch_factor = prefetch_factor
self.drop_last = drop_last

@property
def batch_size(self):
return self.dataset.batch_size

def prepare_data(self) -> None:
self.dataset.fs = None
self.dataset.refresh_clients()
pass

def setup(self, stage: str):
logger.info(f"Refreshing dataset clients for {stage}")
self.dataset.fs = None
self.dataset.refresh_clients()

def train_dataloader(self):
if self.seed is not None:
logger.info(f"Setting seed {self.seed} for train dataloader")
generator = Generator().manual_seed(self.seed)
sampler = RandomSampler(self.dataset, generator=generator)
else:
sampler = SequentialSampler(self.dataset)

return DataLoader(
self.dataset,
batch_size=self.dataset.batch_size,
batch_sampler=BatchSampler(sampler, self.batch_size, self.drop_last),
num_workers=self.num_workers,
pin_memory=self.pin_memory,
prefetch_factor=self.prefetch_factor,
Expand All @@ -263,6 +272,6 @@ def train_dataloader(self):


def mongo_worker_init(worker_id: int = -1):
logger.info(f"Worker {worker_id} initializing")
logger.debug(f"Worker {worker_id} initializing")
clear_fsspec()
set_s3fs_opts()
37 changes: 27 additions & 10 deletions src/neurosis/dataset/mongo/nocaption.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from pymongoarrow.api import find_pandas_all
from pymongoarrow.schema import Schema
from s3fs import S3FileSystem
from torch import Tensor
from torch.utils.data import DataLoader
from torch import Generator, Tensor
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler

from neurosis.dataset.base import NoBucketDataset
from neurosis.dataset.mongo.settings import MongoSettings, get_mongo_settings
Expand All @@ -40,6 +40,7 @@ def __init__(
pma_schema: Optional[Schema] = None,
retries: int = 3,
retry_delay: int = 5,
seed: Optional[int] = None,
**kwargs,
):
super().__init__(resolution, batch_size, **kwargs)
Expand All @@ -62,6 +63,7 @@ def __init__(

self.retries = retries
self.retry_delay = retry_delay
self.seed = seed

# load meta
logger.debug(
Expand All @@ -78,6 +80,7 @@ def __len__(self):

def __getitem__(self, index: int) -> dict[str, Tensor]:
if self._first_getitem:
logger.debug(f"First __getitem__ (idx {index}) - refreshing clients")
self.refresh_clients()
self._first_getitem = False

Expand All @@ -91,18 +94,18 @@ def __getitem__(self, index: int) -> dict[str, Tensor]:
}

def refresh_clients(self):
"""Helper func to replace the current clients with new ones"""
self.client = self.settings.new_client()

# detect forks and reset fsspec
"""Helper func to replace the current clients with new ones post-fork etc."""
pid = getpid()
if self.client is None or self.pid != pid:
self.client = self.settings.new_client()
self.pid = pid

if self.fs is None or self.fs._pid != pid:
logger.info(f"loader PID {pid} detected fork, resetting fsspec clients")
logger.debug(f"Loader detected fork, new PID {pid} - resetting fsspec clients")
import fsspec

fsspec.asyn.reset_lock()
self.fs = S3FileSystem(**self.s3fs_kwargs, skip_instance_cache=True)
self.pid = pid

@property
def collection(self) -> MongoCollection:
Expand Down Expand Up @@ -166,6 +169,7 @@ def __init__(
s3_bucket: Optional[str] = None,
s3fs_kwargs: dict = {},
pma_schema: Optional[Schema] = None,
seed: Optional[int] = None,
num_workers: int = 0,
prefetch_factor: int = 2,
pin_memory: bool = True,
Expand All @@ -185,11 +189,16 @@ def __init__(
s3fs_kwargs=s3fs_kwargs,
pma_schema=pma_schema,
)
self.seed = seed
self.num_workers = num_workers
self.pin_memory = pin_memory
self.prefetch_factor = prefetch_factor
self.drop_last = drop_last

@property
def batch_size(self):
return self.dataset.batch_size

def prepare_data(self) -> None:
pass

Expand All @@ -198,18 +207,26 @@ def setup(self, stage: str):
self.dataset.refresh_clients()

def train_dataloader(self):
if self.seed is not None:
logger.info(f"Setting seed {self.seed} for train dataloader")
generator = Generator().manual_seed(self.seed)
sampler = RandomSampler(self.dataset, generator=generator)
else:
sampler = SequentialSampler(self.dataset)

return DataLoader(
self.dataset,
batch_size=self.dataset.batch_size,
batch_sampler=BatchSampler(sampler, self.batch_size, self.drop_last),
num_workers=self.num_workers,
pin_memory=self.pin_memory,
prefetch_factor=self.prefetch_factor,
persistent_workers=True,
worker_init_fn=mongo_worker_init,
timeout=60.0,
)


def mongo_worker_init(worker_id: int = -1):
logger.info(f"Worker {worker_id} initializing")
logger.debug(f"Worker {worker_id} initializing")
clear_fsspec()
set_s3fs_opts()
Loading

0 comments on commit eae37ed

Please sign in to comment.