Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move array storage settings to config #1468

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The ASDF Standard is at v1.6.0
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

- Drop support for ASDF-in-FITS. [#1288]
- Add ``all_array_storage``, ``all_array_compression`` and
``all_array_compression_kwargs`` to ``asdf.config.AsdfConfig`` [#1468]

2.15.0 (unreleased)
-------------------
Expand Down
70 changes: 70 additions & 0 deletions asdf/_tests/test_array_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,26 @@ def test_update_expand_tree(tmp_path):
assert_array_equal(ff.tree["arrays"][1], my_array2)


def test_update_all_external(tmp_path):
fn = tmp_path / "test.asdf"

my_array = np.arange(64) * 1
my_array2 = np.arange(64) * 2
tree = {"arrays": [my_array, my_array2]}

af = asdf.AsdfFile(tree)
af.write_to(fn)

with asdf.config.config_context() as cfg:
cfg.array_inline_threshold = 10
cfg.all_array_storage = "external"
with asdf.open(fn, mode="rw") as af:
af.update()

assert "test0000.asdf" in os.listdir(tmp_path)
assert "test0001.asdf" in os.listdir(tmp_path)


def _get_update_tree():
return {"arrays": [np.arange(64) * 1, np.arange(64) * 2, np.arange(64) * 3]}

Expand Down Expand Up @@ -830,3 +850,53 @@ def test_block_allocation_on_validate():
assert len(list(af._blocks.blocks)) == 1
af.validate()
assert len(list(af._blocks.blocks)) == 1


@pytest.mark.parametrize("all_array_storage", ["internal", "external", "inline"])
@pytest.mark.parametrize("all_array_compression", [None, "", "zlib", "bzp2", "lz4", "input"])
@pytest.mark.parametrize("compression_kwargs", [None, {}])
def test_write_to_update_storage_options(tmp_path, all_array_storage, all_array_compression, compression_kwargs):
if all_array_compression == "bzp2" and compression_kwargs is not None:
compression_kwargs = {"compresslevel": 1}

def assert_result(ff, arr):
if all_array_storage == "external":
assert "test0000.asdf" in os.listdir(tmp_path)
else:
assert "test0000.asdf" not in os.listdir(tmp_path)
if all_array_storage == "internal":
assert len(ff._blocks._internal_blocks) == 1
else:
assert len(ff._blocks._internal_blocks) == 0
blk = ff._blocks[arr]

target_compression = all_array_compression or None
assert blk._output_compression == target_compression

target_compression_kwargs = compression_kwargs or {}
assert blk._output_compression_kwargs == target_compression_kwargs

arr1 = np.ones((8, 8))
tree = {"array": arr1}
fn = tmp_path / "test.asdf"

ff1 = asdf.AsdfFile(tree)
# first check write_to
ff1.write_to(
fn,
all_array_storage=all_array_storage,
all_array_compression=all_array_compression,
compression_kwargs=compression_kwargs,
)
assert_result(ff1, arr1)

# then reuse the file to check update
with asdf.open(fn, mode="rw") as ff2:
arr2 = np.ones((8, 8)) * 42
ff2["array"] = arr2
ff2.update(
all_array_storage=all_array_storage,
all_array_compression=all_array_compression,
compression_kwargs=compression_kwargs,
)
assert_result(ff2, arr2)
6 changes: 6 additions & 0 deletions asdf/_tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,15 @@ def compressors(self):
def test_compression_with_extension(tmp_path):
tree = _get_large_tree()

with pytest.raises(ValueError, match="Supported compression types are"), config_context() as cfg:
cfg.all_array_compression = "lzma"

with config_context() as config:
config.add_extension(LzmaExtension())

with config_context() as cfg:
cfg.all_array_compression = "lzma"

with pytest.raises(lzma.LZMAError, match=r"Invalid or unsupported options"):
_roundtrip(tmp_path, tree, "lzma", write_options={"compression_kwargs": {"preset": 9000}})
fn = _roundtrip(tmp_path, tree, "lzma", write_options={"compression_kwargs": {"preset": 6}})
Expand Down
33 changes: 33 additions & 0 deletions asdf/_tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,39 @@ def test_array_inline_threshold():
assert get_config().array_inline_threshold is None


def test_all_array_storage():
with asdf.config_context() as config:
assert config.all_array_storage == asdf.config.DEFAULT_ALL_ARRAY_STORAGE
config.all_array_storage = "internal"
assert get_config().all_array_storage == "internal"
config.all_array_storage = None
assert get_config().all_array_storage is None
with pytest.raises(ValueError, match=r"Invalid value for all_array_storage"):
config.all_array_storage = "foo"


def test_all_array_compression():
with asdf.config_context() as config:
assert config.all_array_compression == asdf.config.DEFAULT_ALL_ARRAY_COMPRESSION
config.all_array_compression = "zlib"
assert get_config().all_array_compression == "zlib"
config.all_array_compression = None
assert get_config().all_array_compression is None
with pytest.raises(ValueError, match=r"Supported compression types are"):
config.all_array_compression = "foo"


def test_all_array_compression_kwargs():
with asdf.config_context() as config:
assert config.all_array_compression_kwargs == asdf.config.DEFAULT_ALL_ARRAY_COMPRESSION_KWARGS
config.all_array_compression_kwargs = {}
assert get_config().all_array_compression_kwargs == {}
config.all_array_compression_kwargs = None
assert get_config().all_array_compression_kwargs is None
with pytest.raises(ValueError, match=r"Invalid value for all_array_compression_kwargs"):
config.all_array_compression_kwargs = "foo"


def test_resource_mappings():
with asdf.config_context() as config:
core_mappings = get_json_schema_resource_mappings() + asdf_standard.integration.get_resource_mappings()
Expand Down
52 changes: 26 additions & 26 deletions asdf/asdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,16 +1092,7 @@ def _tree_finalizer(tagged_tree):
padding = util.calculate_padding(fd.tell(), pad_blocks, fd.block_size)
fd.fast_forward(padding)

def _pre_write(self, fd, all_array_storage, all_array_compression, compression_kwargs=None):
if all_array_storage not in (None, "internal", "external", "inline"):
msg = f"Invalid value for all_array_storage: '{all_array_storage}'"
raise ValueError(msg)

self._all_array_storage = all_array_storage

self._all_array_compression = all_array_compression
self._all_array_compression_kwargs = compression_kwargs

def _pre_write(self, fd):
if len(self._tree):
self._run_hook("pre_write")

Expand Down Expand Up @@ -1132,12 +1123,6 @@ def _post_write(self, fd):

def update(
self,
all_array_storage=None,
all_array_compression="input",
pad_blocks=False,
include_block_index=True,
version=None,
compression_kwargs=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -1196,7 +1181,17 @@ def update(
in ``asdf.get_config().array_inline_threshold``.
"""

pad_blocks = kwargs.pop("pad_blocks", False)
include_block_index = kwargs.pop("include_block_index", True)
version = kwargs.pop("version", None)

with config_context() as config:
if "all_array_storage" in kwargs:
config.all_array_storage = kwargs.pop("all_array_storage")
if "all_array_compression" in kwargs:
config.all_array_compression = kwargs.pop("all_array_compression")
if "compression_kwargs" in kwargs:
config.all_array_compression_kwargs = kwargs.pop("compression_kwargs")
_handle_deprecated_kwargs(config, kwargs)

fd = self._fd
Expand All @@ -1216,10 +1211,10 @@ def update(
if version is not None:
self.version = version

if all_array_storage == "external":
if config.all_array_storage == "external":
# If the file is fully exploded, there's no benefit to
# update, so just use write_to()
self.write_to(fd, all_array_storage=all_array_storage)
self.write_to(fd)
fd.truncate()
return

Expand All @@ -1233,7 +1228,7 @@ def update(
if fd.can_memmap():
fd.flush_memmap()

self._pre_write(fd, all_array_storage, all_array_compression, compression_kwargs=compression_kwargs)
self._pre_write(fd)

try:
fd.seek(0)
Expand Down Expand Up @@ -1280,12 +1275,6 @@ def update(
def write_to(
self,
fd,
all_array_storage=None,
all_array_compression="input",
pad_blocks=False,
include_block_index=True,
version=None,
compression_kwargs=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -1355,7 +1344,18 @@ def write_to(
``asdf.get_config().array_inline_threshold``.

"""

pad_blocks = kwargs.pop("pad_blocks", False)
include_block_index = kwargs.pop("include_block_index", True)
version = kwargs.pop("version", None)

with config_context() as config:
if "all_array_storage" in kwargs:
config.all_array_storage = kwargs.pop("all_array_storage")
if "all_array_compression" in kwargs:
config.all_array_compression = kwargs.pop("all_array_compression")
if "compression_kwargs" in kwargs:
config.all_array_compression_kwargs = kwargs.pop("compression_kwargs")
_handle_deprecated_kwargs(config, kwargs)

if version is not None:
Expand All @@ -1367,7 +1367,7 @@ def write_to(
# attribute of the AsdfFile.
if self._uri is None:
self._uri = fd.uri
self._pre_write(fd, all_array_storage, all_array_compression, compression_kwargs=compression_kwargs)
self._pre_write(fd)

try:
self._serial_write(fd, pad_blocks, include_block_index)
Expand Down
19 changes: 10 additions & 9 deletions asdf/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def write_external_blocks(self, uri, pad_blocks=False):
blk._array_storage = "internal"
asdffile._blocks.add(blk)
blk._used = True
asdffile.write_to(subfd, pad_blocks=pad_blocks)
asdffile.write_to(subfd, pad_blocks=pad_blocks, all_array_storage="internal")

def write_block_index(self, fd, ctx):
"""
Expand Down Expand Up @@ -567,13 +567,14 @@ def _find_used_blocks(self, tree, ctx):
if getattr(block, "_used", 0) == 0 and block not in reserved_blocks:
self.remove(block)

def _handle_global_block_settings(self, ctx, block):
all_array_storage = getattr(ctx, "_all_array_storage", None)
def _handle_global_block_settings(self, block):
cfg = get_config()
all_array_storage = cfg.all_array_storage
if all_array_storage:
self.set_array_storage(block, all_array_storage)

all_array_compression = getattr(ctx, "_all_array_compression", "input")
all_array_compression_kwargs = getattr(ctx, "_all_array_compression_kwargs", {})
all_array_compression = cfg.all_array_compression
all_array_compression_kwargs = cfg.all_array_compression_kwargs
# Only override block compression algorithm if it wasn't explicitly set
# by AsdfFile.set_array_compression.
if all_array_compression != "input":
Expand Down Expand Up @@ -601,7 +602,7 @@ def finalize(self, ctx):
self._find_used_blocks(ctx.tree, ctx)

for block in list(self.blocks):
self._handle_global_block_settings(ctx, block)
self._handle_global_block_settings(block)

def get_block(self, source):
"""
Expand Down Expand Up @@ -714,7 +715,7 @@ def get_source(self, block):
msg = "block not found."
raise ValueError(msg)

def find_or_create_block_for_array(self, arr, ctx):
def find_or_create_block_for_array(self, arr):
"""
For a given array, looks for an existing block containing its
underlying data. If not found, adds a new block to the block
Expand Down Expand Up @@ -743,7 +744,7 @@ def find_or_create_block_for_array(self, arr, ctx):

block = Block(base)
self.add(block)
self._handle_global_block_settings(ctx, block)
self._handle_global_block_settings(block)

return block

Expand Down Expand Up @@ -787,7 +788,7 @@ def get_output_compression_extensions(self):
return ext

def __getitem__(self, arr):
return self.find_or_create_block_for_array(arr, object())
return self.find_or_create_block_for_array(arr)

def close(self):
for block in self.blocks:
Expand Down
Loading