Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add method update1 to dj.Table #763

Merged
merged 8 commits into from
May 8, 2020
246 changes: 138 additions & 108 deletions datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,40 @@ def _log(self):
def external(self):
return self.connection.schemas[self.database].external

def update1(self, row):
"""
Update an existing entry in the table.
Caution: Updates are not part of the DataJoint data manipulation model. For strict data integrity,
use delete and insert.
:param row: a dict containing the primary key and the attributes to update.
Setting an attribute value to None will reset it to the default value (if any)
The primary key attributes must always be provided.
Examples:
>>> table.update1({'id': 1, 'value': 3}) # update value in record with id=1
>>> table.update1({'id': 1, 'value': None}) # reset value to default
"""
# argument validations
if not isinstance(row, collections.abc.Mapping):
raise DataJointError('The argument of update1 must be dict-like.')
if not set(row).issuperset(self.primary_key):
raise DataJointError('The argument of update1 must supply all primary key values.')
try:
raise DataJointError('Attribute `%s` not found.' % next(k for k in row if k not in self.heading.names))
except StopIteration:
pass # ok
if len(self.restriction):
raise DataJointError('Update cannot be applied to a restricted table.')
key = {k: row[k] for k in self.primary_key}
if len(self & key) != 1:
raise DataJointError('Update entry must exist.')
# UPDATE query
row = [self.__make_placeholder(k, v) for k, v in row.items() if k not in self.primary_key]
query = "UPDATE {table} SET {assignments} WHERE {where}".format(
table=self.full_table_name,
assignments=",".join('`%s`=%s' % r[:2] for r in row),
where=self._make_condition(key))
self.connection.query(query, args=list(r[2] for r in row if r[2] is not None))

