Skip to content

Commit

Permalink
feat: Adds .arrow support
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Oct 7, 2024
1 parent 4fff80a commit 3a284a5
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions tools/vendor_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import sys
import tempfile
from functools import cached_property, partial
from pathlib import Path
from typing import Any, Callable, ClassVar, Literal
Expand All @@ -29,11 +30,11 @@
_OLD_SOURCE_TAG = "v1.29.0" # 5 years ago
_CURRENT_SOURCE_TAG = "v2.9.0"

ExtSupported: TypeAlias = Literal[".csv", ".json", ".tsv"]
ExtSupported: TypeAlias = Literal[".csv", ".json", ".tsv", ".arrow"]


def is_ext_supported(suffix: str) -> TypeIs[ExtSupported]:
return suffix in {".csv", ".json", ".tsv"}
return suffix in {".csv", ".json", ".tsv", ".arrow"}


def _py_to_js(s: str, /):
Expand All @@ -49,6 +50,7 @@ class Dataset:
".csv": pl.read_csv,
".json": pl.read_json,
".tsv": partial(pl.read_csv, separator="\t"),
".arrow": partial(pl.read_ipc, use_pyarrow=True),
}

def __init__(self, name: str, /, base_url: str) -> None:
Expand All @@ -63,9 +65,10 @@ def __init__(self, name: str, /, base_url: str) -> None:
self.url: str = f"{base_url}{file_name}"

def __call__(self, **kwds: Any) -> pl.DataFrame:
with urlopen(self.url) as f:
fn = self.read_fn[self.extension]
content = fn(f, **kwds)
fn = self.read_fn[self.extension]
with tempfile.NamedTemporaryFile() as tmp, urlopen(self.url) as f:
tmp.write(f.read())
content = fn(tmp, **kwds)
return content

def __repr__(self) -> str:
Expand Down

0 comments on commit 3a284a5

Please sign in to comment.