Skip to content

Commit

Permalink
Add the Session.virtualfile_from_stringio method to allow StringIO in…
Browse files Browse the repository at this point in the history
…put for certain functions/methods (#3326)
  • Loading branch information
seisman authored Sep 18, 2024
1 parent a592ade commit 890626d
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 5 deletions.
1 change: 1 addition & 0 deletions doc/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,5 +317,6 @@ Low level access (these are mostly used by the :mod:`pygmt.clib` package):
clib.Session.get_libgmt_func
clib.Session.virtualfile_from_data
clib.Session.virtualfile_from_grid
clib.Session.virtualfile_from_stringio
clib.Session.virtualfile_from_matrix
clib.Session.virtualfile_from_vectors
106 changes: 104 additions & 2 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import contextlib
import ctypes as ctp
import io
import pathlib
import sys
import warnings
Expand Down Expand Up @@ -60,6 +61,7 @@
"GMT_IS_PLP", # items could be any one of POINT, LINE, or POLY
"GMT_IS_SURFACE", # items are 2-D grid
"GMT_IS_VOLUME", # items are 3-D grid
"GMT_IS_TEXT", # Text strings which triggers ASCII text reading
]

METHODS = [
Expand All @@ -70,6 +72,11 @@
DIRECTIONS = ["GMT_IN", "GMT_OUT"]

MODES = ["GMT_CONTAINER_ONLY", "GMT_IS_OUTPUT"]
MODE_MODIFIERS = [
"GMT_GRID_IS_CARTESIAN",
"GMT_GRID_IS_GEO",
"GMT_WITH_STRINGS",
]

REGISTRATIONS = ["GMT_GRID_PIXEL_REG", "GMT_GRID_NODE_REG"]

Expand Down Expand Up @@ -728,7 +735,7 @@ def create_data(
mode_int = self._parse_constant(
mode,
valid=MODES,
valid_modifiers=["GMT_GRID_IS_CARTESIAN", "GMT_GRID_IS_GEO"],
valid_modifiers=MODE_MODIFIERS,
)
geometry_int = self._parse_constant(geometry, valid=GEOMETRIES)
registration_int = self._parse_constant(registration, valid=REGISTRATIONS)
Expand Down Expand Up @@ -1603,6 +1610,100 @@ def virtualfile_from_grid(self, grid):
with self.open_virtualfile(*args) as vfile:
yield vfile

@contextlib.contextmanager
def virtualfile_from_stringio(self, stringio: io.StringIO):
r"""
Store a :class:`io.StringIO` object in a virtual file.
Store the contents of a :class:`io.StringIO` object in a GMT_DATASET container
and create a virtual file to pass to a GMT module.
For simplicity, currently we make following assumptions in the StringIO object
- ``"#"`` indicates a comment line.
- ``">"`` indicates a segment header.
Parameters
----------
stringio
The :class:`io.StringIO` object containing the data to be stored in the
virtual file.
Yields
------
fname
The name of the virtual file.
Examples
--------
>>> import io
>>> from pygmt.clib import Session
>>> # A StringIO object containing legend specifications
>>> stringio = io.StringIO(
... "# Comment\n"
... "H 24p Legend\n"
... "N 2\n"
... "S 0.1i c 0.15i p300/12 0.25p 0.3i My circle\n"
... )
>>> with Session() as lib:
... with lib.virtualfile_from_stringio(stringio) as fin:
... lib.virtualfile_to_dataset(vfname=fin, output_type="pandas")
0
0 H 24p Legend
1 N 2
2 S 0.1i c 0.15i p300/12 0.25p 0.3i My circle
"""
# Parse the io.StringIO object.
segments = []
current_segment = {"header": "", "data": []}
for line in stringio.getvalue().splitlines():
if line.startswith("#"): # Skip comments
continue
if line.startswith(">"): # Segment header
if current_segment["data"]: # If we have data, start a new segment
segments.append(current_segment)
current_segment = {"header": "", "data": []}
current_segment["header"] = line.strip(">").lstrip()
else:
current_segment["data"].append(line) # type: ignore[attr-defined]
if current_segment["data"]: # Add the last segment if it has data
segments.append(current_segment)

# One table with one or more segments.
# n_rows is the maximum number of rows/records for all segments.
# n_columns is the number of numeric data columns, so it's 0 here.
n_tables = 1
n_segments = len(segments)
n_rows = max(len(segment["data"]) for segment in segments)
n_columns = 0

# Create the GMT_DATASET container
family, geometry = "GMT_IS_DATASET", "GMT_IS_TEXT"
dataset = self.create_data(
family,
geometry,
mode="GMT_CONTAINER_ONLY|GMT_WITH_STRINGS",
dim=[n_tables, n_segments, n_rows, n_columns],
)
dataset = ctp.cast(dataset, ctp.POINTER(_GMT_DATASET))
table = dataset.contents.table[0].contents
for i, segment in enumerate(segments):
seg = table.segment[i].contents
if segment["header"]:
seg.header = segment["header"].encode() # type: ignore[attr-defined]
seg.text = strings_to_ctypes_array(segment["data"])

with self.open_virtualfile(family, geometry, "GMT_IN", dataset) as vfile:
try:
yield vfile
finally:
# Must set the pointers to None to avoid double freeing the memory.
# Maybe upstream bug.
for i in range(n_segments):
seg = table.segment[i].contents
seg.header = None
seg.text = None

def virtualfile_in( # noqa: PLR0912
self,
check_kind=None,
Expand Down Expand Up @@ -1696,6 +1797,7 @@ def virtualfile_in( # noqa: PLR0912
"geojson": tempfile_from_geojson,
"grid": self.virtualfile_from_grid,
"image": tempfile_from_image,
"stringio": self.virtualfile_from_stringio,
# Note: virtualfile_from_matrix is not used because a matrix can be
# converted to vectors instead, and using vectors allows for better
# handling of string type inputs (e.g. for datetime data types)
Expand All @@ -1704,7 +1806,7 @@ def virtualfile_in( # noqa: PLR0912
}[kind]

# Ensure the data is an iterable (Python list or tuple)
if kind in {"geojson", "grid", "image", "file", "arg"}:
if kind in {"geojson", "grid", "image", "file", "arg", "stringio"}:
if kind == "image" and data.dtype != "uint8":
msg = (
f"Input image has dtype: {data.dtype} which is unsupported, "
Expand Down
16 changes: 13 additions & 3 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Utilities and common tasks for wrapping the GMT modules.
"""

import io
import os
import pathlib
import shutil
Expand Down Expand Up @@ -188,8 +189,10 @@ def _check_encoding(

def data_kind(
data: Any = None, required: bool = True
) -> Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]:
"""
) -> Literal[
"arg", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]:
r"""
Check the kind of data that is provided to a module.
The ``data`` argument can be in any type, but only following types are supported:
Expand Down Expand Up @@ -222,6 +225,7 @@ def data_kind(
>>> import numpy as np
>>> import xarray as xr
>>> import pathlib
>>> import io
>>> data_kind(data=None)
'vectors'
>>> data_kind(data=np.arange(10).reshape((5, 2)))
Expand All @@ -240,8 +244,12 @@ def data_kind(
'grid'
>>> data_kind(data=xr.DataArray(np.random.rand(3, 4, 5)))
'image'
>>> data_kind(data=io.StringIO("TEXT1\nTEXT23\n"))
'stringio'
"""
kind: Literal["arg", "file", "geojson", "grid", "image", "matrix", "vectors"]
kind: Literal[
"arg", "file", "geojson", "grid", "image", "matrix", "stringio", "vectors"
]
if isinstance(data, str | pathlib.PurePath) or (
isinstance(data, list | tuple)
and all(isinstance(_file, str | pathlib.PurePath) for _file in data)
Expand All @@ -250,6 +258,8 @@ def data_kind(
kind = "file"
elif isinstance(data, bool | int | float) or (data is None and not required):
kind = "arg"
elif isinstance(data, io.StringIO):
kind = "stringio"
elif isinstance(data, xr.DataArray):
kind = "image" if len(data.dims) == 3 else "grid"
elif hasattr(data, "__geo_interface__"):
Expand Down
104 changes: 104 additions & 0 deletions pygmt/tests/test_clib_virtualfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Test the C API functions related to virtual files.
"""

import io
from importlib.util import find_spec
from itertools import product
from pathlib import Path
Expand Down Expand Up @@ -407,3 +408,106 @@ def test_inquire_virtualfile():
]:
with lib.open_virtualfile(family, geometry, "GMT_OUT", None) as vfile:
assert lib.inquire_virtualfile(vfile) == lib[family]


class TestVirtualfileFromStringIO:
"""
Test the virtualfile_from_stringio method.
"""

def _stringio_to_dataset(self, data: io.StringIO):
"""
A helper function for check the virtualfile_from_stringio method.
The function does the following:
1. Creates a virtual file from the input StringIO object.
2. Pass the virtual file to the ``read`` module, which reads the virtual file
and writes it to another virtual file.
3. Reads the output virtual file as a GMT_DATASET object.
4. Extracts the header and the trailing text from the dataset and returns it as
a string.
"""
with clib.Session() as lib:
with (
lib.virtualfile_from_stringio(data) as vintbl,
lib.virtualfile_out(kind="dataset") as vouttbl,
):
lib.call_module("read", args=[vintbl, vouttbl, "-Td"])
ds = lib.read_virtualfile(vouttbl, kind="dataset").contents

output = []
table = ds.table[0].contents
for segment in table.segment[: table.n_segments]:
seg = segment.contents
output.append(f"> {seg.header.decode()}" if seg.header else ">")
output.extend(np.char.decode(seg.text[: seg.n_rows]))
return "\n".join(output) + "\n"

def test_virtualfile_from_stringio(self):
"""
Test the virtualfile_from_stringio method.
"""
data = io.StringIO(
"# Comment\n"
"H 24p Legend\n"
"N 2\n"
"S 0.1i c 0.15i p300/12 0.25p 0.3i My circle\n"
)
expected = (
">\n"
"H 24p Legend\n"
"N 2\n"
"S 0.1i c 0.15i p300/12 0.25p 0.3i My circle\n"
)
assert self._stringio_to_dataset(data) == expected

def test_one_segment(self):
"""
Test the virtualfile_from_stringio method with one segment.
"""
data = io.StringIO(
"# Comment\n"
"> Segment 1\n"
"1 2 3 ABC\n"
"4 5 DE\n"
"6 7 8 9 FGHIJK LMN OPQ\n"
"RSTUVWXYZ\n"
)
expected = (
"> Segment 1\n"
"1 2 3 ABC\n"
"4 5 DE\n"
"6 7 8 9 FGHIJK LMN OPQ\n"
"RSTUVWXYZ\n"
)
assert self._stringio_to_dataset(data) == expected

def test_multiple_segments(self):
"""
Test the virtualfile_from_stringio method with multiple segments.
"""
data = io.StringIO(
"# Comment line 1\n"
"# Comment line 2\n"
"> Segment 1\n"
"1 2 3 ABC\n"
"4 5 DE\n"
"6 7 8 9 FG\n"
"# Comment line 3\n"
"> Segment 2\n"
"1 2 3 ABC\n"
"4 5 DE\n"
"6 7 8 9 FG\n"
)
expected = (
"> Segment 1\n"
"1 2 3 ABC\n"
"4 5 DE\n"
"6 7 8 9 FG\n"
"> Segment 2\n"
"1 2 3 ABC\n"
"4 5 DE\n"
"6 7 8 9 FG\n"
)
assert self._stringio_to_dataset(data) == expected

0 comments on commit 890626d

Please sign in to comment.