Skip to content

Commit

Permalink
Fix #154 . Add copy/deepcopy methods for DataChunk and add tests (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
oruebel authored Sep 27, 2019
1 parent 68c5ac9 commit c551de2
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/hdmf/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,17 @@ def __getattr__(self, attr):
"""Delegate retrival of attributes to the data in self.data"""
return getattr(self.data, attr)

def __copy__(self):
newobj = DataChunk(data=self.data,
selection=self.selection)
return newobj

def __deepcopy__(self, memo):
result = DataChunk(data=copy.deepcopy(self.data),
selection=copy.deepcopy(self.selection))
memo[id(self)] = result
return result

def astype(self, dtype):
"""Get a new DataChunk with the self.data converted to the given type"""
return DataChunk(data=self.data.astype(dtype),
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/utils_test/test_core_DataChunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import unittest2 as unittest

from hdmf.data_utils import DataChunk
import numpy as np
from copy import copy, deepcopy


class DataChunkTests(unittest.TestCase):

def setUp(self):
pass

def tearDown(self):
pass

def test_datachunk_copy(self):
obj = DataChunk(data=np.arange(3), selection=np.s_[0:3])
obj_copy = copy(obj)
self.assertNotEqual(id(obj), id(obj_copy))
self.assertEqual(id(obj.data), id(obj_copy.data))
self.assertEqual(id(obj.selection), id(obj_copy.selection))

def test_datachunk_deepcopy(self):
obj = DataChunk(data=np.arange(3), selection=np.s_[0:3])
obj_copy = deepcopy(obj)
self.assertNotEqual(id(obj), id(obj_copy))
self.assertNotEqual(id(obj.data), id(obj_copy.data))
self.assertNotEqual(id(obj.selection), id(obj_copy.selection))

def test_datachunk_astype(self):
obj = DataChunk(data=np.arange(3), selection=np.s_[0:3])
newtype = np.dtype('int16')
obj_astype = obj.astype(newtype)
self.assertNotEqual(id(obj), id(obj_astype))
self.assertEqual(obj_astype.dtype, np.dtype(newtype))


if __name__ == '__main__':
unittest.main()

0 comments on commit c551de2

Please sign in to comment.