Skip to content

Commit

Permalink
[BUGFIX] Handle settings on read-only filesystem (#465)
Browse files Browse the repository at this point in the history
* Handle settings on read-only filesystems

* Cleanup
  • Loading branch information
sdatkinson authored Sep 17, 2024
1 parent 17273de commit 26fdad7
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 12 deletions.
13 changes: 6 additions & 7 deletions nam/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,6 @@ def _get_configs(
batch_size: int,
fit_mrstft: bool,
):

data_config = _get_data_config(
input_version=input_version,
input_path=input_path,
Expand Down Expand Up @@ -1564,13 +1563,13 @@ def validate_data(
for split in Split:
try:
init_dataset(data_config, split)
pytorch_data_split_validation_dict[split.value] = (
_PyTorchDataSplitValidation(passed=True, msg=None)
)
pytorch_data_split_validation_dict[
split.value
] = _PyTorchDataSplitValidation(passed=True, msg=None)
except DataError as e:
pytorch_data_split_validation_dict[split.value] = (
_PyTorchDataSplitValidation(passed=False, msg=str(e))
)
pytorch_data_split_validation_dict[
split.value
] = _PyTorchDataSplitValidation(passed=False, msg=str(e))
pytorch_data_validation = _PyTorchDataValidation(
passed=all(v.passed for v in pytorch_data_split_validation_dict.values()),
**pytorch_data_split_validation_dict,
Expand Down
35 changes: 30 additions & 5 deletions nam/train/gui/_resources/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,38 @@ def _get_settings() -> dict:
"""
Make sure that ./settings.json exists; if it does, then read it. If not, empty dict.
"""

if not _SETTINGS_JSON_PATH.exists():
_write_settings({})
with open(_SETTINGS_JSON_PATH, "r") as fp:
return json.load(fp)
return dict()
else:
with open(_SETTINGS_JSON_PATH, "r") as fp:
return json.load(fp)


class _WriteSettings(object):
def __init__(self):
self._oserror = False

def __call__(self, *args, **kwargs):
if self._oserror:
return
# Try-catch for Issue 448
try:
return _write_settings_unsafe(*args, **kwargs)
except OSError as e:
if "Read-only filesystem" in str(e):
print(
"Failed to write settings--NAM appears to be installed to a "
"read-only filesystem. This is discouraged; consider installing to "
"a location with user-level access."
)
self._oserror = True
else:
raise e


_write_settings = _WriteSettings()


def _write_settings(obj: dict):
def _write_settings_unsafe(obj: dict):
with open(_SETTINGS_JSON_PATH, "w") as fp:
json.dump(obj, fp, indent=4)
3 changes: 3 additions & 0 deletions tests/test_nam/test_train/test_gui/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# File: __init__.py
# Created Date: Tuesday September 17th 2024
# Author: Steven Atkinson ([email protected])
File renamed without changes.
Empty file.
44 changes: 44 additions & 0 deletions tests/test_nam/test_train/test_gui/test_resources/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# File: test_resources.py
# Created Date: Tuesday September 17th 2024
# Author: Steven Atkinson ([email protected])

from contextlib import contextmanager
from pathlib import Path

import pytest

from nam.train.gui._resources import settings


class TestReadOnly(object):
"""
Issue 448
"""

@pytest.mark.parametrize("path_key", tuple(pk for pk in settings.PathKey))
def test_get_last_path(self, path_key: settings.PathKey):
with self._mock_read_only():
last_path = settings.get_last_path(path_key)
assert last_path is None or isinstance(last_path, Path)

@pytest.mark.parametrize("path_key", tuple(pk for pk in settings.PathKey))
def test_set_last_path(self, path_key: settings.PathKey):
path = Path(__file__).parent / Path("dummy.txt")
with self._mock_read_only():
settings.set_last_path(path_key, path)

@contextmanager
def _mock_read_only(self):
def write_settings(*args, **kwargs):
raise OSError("Read-only filesystem")

try:
tmp = settings._write_settings_unsafe
settings._write_settings_unsafe = write_settings
yield
finally:
settings._write_settings_unsafe = tmp


if __name__ == "__main__":
pytest.main()

0 comments on commit 26fdad7

Please sign in to comment.