def insert1(self, row, **kwargs):
"""
Insert one data record or one Mapping (like a dict).
Expand Down Expand Up @@ -201,7 +235,6 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
'Inserts into an auto-populated table can only done inside its make method during a populate call.'
' To override, set keyword argument allow_direct_insert=True.')

heading = self.heading
if inspect.isclass(rows) and issubclass(rows, QueryExpression): # instantiate if a class
rows = rows()
if isinstance(rows, QueryExpression):
Expand All @@ -210,10 +243,10 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
try:
raise DataJointError(
"Attribute %s not found. To ignore extra attributes in insert, set ignore_extra_fields=True." %
next(name for name in rows.heading if name not in heading))
next(name for name in rows.heading if name not in self.heading))
except StopIteration:
pass
fields = list(name for name in rows.heading if name in heading)
fields = list(name for name in rows.heading if name in self.heading)
query = '{command} INTO {table} ({fields}) {select}{duplicate}'.format(
command='REPLACE' if replace else 'INSERT',
fields='`' + '`,`'.join(fields) + '`',
Expand All @@ -225,111 +258,8 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
self.connection.query(query)
return

if heading.attributes is None:
logger.warning('Could not access table {table}'.format(table=self.full_table_name))
return

field_list = None # ensures that all rows have the same attributes in the same order as the first row.

def make_row_to_insert(row):
"""
:param row: A tuple to insert
:return: a dict with fields 'names', 'placeholders', 'values'
"""

def make_placeholder(name, value):
"""
For a given attribute `name` with `value`, return its processed value or value placeholder
as a string to be included in the query and the value, if any, to be submitted for
processing by mysql API.
:param name: name of attribute to be inserted
:param value: value of attribute to be inserted
"""
if ignore_extra_fields and name not in heading:
return None
attr = heading[name]
if attr.adapter:
value = attr.adapter.put(value)
if value is None or (attr.numeric and (value == '' or np.isnan(np.float(value)))):
# set default value
placeholder, value = 'DEFAULT', None
else: # not NULL
placeholder = '%s'
if attr.uuid:
if not isinstance(value, uuid.UUID):
try:
value = uuid.UUID(value)
except (AttributeError, ValueError):
raise DataJointError(
'badly formed UUID value {v} for attribute `{n}`'.format(v=value, n=name)) from None
value = value.bytes
elif attr.is_blob:
value = blob.pack(value)
value = self.external[attr.store].put(value).bytes if attr.is_external else value
elif attr.is_attachment:
attachment_path = Path(value)
if attr.is_external:
# value is hash of contents
value = self.external[attr.store].upload_attachment(attachment_path).bytes
else:
# value is filename + contents
value = str.encode(attachment_path.name) + b'\0' + attachment_path.read_bytes()
elif attr.is_filepath:
value = self.external[attr.store].upload_filepath(value).bytes
elif attr.numeric:
value = str(int(value) if isinstance(value, bool) else value)
return name, placeholder, value

def check_fields(fields):
"""
Validates that all items in `fields` are valid attributes in the heading
:param fields: field names of a tuple
"""
if field_list is None:
if not ignore_extra_fields:
for field in fields:
if field not in heading:
raise KeyError(u'`{0:s}` is not in the table heading'.format(field))
elif set(field_list) != set(fields).intersection(heading.names):
raise DataJointError('Attempt to insert rows with different fields')

if isinstance(row, np.void): # np.array
check_fields(row.dtype.fields)
attributes = [make_placeholder(name, row[name])
for name in heading if name in row.dtype.fields]
elif isinstance(row, collections.abc.Mapping): # dict-based
check_fields(row)
attributes = [make_placeholder(name, row[name]) for name in heading if name in row]
else: # positional
try:
if len(row) != len(heading):
raise DataJointError(
'Invalid insert argument. Incorrect number of attributes: '
'{given} given; {expected} expected'.format(
given=len(row), expected=len(heading)))
except TypeError:
raise DataJointError('Datatype %s cannot be inserted' % type(row))
else:
attributes = [make_placeholder(name, value) for name, value in zip(heading, row)]
if ignore_extra_fields:
attributes = [a for a in attributes if a is not None]

assert len(attributes), 'Empty tuple'
row_to_insert = dict(zip(('names', 'placeholders', 'values'), zip(*attributes)))
nonlocal field_list
if field_list is None:
# first row sets the composition of the field list
field_list = row_to_insert['names']
else:
# reorder attributes in row_to_insert to match field_list
order = list(row_to_insert['names'].index(field) for field in field_list)
row_to_insert['names'] = list(row_to_insert['names'][i] for i in order)
row_to_insert['placeholders'] = list(row_to_insert['placeholders'][i] for i in order)
row_to_insert['values'] = list(row_to_insert['values'][i] for i in order)

return row_to_insert

rows = list(make_row_to_insert(row) for row in rows)
field_list = [] # collects the field list from first row (passed by reference)
rows = list(self.__make_row_to_insert(row, field_list, ignore_extra_fields) for row in rows)
if rows:
try:
query = "{command} INTO {destination}(`{fields}`) VALUES {placeholders}{duplicate}".format(
Expand Down Expand Up @@ -598,6 +528,106 @@ def _update(self, attrname, value=None):
where_clause=self.where_clause)
self.connection.query(command, args=(value, ) if value is not None else ())

# --- private helper functions ----
def __make_placeholder(self, name, value, ignore_extra_fields=False):
"""
For a given attribute `name` with `value`, return its processed value or value placeholder
as a string to be included in the query and the value, if any, to be submitted for
processing by mysql API.
:param name: name of attribute to be inserted
:param value: value of attribute to be inserted
"""
if ignore_extra_fields and name not in self.heading:
return None
attr = self.heading[name]
if attr.adapter:
value = attr.adapter.put(value)
if value is None or (attr.numeric and (value == '' or np.isnan(np.float(value)))):
# set default value
placeholder, value = 'DEFAULT', None
else: # not NULL
placeholder = '%s'
if attr.uuid:
if not isinstance(value, uuid.UUID):
try:
value = uuid.UUID(value)
except (AttributeError, ValueError):
raise DataJointError(
'badly formed UUID value {v} for attribute `{n}`'.format(v=value, n=name)) from None
value = value.bytes
elif attr.is_blob:
value = blob.pack(value)
value = self.external[attr.store].put(value).bytes if attr.is_external else value
elif attr.is_attachment:
attachment_path = Path(value)
if attr.is_external:
# value is hash of contents
value = self.external[attr.store].upload_attachment(attachment_path).bytes
else:
# value is filename + contents
value = str.encode(attachment_path.name) + b'\0' + attachment_path.read_bytes()
elif attr.is_filepath:
value = self.external[attr.store].upload_filepath(value).bytes
elif attr.numeric:
value = str(int(value) if isinstance(value, bool) else value)
return name, placeholder, value

def __make_row_to_insert(self, row, field_list, ignore_extra_fields):
"""
Helper function for insert and update
:param row: A tuple to insert
:return: a dict with fields 'names', 'placeholders', 'values'
"""

def check_fields(fields):
"""
Validates that all items in `fields` are valid attributes in the heading
:param fields: field names of a tuple
"""
if not field_list:
if not ignore_extra_fields:
for field in fields:
if field not in self.heading:
raise KeyError(u'`{0:s}` is not in the table heading'.format(field))
elif set(field_list) != set(fields).intersection(self.heading.names):
raise DataJointError('Attempt to insert rows with different fields')

if isinstance(row, np.void): # np.array
check_fields(row.dtype.fields)
attributes = [self.__make_placeholder(name, row[name], ignore_extra_fields)
for name in self.heading if name in row.dtype.fields]
elif isinstance(row, collections.abc.Mapping): # dict-based
check_fields(row)
attributes = [self.__make_placeholder(name, row[name], ignore_extra_fields)
for name in self.heading if name in row]
else: # positional
try:
if len(row) != len(self.heading):
raise DataJointError(
'Invalid insert argument. Incorrect number of attributes: '
'{given} given; {expected} expected'.format(
given=len(row), expected=len(self.heading)))
except TypeError:
raise DataJointError('Datatype %s cannot be inserted' % type(row))
else:
attributes = [self.__make_placeholder(name, value, ignore_extra_fields)
for name, value in zip(self.heading, row)]
if ignore_extra_fields:
attributes = [a for a in attributes if a is not None]

assert len(attributes), 'Empty tuple'
row_to_insert = dict(zip(('names', 'placeholders', 'values'), zip(*attributes)))
if not field_list:
# first row sets the composition of the field list
field_list.extend(row_to_insert['names'])
else:
# reorder attributes in row_to_insert to match field_list
order = list(row_to_insert['names'].index(field) for field in field_list)
row_to_insert['names'] = list(row_to_insert['names'][i] for i in order)
row_to_insert['placeholders'] = list(row_to_insert['placeholders'][i] for i in order)
row_to_insert['values'] = list(row_to_insert['values'][i] for i in order)
return row_to_insert


def lookup_class_name(name, context, depth=3):
"""
Expand Down
2 changes: 1 addition & 1 deletion datajoint/user_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
supported_class_attrs = {
'key_source', 'describe', 'alter', 'heading', 'populate', 'progress', 'primary_key', 'proj', 'aggr',
'fetch', 'fetch1', 'head', 'tail',
'insert', 'insert1', 'drop', 'drop_quick', 'delete', 'delete_quick'}
'insert', 'insert1', 'update1', 'drop', 'drop_quick', 'delete', 'delete_quick'}


