Skip to content

Commit

Permalink
Merge pull request #776 from ixcat/issue-666
Browse files Browse the repository at this point in the history
datajoint/table.py: smarter dataframe conversion (#666)
  • Loading branch information
eywalker authored May 15, 2020
2 parents d54c250 + 18beaf2 commit 518b882
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
6 changes: 5 additions & 1 deletion datajoint/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,11 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
"""

if isinstance(rows, pandas.DataFrame):
rows = rows.to_records()
# drop 'extra' synthetic index for 1-field index case -
# frames with more advanced indices should be prepared by user.
rows = rows.reset_index(
drop=len(rows.index.names) == 1 and not rows.index.names[0]
).to_records(index=False)

# prohibit direct inserts into auto-populated tables
if not allow_direct_insert and not getattr(self, '_allow_insert', True): # allow_insert is only used in AutoPopulate
Expand Down
17 changes: 16 additions & 1 deletion tests/test_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def test_insert_select(self):
'real_id', 'date_of_birth', 'subject_notes', subject_id='subject_id+1000', species='"human"'))
assert_equal(len(self.subject), 2*original_length)

def test_insert_pandas(self):
def test_insert_pandas_roundtrip(self):
''' ensure fetched frames can be inserted '''
schema.TTest2.delete()
n = len(schema.TTest())
assert_true(n > 0)
Expand All @@ -113,6 +114,20 @@ def test_insert_pandas(self):
schema.TTest2.insert(df)
assert_equal(len(schema.TTest2()), n)

def test_insert_pandas_userframe(self):
'''
ensure simple user-created frames (1 field, non-custom index)
can be inserted without extra index adjustment
'''
schema.TTest2.delete()
n = len(schema.TTest())
assert_true(n > 0)
df = pandas.DataFrame(schema.TTest.fetch())
assert_true(isinstance(df, pandas.DataFrame))
assert_equal(len(df), n)
schema.TTest2.insert(df)
assert_equal(len(schema.TTest2()), n)

@raises(dj.DataJointError)
def test_insert_select_ignore_extra_fields0(self):
""" need ignore extra fields for insert select """
Expand Down

0 comments on commit 518b882

Please sign in to comment.