diff --git a/trx/tests/test_io.py b/trx/tests/test_io.py index 5a9dd8f..2f748d8 100644 --- a/trx/tests/test_io.py +++ b/trx/tests/test_io.py @@ -1,6 +1,7 @@ #! /usr/bin/env python3 # -*- coding: utf-8 -*- +from copy import deepcopy import os import pytest @@ -13,6 +14,7 @@ except ImportError: dipy_available = False +import trx.trx_file_memmap as tmm from trx.trx_file_memmap import TrxFile from trx.io import load, save, get_trx_tmpdir from trx.fetcher import (get_testing_files_dict, @@ -77,3 +79,27 @@ def test_multi_load_save_rasmm(path): obj = load(out_path, os.path.join(dir, 'gs.nii')) assert_allclose(obj.streamlines._data, coord, rtol=1e-04, atol=1e-06) + + +@pytest.mark.parametrize("path", [("gs.trx"), ("gs_fldr.trx")]) +@pytest.mark.skipif(not dipy_available, reason='Dipy is not installed.') +def test_close_tmp_file(path): + dir = os.path.join(get_home(), 'gold_standard') + path = os.path.join(dir, path) + + trx = tmm.load(path) + tmp_dir = deepcopy(trx._uncompressed_folder_handle) + sft = trx.to_sft() + trx.close() + + coord_rasmm = np.loadtxt(os.path.join(get_home(), 'gold_standard', + 'gs_rasmm_space.txt')) + coord_vox = np.loadtxt(os.path.join(get_home(), 'gold_standard', + 'gs_vox_space.txt')) + + # The folder trx representation does not need tmp files + if os.path.isfile(path): + assert not os.path.isdir(tmp_dir.name) + assert_allclose(sft.streamlines._data, coord_rasmm, rtol=1e-04, atol=1e-06) + sft.to_vox() + assert_allclose(sft.streamlines._data, coord_vox, rtol=1e-04, atol=1e-06)