From 8174f81312a7350fd0a5cd308b1c5c8dba66ce38 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 24 Oct 2024 18:09:38 +0200 Subject: [PATCH] basic xml (mostly copy pasted from text) --- docs/source/nlp_load.mdx | 12 +++- .../package_reference/loading_methods.mdx | 6 ++ src/datasets/packaged_modules/__init__.py | 3 + src/datasets/packaged_modules/xml/__init__.py | 0 src/datasets/packaged_modules/xml/xml.py | 68 +++++++++++++++++++ 5 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 src/datasets/packaged_modules/xml/__init__.py create mode 100644 src/datasets/packaged_modules/xml/xml.py diff --git a/docs/source/nlp_load.mdx b/docs/source/nlp_load.mdx index 5cfe5d31e99..dae074ae3fc 100644 --- a/docs/source/nlp_load.mdx +++ b/docs/source/nlp_load.mdx @@ -33,4 +33,14 @@ To load remote text files via HTTP, pass the URLs instead: ```py >>> dataset = load_dataset("text", data_files="https://huggingface.co/datasets/lhoestq/test/resolve/main/some_text.txt") -``` \ No newline at end of file +``` + +To load XML data you can use the "xml" loader, which is equivalent to "text" with sample_by="document": + +```py +>>> from datasets import load_dataset +>>> dataset = load_dataset("xml", data_files={"train": ["my_xml_1.xml", "my_xml_2.xml"], "test": "my_xml_file.xml"}) + +# Load from a directory +>>> dataset = load_dataset("xml", data_dir="path/to/xml/dataset") +``` diff --git a/docs/source/package_reference/loading_methods.mdx b/docs/source/package_reference/loading_methods.mdx index b17cbed8a3b..29d94584220 100644 --- a/docs/source/package_reference/loading_methods.mdx +++ b/docs/source/package_reference/loading_methods.mdx @@ -49,6 +49,12 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t") [[autodoc]] datasets.packaged_modules.json.Json +### XML + +[[autodoc]] datasets.packaged_modules.xml.XmlConfig + +[[autodoc]] datasets.packaged_modules.xml.Xml + ### Parquet [[autodoc]] datasets.packaged_modules.parquet.ParquetConfig diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 6a23170db5e..7598d0213ac 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -15,6 +15,7 @@ from .sql import sql from .text import text from .webdataset import webdataset +from .xml import xml def _hash_python_lines(lines: List[str]) -> str: @@ -41,6 +42,7 @@ def _hash_python_lines(lines: List[str]) -> str: "imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())), "audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())), "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), + "xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())), } # get importable module names and hash for caching @@ -69,6 +71,7 @@ def _hash_python_lines(lines: List[str]) -> str: ".arrow": ("arrow", {}), ".txt": ("text", {}), ".tar": ("webdataset", {}), + ".xml": ("xml", {}), } _EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) diff --git a/src/datasets/packaged_modules/xml/__init__.py b/src/datasets/packaged_modules/xml/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/xml/xml.py b/src/datasets/packaged_modules/xml/xml.py new file mode 100644 index 00000000000..d5009b4dd6a --- /dev/null +++ b/src/datasets/packaged_modules/xml/xml.py @@ -0,0 +1,68 @@ +import itertools +from dataclasses import dataclass +from typing import Optional + +import pyarrow as pa + +import datasets +from datasets.features.features import require_storage_cast +from datasets.table import table_cast + + +logger = datasets.utils.logging.get_logger(__name__) + + +@dataclass +class XmlConfig(datasets.BuilderConfig): + """BuilderConfig for xml files.""" + + features: Optional[datasets.Features] = None + encoding: str = "utf-8" + encoding_errors: Optional[str] = None + + +class Xml(datasets.ArrowBasedBuilder): + BUILDER_CONFIG_CLASS = XmlConfig + + def _info(self): + return datasets.DatasetInfo(features=self.config.features) + + def _split_generators(self, dl_manager): + """The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]]. + + If str or List[str], then the dataset returns only the 'train' split. + If dict, then keys should be from the `datasets.Split` enum. + """ + if not self.config.data_files: + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") + dl_manager.download_config.extract_on_the_fly = True + data_files = dl_manager.download_and_extract(self.config.data_files) + splits = [] + for split_name, files in data_files.items(): + if isinstance(files, str): + files = [files] + files = [dl_manager.iter_files(file) for file in files] + splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) + return splits + + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.config.features is not None: + schema = self.config.features.arrow_schema + if all(not require_storage_cast(feature) for feature in self.config.features.values()): + # cheaper cast + pa_table = pa_table.cast(schema) + else: + # more expensive cast; allows str <-> int/float or str to Audio for example + pa_table = table_cast(pa_table, schema) + return pa_table + else: + return pa_table.cast(pa.schema({"xml": pa.string()})) + + def _generate_tables(self, files): + pa_table_names = list(self.config.features) if self.config.features is not None else ["xml"] + for file_idx, file in enumerate(itertools.chain.from_iterable(files)): + # open in text mode, by default translates universal newlines ("\n", "\r\n" and "\r") into "\n" + with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f: + xml = f.read() + pa_table = pa.Table.from_arrays([pa.array([xml])], names=pa_table_names) + yield file_idx, self._cast_table(pa_table)