class OrderedClass(type):
Expand Down
1 change: 0 additions & 1 deletion tests/test_attach.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from nose.tools import assert_true, assert_equal, assert_not_equal
from numpy.testing import assert_array_equal
import tempfile
from pathlib import Path
import os
Expand Down
1 change: 1 addition & 0 deletions tests/test_fetch_same.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import datajoint as dj

schema = dj.Schema(PREFIX + '_fetch_same', connection=dj.conn(**CONN_INFO))
dj.config['enable_python_native_blobs'] = True


@schema
Expand Down
9 changes: 2 additions & 7 deletions tests/test_reconnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
Collection of test cases to test connection module.
"""

from nose.tools import assert_true, assert_false, assert_equal, raises
from nose.tools import assert_true, assert_false, raises
import datajoint as dj
import numpy as np
from datajoint import DataJointError
from . import CONN_INFO, PREFIX

from . import CONN_INFO


class TestReconnect:
Expand All @@ -18,20 +16,17 @@ class TestReconnect:
def setup(self):
self.conn = dj.conn(reset=True, **CONN_INFO)


def test_close(self):
assert_true(self.conn.is_connected, "Connection should be alive")
self.conn.close()
assert_false(self.conn.is_connected, "Connection should now be closed")


def test_reconnect(self):
assert_true(self.conn.is_connected, "Connection should be alive")
self.conn.close()
self.conn.query('SHOW DATABASES;', reconnect=True).fetchall()
assert_true(self.conn.is_connected, "Connection should be alive")


@raises(DataJointError)
def test_reconnect_throws_error_in_transaction(self):
assert_true(self.conn.is_connected, "Connection should be alive")
Expand Down
Loading