From 8181573f5312a865b79a837f8ced93d02530e7cf Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Thu, 16 Apr 2020 23:42:09 -0500 Subject: [PATCH 1/6] add method update1 to dj.Table --- datajoint/table.py | 246 ++++++++++++++++++++++----------------- datajoint/user_tables.py | 2 +- tests/test_attach.py | 1 - tests/test_update1.py | 65 +++++++++++ 4 files changed, 204 insertions(+), 110 deletions(-) create mode 100644 tests/test_update1.py diff --git a/datajoint/table.py b/datajoint/table.py index e9bacd437..aaf728376 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -166,6 +166,40 @@ def _log(self): def external(self): return self.connection.schemas[self.database].external + def update1(self, row): + """ + Update an existing entry in the table. + Caution: Updates are not part of the DataJoint data manipulation model. For strict data integrity, + use delete and insert. + :param row: a dict containing the primary key and the attributes to update. + Setting an attribute value to None will reset it to the default value (if any) + The primary key attributes must always be provided. + Examples: + >>> table.update1({'id': 1, 'value': 3}) # update value in record with id=1 + >>> table.update1({'id': 1, 'value': None}) # reset value to default + """ + # argument validations + if not isinstance(row, collections.abc.Mapping): + raise DataJointError('The argument of update must be dict-like') + if not set(row).issuperset(self.primary_key): + raise DataJointError('The argument of update must supply all primary key values') + try: + raise DataJointError('Attribute `%s` not found.' % next(k for k in row if k not in self.heading.names)) + except StopIteration: + pass # ok + if len(self.restriction): + raise DataJointError('Update cannot be applied to a restricted table.') + key = {k: row[k] for k in self.primary_key} + if len(self & key) != 1: + raise DataJointError('Update entry must exist.') + # UPDATE query + row = [self.__make_placeholder(k, v) for k, v in row.items() if k not in self.primary_key] + query = "UPDATE {table} SET {assignments} WHERE {where}".format( + table=self.full_table_name, + assignments=",".join('`%s`=%s' % r[:2] for r in row), + where=self._make_condition(key)) + self.connection.query(query, args=list(r[2] for r in row if r[2] is not None)) + def insert1(self, row, **kwargs): """ Insert one data record or one Mapping (like a dict). @@ -201,7 +235,6 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields 'Inserts into an auto-populated table can only done inside its make method during a populate call.' ' To override, set keyword argument allow_direct_insert=True.') - heading = self.heading if inspect.isclass(rows) and issubclass(rows, QueryExpression): # instantiate if a class rows = rows() if isinstance(rows, QueryExpression): @@ -210,10 +243,10 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields try: raise DataJointError( "Attribute %s not found. To ignore extra attributes in insert, set ignore_extra_fields=True." % - next(name for name in rows.heading if name not in heading)) + next(name for name in rows.heading if name not in self.heading)) except StopIteration: pass - fields = list(name for name in rows.heading if name in heading) + fields = list(name for name in rows.heading if name in self.heading) query = '{command} INTO {table} ({fields}) {select}{duplicate}'.format( command='REPLACE' if replace else 'INSERT', fields='`' + '`,`'.join(fields) + '`', @@ -225,111 +258,8 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields self.connection.query(query) return - if heading.attributes is None: - logger.warning('Could not access table {table}'.format(table=self.full_table_name)) - return - - field_list = None # ensures that all rows have the same attributes in the same order as the first row. - - def make_row_to_insert(row): - """ - :param row: A tuple to insert - :return: a dict with fields 'names', 'placeholders', 'values' - """ - - def make_placeholder(name, value): - """ - For a given attribute `name` with `value`, return its processed value or value placeholder - as a string to be included in the query and the value, if any, to be submitted for - processing by mysql API. - :param name: name of attribute to be inserted - :param value: value of attribute to be inserted - """ - if ignore_extra_fields and name not in heading: - return None - attr = heading[name] - if attr.adapter: - value = attr.adapter.put(value) - if value is None or (attr.numeric and (value == '' or np.isnan(np.float(value)))): - # set default value - placeholder, value = 'DEFAULT', None - else: # not NULL - placeholder = '%s' - if attr.uuid: - if not isinstance(value, uuid.UUID): - try: - value = uuid.UUID(value) - except (AttributeError, ValueError): - raise DataJointError( - 'badly formed UUID value {v} for attribute `{n}`'.format(v=value, n=name)) from None - value = value.bytes - elif attr.is_blob: - value = blob.pack(value) - value = self.external[attr.store].put(value).bytes if attr.is_external else value - elif attr.is_attachment: - attachment_path = Path(value) - if attr.is_external: - # value is hash of contents - value = self.external[attr.store].upload_attachment(attachment_path).bytes - else: - # value is filename + contents - value = str.encode(attachment_path.name) + b'\0' + attachment_path.read_bytes() - elif attr.is_filepath: - value = self.external[attr.store].upload_filepath(value).bytes - elif attr.numeric: - value = str(int(value) if isinstance(value, bool) else value) - return name, placeholder, value - - def check_fields(fields): - """ - Validates that all items in `fields` are valid attributes in the heading - :param fields: field names of a tuple - """ - if field_list is None: - if not ignore_extra_fields: - for field in fields: - if field not in heading: - raise KeyError(u'`{0:s}` is not in the table heading'.format(field)) - elif set(field_list) != set(fields).intersection(heading.names): - raise DataJointError('Attempt to insert rows with different fields') - - if isinstance(row, np.void): # np.array - check_fields(row.dtype.fields) - attributes = [make_placeholder(name, row[name]) - for name in heading if name in row.dtype.fields] - elif isinstance(row, collections.abc.Mapping): # dict-based - check_fields(row) - attributes = [make_placeholder(name, row[name]) for name in heading if name in row] - else: # positional - try: - if len(row) != len(heading): - raise DataJointError( - 'Invalid insert argument. Incorrect number of attributes: ' - '{given} given; {expected} expected'.format( - given=len(row), expected=len(heading))) - except TypeError: - raise DataJointError('Datatype %s cannot be inserted' % type(row)) - else: - attributes = [make_placeholder(name, value) for name, value in zip(heading, row)] - if ignore_extra_fields: - attributes = [a for a in attributes if a is not None] - - assert len(attributes), 'Empty tuple' - row_to_insert = dict(zip(('names', 'placeholders', 'values'), zip(*attributes))) - nonlocal field_list - if field_list is None: - # first row sets the composition of the field list - field_list = row_to_insert['names'] - else: - # reorder attributes in row_to_insert to match field_list - order = list(row_to_insert['names'].index(field) for field in field_list) - row_to_insert['names'] = list(row_to_insert['names'][i] for i in order) - row_to_insert['placeholders'] = list(row_to_insert['placeholders'][i] for i in order) - row_to_insert['values'] = list(row_to_insert['values'][i] for i in order) - - return row_to_insert - - rows = list(make_row_to_insert(row) for row in rows) + field_list = [] # collects the field list from first row (passed by reference) + rows = list(self.__make_row_to_insert(row, field_list, ignore_extra_fields) for row in rows) if rows: try: query = "{command} INTO {destination}(`{fields}`) VALUES {placeholders}{duplicate}".format( @@ -598,6 +528,106 @@ def _update(self, attrname, value=None): where_clause=self.where_clause) self.connection.query(command, args=(value, ) if value is not None else ()) + # --- private helper functions ---- + def __make_placeholder(self, name, value, ignore_extra_fields=False): + """ + For a given attribute `name` with `value`, return its processed value or value placeholder + as a string to be included in the query and the value, if any, to be submitted for + processing by mysql API. + :param name: name of attribute to be inserted + :param value: value of attribute to be inserted + """ + if ignore_extra_fields and name not in self.heading: + return None + attr = self.heading[name] + if attr.adapter: + value = attr.adapter.put(value) + if value is None or (attr.numeric and (value == '' or np.isnan(np.float(value)))): + # set default value + placeholder, value = 'DEFAULT', None + else: # not NULL + placeholder = '%s' + if attr.uuid: + if not isinstance(value, uuid.UUID): + try: + value = uuid.UUID(value) + except (AttributeError, ValueError): + raise DataJointError( + 'badly formed UUID value {v} for attribute `{n}`'.format(v=value, n=name)) from None + value = value.bytes + elif attr.is_blob: + value = blob.pack(value) + value = self.external[attr.store].put(value).bytes if attr.is_external else value + elif attr.is_attachment: + attachment_path = Path(value) + if attr.is_external: + # value is hash of contents + value = self.external[attr.store].upload_attachment(attachment_path).bytes + else: + # value is filename + contents + value = str.encode(attachment_path.name) + b'\0' + attachment_path.read_bytes() + elif attr.is_filepath: + value = self.external[attr.store].upload_filepath(value).bytes + elif attr.numeric: + value = str(int(value) if isinstance(value, bool) else value) + return name, placeholder, value + + def __make_row_to_insert(self, row, field_list, ignore_extra_fields): + """ + Helper function for insert and update + :param row: A tuple to insert + :return: a dict with fields 'names', 'placeholders', 'values' + """ + + def check_fields(fields): + """ + Validates that all items in `fields` are valid attributes in the heading + :param fields: field names of a tuple + """ + if not field_list: + if not ignore_extra_fields: + for field in fields: + if field not in self.heading: + raise KeyError(u'`{0:s}` is not in the table heading'.format(field)) + elif set(field_list) != set(fields).intersection(self.heading.names): + raise DataJointError('Attempt to insert rows with different fields') + + if isinstance(row, np.void): # np.array + check_fields(row.dtype.fields) + attributes = [self.__make_placeholder(name, row[name], ignore_extra_fields) + for name in self.heading if name in row.dtype.fields] + elif isinstance(row, collections.abc.Mapping): # dict-based + check_fields(row) + attributes = [self.__make_placeholder(name, row[name], ignore_extra_fields) + for name in self.heading if name in row] + else: # positional + try: + if len(row) != len(self.heading): + raise DataJointError( + 'Invalid insert argument. Incorrect number of attributes: ' + '{given} given; {expected} expected'.format( + given=len(row), expected=len(self.heading))) + except TypeError: + raise DataJointError('Datatype %s cannot be inserted' % type(row)) + else: + attributes = [self.__make_placeholder(name, value, ignore_extra_fields) + for name, value in zip(self.heading, row)] + if ignore_extra_fields: + attributes = [a for a in attributes if a is not None] + + assert len(attributes), 'Empty tuple' + row_to_insert = dict(zip(('names', 'placeholders', 'values'), zip(*attributes))) + if not field_list: + # first row sets the composition of the field list + field_list.extend(row_to_insert['names']) + else: + # reorder attributes in row_to_insert to match field_list + order = list(row_to_insert['names'].index(field) for field in field_list) + row_to_insert['names'] = list(row_to_insert['names'][i] for i in order) + row_to_insert['placeholders'] = list(row_to_insert['placeholders'][i] for i in order) + row_to_insert['values'] = list(row_to_insert['values'][i] for i in order) + return row_to_insert + def lookup_class_name(name, context, depth=3): """ diff --git a/datajoint/user_tables.py b/datajoint/user_tables.py index 3942264b5..5613ebb3c 100644 --- a/datajoint/user_tables.py +++ b/datajoint/user_tables.py @@ -14,7 +14,7 @@ supported_class_attrs = { 'key_source', 'describe', 'alter', 'heading', 'populate', 'progress', 'primary_key', 'proj', 'aggr', 'fetch', 'fetch1', 'head', 'tail', - 'insert', 'insert1', 'drop', 'drop_quick', 'delete', 'delete_quick'} + 'insert', 'insert1', 'update1', 'drop', 'drop_quick', 'delete', 'delete_quick'} class OrderedClass(type): diff --git a/tests/test_attach.py b/tests/test_attach.py index db2f2f1c7..ebe866f0c 100644 --- a/tests/test_attach.py +++ b/tests/test_attach.py @@ -1,5 +1,4 @@ from nose.tools import assert_true, assert_equal, assert_not_equal -from numpy.testing import assert_array_equal import tempfile from pathlib import Path import os diff --git a/tests/test_update1.py b/tests/test_update1.py new file mode 100644 index 000000000..0d1b333cb --- /dev/null +++ b/tests/test_update1.py @@ -0,0 +1,65 @@ +from nose.tools import assert_true, assert_equal, raises +import os +import numpy as np +from pathlib import Path +import tempfile +import datajoint as dj +from . import PREFIX, CONN_INFO + +schema = dj.Schema(PREFIX + '_update1', connection=dj.conn(**CONN_INFO)) + +dj.config['stores']['update1_store'] = dict( + stage=tempfile.mkdtemp(), + protocol='file', + location=tempfile.mkdtemp()) + +scratch_folder = tempfile.mkdtemp() + +@schema +class Thing(dj.Manual): + definition = """ + thing : int + --- + number=0 : int + frac : float + picture = null : attach@update1_store + params = null : longblob + timestamp = CURRENT_TIMESTAMP : datetime + """ + + +def test_update1(): + """test normal updates""" + + # CHECK 1 + key = dict(thing=1) + Thing.insert1(dict(key, frac=0.5)) + check1 = Thing.fetch1() + + # CHECK 2 -- some updates + Thing.update1(dict(key, number=3, frac=30)) + attach_file = Path(scratch_folder, 'attach1.dat') + buffer1 = os.urandom(100) + attach_file.write_bytes(buffer1) + Thing.update1(dict(key, picture=attach_file)) + Thing.update1(dict(key, timestamp="2020-01-01 10:00:00")) + check2 = Thing.fetch1(download_path=scratch_folder) + buffer2 = Path(check2['picture']).read_bytes() + + # CHECK 3 + Thing.update1(dict(key, timestamp=None, picture=None, params=np.random.randn(3, 3))) # rest to default + check3 = Thing.fetch1() + + assert_true(check1['number'] == 0 and check1['picture'] is None and check1['params'] is None) + assert_true(check2['number'] == 0 and check1['picture'] is None and check2['params'] is None) + assert_true(check3['timestamp'] > check2['timestamp']) + assert_equal(buffer1, buffer2) + + print(check1, check2) + + + + + + + From a8a06cd9e6ab48a64cfc8f3879d0e55664383f7c Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 17 Apr 2020 00:01:53 -0500 Subject: [PATCH 2/6] fix test for update1 --- tests/test_update1.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/test_update1.py b/tests/test_update1.py index 0d1b333cb..37b60f5b0 100644 --- a/tests/test_update1.py +++ b/tests/test_update1.py @@ -27,7 +27,6 @@ class Thing(dj.Manual): timestamp = CURRENT_TIMESTAMP : datetime """ - def test_update1(): """test normal updates""" @@ -47,19 +46,22 @@ def test_update1(): buffer2 = Path(check2['picture']).read_bytes() # CHECK 3 - Thing.update1(dict(key, timestamp=None, picture=None, params=np.random.randn(3, 3))) # rest to default + Thing.update1(dict(key, number=None, timestamp=None, picture=None, params=np.random.randn(3, 3))) check3 = Thing.fetch1() - assert_true(check1['number'] == 0 and check1['picture'] is None and check1['params'] is None) - assert_true(check2['number'] == 0 and check1['picture'] is None and check2['params'] is None) - assert_true(check3['timestamp'] > check2['timestamp']) - assert_equal(buffer1, buffer2) - - print(check1, check2) - - - - + assert_true(check1['number'] == 0 and + check1['picture'] is None and + check1['params'] is None) + assert_true(check2['number'] == 3 and + check2['frac'] == 30.0 and + check2['picture'] is not None and + check2['params'] is None) + assert_true(check3['number'] == 0 and + check3['frac'] == 30.0 and + check3['picture'] is None and + isinstance(check3['params'], np.ndarray)) + assert_true(check3['timestamp'] > check2['timestamp']) + assert_equal(buffer1, buffer2) From 039801415417c5ea64414ce1c658c44b83e5c568 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 17 Apr 2020 00:15:29 -0500 Subject: [PATCH 3/6] add more tests for update1 --- tests/test_reconnection.py | 9 ++------- tests/test_update1.py | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/tests/test_reconnection.py b/tests/test_reconnection.py index 22ebdd4d6..4ddef39f7 100644 --- a/tests/test_reconnection.py +++ b/tests/test_reconnection.py @@ -2,12 +2,10 @@ Collection of test cases to test connection module. """ -from nose.tools import assert_true, assert_false, assert_equal, raises +from nose.tools import assert_true, assert_false, raises import datajoint as dj -import numpy as np from datajoint import DataJointError -from . import CONN_INFO, PREFIX - +from . import CONN_INFO class TestReconnect: @@ -18,20 +16,17 @@ class TestReconnect: def setup(self): self.conn = dj.conn(reset=True, **CONN_INFO) - def test_close(self): assert_true(self.conn.is_connected, "Connection should be alive") self.conn.close() assert_false(self.conn.is_connected, "Connection should now be closed") - def test_reconnect(self): assert_true(self.conn.is_connected, "Connection should be alive") self.conn.close() self.conn.query('SHOW DATABASES;', reconnect=True).fetchall() assert_true(self.conn.is_connected, "Connection should be alive") - @raises(DataJointError) def test_reconnect_throws_error_in_transaction(self): assert_true(self.conn.is_connected, "Connection should be alive") diff --git a/tests/test_update1.py b/tests/test_update1.py index 37b60f5b0..964d4f7e2 100644 --- a/tests/test_update1.py +++ b/tests/test_update1.py @@ -5,6 +5,7 @@ import tempfile import datajoint as dj from . import PREFIX, CONN_INFO +from datajoint import DataJointError schema = dj.Schema(PREFIX + '_update1', connection=dj.conn(**CONN_INFO)) @@ -27,10 +28,11 @@ class Thing(dj.Manual): timestamp = CURRENT_TIMESTAMP : datetime """ + def test_update1(): """test normal updates""" - # CHECK 1 + # CHECK 1 -- initial insert key = dict(thing=1) Thing.insert1(dict(key, frac=0.5)) check1 = Thing.fetch1() @@ -45,7 +47,7 @@ def test_update1(): check2 = Thing.fetch1(download_path=scratch_folder) buffer2 = Path(check2['picture']).read_bytes() - # CHECK 3 + # CHECK 3 -- reset to default values using None Thing.update1(dict(key, number=None, timestamp=None, picture=None, params=np.random.randn(3, 3))) check3 = Thing.fetch1() @@ -65,3 +67,20 @@ def test_update1(): assert_true(check3['timestamp'] > check2['timestamp']) assert_equal(buffer1, buffer2) + + +@raises(DataJointError) +def test_update1_nonexistent(): + Thing.update1(dict(thing=100, frac=0.5)) # updating a non-existent entry + + +@raises(DataJointError) +def test_update1_noprimary(): + Thing.update1(dict(number=None)) # missing primary key + + +@raises(DataJointError) +def test_update1_misspelled_attribute(): + key = dict(thing=17) + Thing.insert1(dict(key, frac=1.5)) + Thing.update1(dict(key, numer=3)) # misspelled attribute From 19abc9f8dbd6f116fb0b3c5d506a87e170cb7e77 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 17 Apr 2020 00:20:19 -0500 Subject: [PATCH 4/6] minor error message corrections. --- datajoint/table.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datajoint/table.py b/datajoint/table.py index aaf728376..3aad7c4c0 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -180,9 +180,9 @@ def update1(self, row): """ # argument validations if not isinstance(row, collections.abc.Mapping): - raise DataJointError('The argument of update must be dict-like') + raise DataJointError('The argument of update1 must be dict-like.') if not set(row).issuperset(self.primary_key): - raise DataJointError('The argument of update must supply all primary key values') + raise DataJointError('The argument of update1 must supply all primary key values.') try: raise DataJointError('Attribute `%s` not found.' % next(k for k in row if k not in self.heading.names)) except StopIteration: From 862f527b0d1ccf965e34696c40b2a6c2c70d6444 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 17 Apr 2020 08:59:55 -0500 Subject: [PATCH 5/6] minor style --- datajoint/table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datajoint/table.py b/datajoint/table.py index 3aad7c4c0..5942b24e8 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -258,7 +258,7 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields self.connection.query(query) return - field_list = [] # collects the field list from first row (passed by reference) + field_list = [] # collects the field list from first row (passed by reference) rows = list(self.__make_row_to_insert(row, field_list, ignore_extra_fields) for row in rows) if rows: try: From 42c1836c73cb45d0db1848c2c91c36eb7aef5a31 Mon Sep 17 00:00:00 2001 From: Dimitri Yatsenko Date: Fri, 24 Apr 2020 19:33:17 -0500 Subject: [PATCH 6/6] add a filepath test to test_update1 --- tests/test_fetch_same.py | 1 + tests/test_update1.py | 44 ++++++++++++++++++++++++++++++++-------- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/tests/test_fetch_same.py b/tests/test_fetch_same.py index 96cbfcfcb..f00624efb 100644 --- a/tests/test_fetch_same.py +++ b/tests/test_fetch_same.py @@ -9,6 +9,7 @@ import datajoint as dj schema = dj.Schema(PREFIX + '_fetch_same', connection=dj.conn(**CONN_INFO)) +dj.config['enable_python_native_blobs'] = True @schema diff --git a/tests/test_update1.py b/tests/test_update1.py index 964d4f7e2..1ce4e8eda 100644 --- a/tests/test_update1.py +++ b/tests/test_update1.py @@ -1,4 +1,4 @@ -from nose.tools import assert_true, assert_equal, raises +from nose.tools import assert_true, assert_false, assert_equal, raises import os import numpy as np from pathlib import Path @@ -9,13 +9,20 @@ schema = dj.Schema(PREFIX + '_update1', connection=dj.conn(**CONN_INFO)) -dj.config['stores']['update1_store'] = dict( +dj.config['stores']['update_store'] = dict( + protocol='file', + location=tempfile.mkdtemp()) + +dj.config['stores']['update_repo'] = dict( stage=tempfile.mkdtemp(), protocol='file', location=tempfile.mkdtemp()) + scratch_folder = tempfile.mkdtemp() +dj.errors._switch_filepath_types(True) + @schema class Thing(dj.Manual): definition = """ @@ -23,8 +30,9 @@ class Thing(dj.Manual): --- number=0 : int frac : float - picture = null : attach@update1_store + picture = null : attach@update_store params = null : longblob + img_file = null: filepath@update_repo timestamp = CURRENT_TIMESTAMP : datetime """ @@ -32,23 +40,41 @@ class Thing(dj.Manual): def test_update1(): """test normal updates""" + dj.errors._switch_filepath_types(True) # CHECK 1 -- initial insert key = dict(thing=1) Thing.insert1(dict(key, frac=0.5)) check1 = Thing.fetch1() # CHECK 2 -- some updates - Thing.update1(dict(key, number=3, frac=30)) + # numbers and datetimes + Thing.update1(dict(key, number=3, frac=30, timestamp="2020-01-01 10:00:00")) + # attachment attach_file = Path(scratch_folder, 'attach1.dat') buffer1 = os.urandom(100) attach_file.write_bytes(buffer1) Thing.update1(dict(key, picture=attach_file)) - Thing.update1(dict(key, timestamp="2020-01-01 10:00:00")) + attach_file.unlink() + assert_false(attach_file.is_file()) + + # filepath + stage_path = dj.config['stores']['update_repo']['stage'] + relpath, filename = 'one/two/three', 'picture.dat' + managed_file = Path(stage_path, relpath, filename) + managed_file.parent.mkdir(parents=True, exist_ok=True) + original_file_data = os.urandom(3000) + with managed_file.open('wb') as f: + f.write(original_file_data) + Thing.update1(dict(key, img_file=managed_file)) + managed_file.unlink() + assert_false(managed_file.is_file()) + check2 = Thing.fetch1(download_path=scratch_folder) - buffer2 = Path(check2['picture']).read_bytes() + buffer2 = Path(check2['picture']).read_bytes() # read attachment + final_file_data = managed_file.read_bytes() # read filepath # CHECK 3 -- reset to default values using None - Thing.update1(dict(key, number=None, timestamp=None, picture=None, params=np.random.randn(3, 3))) + Thing.update1(dict(key, number=None, timestamp=None, picture=None, img_file=None, params=np.random.randn(3, 3))) check3 = Thing.fetch1() assert_true(check1['number'] == 0 and @@ -58,15 +84,17 @@ def test_update1(): assert_true(check2['number'] == 3 and check2['frac'] == 30.0 and check2['picture'] is not None and - check2['params'] is None) + check2['params'] is None and buffer1==buffer2) assert_true(check3['number'] == 0 and check3['frac'] == 30.0 and check3['picture'] is None and + check3['img_file'] is None and isinstance(check3['params'], np.ndarray)) assert_true(check3['timestamp'] > check2['timestamp']) assert_equal(buffer1, buffer2) + assert_equal(original_file_data, final_file_data) @raises(DataJointError)