Skip to content

Commit

Permalink
add repo_id info
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Sep 29, 2023
1 parent 0cc77d7 commit fcaa9f2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ def task_templates(self):
def version(self):
return self._info.version

@property
def repo_id(self) -> str:
return self._info.repo_id


class TensorflowDatasetMixin:
_TF_DATASET_REFS = set()
Expand Down
6 changes: 4 additions & 2 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ def __init__(
# update info with user specified infos
if features is not None:
self.info.features = features
if repo_id is not None:
self.info.repo_id = repo_id

# Prepare data dirs:
# cache_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing)
Expand Down Expand Up @@ -417,7 +419,7 @@ def __init__(
if len(os.listdir(self._cache_dir)) > 0:
if os.path.exists(path_join(self._cache_dir, config.DATASET_INFO_FILENAME)):
logger.info("Overwrite dataset info from restored data version if exists.")
self.info = DatasetInfo.from_directory(self._cache_dir)
self.info.update(DatasetInfo.from_directory(self._cache_dir))
else: # dir exists but no data, remove the empty dir as data aren't available anymore
logger.warning(
f"Old caching folder {self._cache_dir} for dataset {self.dataset_name} exists but no data were found. Removing it. "
Expand Down Expand Up @@ -882,7 +884,7 @@ def download_and_prepare(
logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})")
# We need to update the info in case some splits were added in the meantime
# for example when calling load_dataset from multiple workers.
self.info = self._load_info()
self.info.update(self._load_info())
self.download_post_processing_resources(dl_manager)
return

Expand Down
8 changes: 8 additions & 0 deletions src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class DatasetInfo:
dataset_name: Optional[str] = None # for packaged builders, to be different from builder_name
config_name: Optional[str] = None
version: Optional[Union[str, Version]] = None
repo_id: Optional[str] = None
# Set later by `download_and_prepare`
splits: Optional[dict] = None
download_checksums: Optional[dict] = None
Expand Down Expand Up @@ -282,6 +283,12 @@ def from_merge(cls, dataset_infos: List["DatasetInfo"]):
supervised_keys = None
task_templates = None

repo_ids = {dset_info.repo_id for dset_info in dataset_infos}
if len(repo_ids) == 1 and next(iter(repo_ids)) is not None:
repo_id = next(iter(repo_ids))
else:
repo_id = None

# Find common task templates across all dataset infos
all_task_templates = [info.task_templates for info in dataset_infos if info.task_templates is not None]
if len(all_task_templates) > 1:
Expand All @@ -299,6 +306,7 @@ def from_merge(cls, dataset_infos: List["DatasetInfo"]):
features=features,
supervised_keys=supervised_keys,
task_templates=task_templates,
repo_id=repo_id,
)

@classmethod
Expand Down

0 comments on commit fcaa9f2

Please sign in to comment.