Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the node repository API backend agnostic #2506

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions .ci/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,24 @@ def validate_cached(cached_calcs):
valid = False

if isinstance(calc, CalcJobNode):
if 'raw_input' not in calc.repository._get_folder_pathsubfolder.get_content_list():
print("Cached calculation <{}> does not have a 'raw_input' folder".format(calc.pk))
original_calc = load_node(calc.get_extra('_aiida_cached_from'))
files_original = original_calc.list_object_names()
files_cached = calc.list_object_names()

if not files_cached:
print("Cached calculation <{}> does not have any raw inputs files".format(calc.pk))
print_report(calc.pk)
valid = False
original_calc = load_node(calc.get_extra('_aiida_cached_from'))
if 'raw_input' not in original_calc.repository._get_folder_pathsubfolder.get_content_list():
print("Original calculation <{}> does not have a 'raw_input' folder after being cached from."
if not files_original:
print("Original calculation <{}> does not have any raw inputs files after being cached from."
.format(original_calc.pk))
valid = False

if set(files_original) != set(files_cached):
print("different raw input files [{}] vs [{}] for original<{}> and cached<{}> calculation".format(
set(files_original), set(files_cached), original_calc.pk, calc.pk))
valid = False

return valid


Expand Down
12 changes: 0 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,9 @@
aiida/orm/nodes/data/array/kpoints.py|
aiida/orm/nodes/data/array/projection.py|
aiida/orm/nodes/data/array/xy.py|
aiida/orm/nodes/data/base.py|
aiida/orm/nodes/data/bool.py|
aiida/orm/nodes/data/error.py|
aiida/orm/nodes/data/float.py|
aiida/orm/nodes/data/folder.py|
aiida/orm/nodes/data/frozendict.py|
aiida/orm/nodes/data/int.py|
aiida/orm/nodes/data/list.py|
aiida/orm/nodes/data/code.py|
aiida/orm/nodes/data/numeric.py|
aiida/orm/nodes/data/orbital.py|
aiida/orm/nodes/data/parameter.py|
aiida/orm/nodes/data/remote.py|
aiida/orm/nodes/data/singlefile.py|
aiida/orm/nodes/data/str.py|
aiida/orm/nodes/data/structure.py|
aiida/orm/nodes/data/upf.py|
aiida/orm/nodes/process/calculation/calcjob.py|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def delete_trajectory_symbols_array(apps, _):
for t_pk in trajectories_pk:
trajectory = load_node(t_pk)
modifier.del_value_for_node(DbNode.objects.get(pk=trajectory.pk), 'array|symbols')
# Remove the .npy file (using delete_array raises ModificationNotAllowed error)
trajectory.repository._get_folder_pathsubfolder.remove_path('symbols.npy') # pylint: disable=protected-access
trajectory.delete_object('symbols.npy', force=True)


def create_trajectory_symbols_array(apps, _):
Expand All @@ -63,10 +62,11 @@ def create_trajectory_symbols_array(apps, _):
trajectory = load_node(t_pk)
symbols = numpy.array(trajectory.get_attribute('symbols'))
# Save the .npy file (using set_array raises ModificationNotAllowed error)
with tempfile.NamedTemporaryFile() as _file:
numpy.save(_file, symbols)
_file.flush()
trajectory.repository._get_folder_pathsubfolder.insert_path(_file.name, 'symbols.npy') # pylint: disable=protected-access
with tempfile.NamedTemporaryFile() as handle:
numpy.save(handle, symbols)
handle.flush()
handle.seek(0)
trajectory.put_object_from_filelike(handle, 'symbols.npy')
modifier.set_value_for_node(DbNode.objects.get(pk=trajectory.pk), 'array|symbols', list(symbols.shape))


Expand Down
10 changes: 5 additions & 5 deletions aiida/backends/djsite/db/subtests/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ def setUpBeforeMigration(self):
self.n_int_duplicates = 4

node_bool = Bool(True)
node_bool.repository.add_path(handle.name, self.file_name)
node_bool.put_object_from_file(handle.name, self.file_name)
node_bool.store()

node_int = Int(1)
node_int.repository.add_path(handle.name, self.file_name)
node_int.put_object_from_file(handle.name, self.file_name)
node_int.store()

self.nodes_boolean.append(node_bool)
Expand All @@ -125,14 +125,14 @@ def setUpBeforeMigration(self):
for i in range(self.n_bool_duplicates):
node = Bool(True)
node.backend_entity.dbmodel.uuid = node_bool.uuid
node.repository.add_path(handle.name, self.file_name)
node.put_object_from_file(handle.name, self.file_name)
node.store()
self.nodes_boolean.append(node)

for i in range(self.n_int_duplicates):
node = Int(1)
node.backend_entity.dbmodel.uuid = node_int.uuid
node.repository.add_path(handle.name, self.file_name)
node.put_object_from_file(handle.name, self.file_name)
node.store()
self.nodes_integer.append(node)

Expand Down Expand Up @@ -161,7 +161,7 @@ def test_deduplicated_uuids(self):
self.assertEqual(len(set(uuids_integer)), len(nodes_integer))

for node in nodes_boolean:
with node.repository._get_folder_pathsubfolder.open(self.file_name) as handle:
with node.open(self.file_name) as handle:
content = handle.read()
self.assertEqual(content, self.file_content)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def upgrade():
for t in trajectories:
del t.attributes['array|symbols']
flag_modified(t, 'attributes')
# Remove the .npy file (using delete_array raises ModificationNotAllowed error)
load_node(pk=t.id).repository._get_folder_pathsubfolder.remove_path('symbols.npy') # pylint: disable=protected-access
load_node(pk=t.id).delete_object('symbols.npy', force=True)
session.add(t)
session.commit()
session.close()
Expand All @@ -52,10 +51,11 @@ def downgrade():
for t in trajectories:
symbols = numpy.array(t.get_attribute('symbols'))
# Save the .npy file (using set_array raises ModificationNotAllowed error)
with tempfile.NamedTemporaryFile() as _file:
numpy.save(_file, symbols)
_file.flush()
load_node(pk=t.id).repository.insert_path(_file.name, 'symbols.npy')
with tempfile.NamedTemporaryFile() as handle:
numpy.save(handle, symbols)
handle.flush()
handle.seek(0)
load_node(pk=t.id).put_object_from_filelike(handle, 'symbols.npy')
t.attributes['array|symbols'] = list(symbols.shape)
flag_modified(t, 'attributes')
session.add(t)
Expand Down
1 change: 1 addition & 0 deletions aiida/backends/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
'orm.utils.calcjob': ['aiida.backends.tests.orm.utils.test_calcjob'],
'orm.utils.node': ['aiida.backends.tests.orm.utils.test_node'],
'orm.utils.loaders': ['aiida.backends.tests.orm.utils.test_loaders'],
'orm.utils.repository': ['aiida.backends.tests.orm.utils.test_repository'],
'work.calcfunctions': ['aiida.backends.tests.work.test_calcfunctions'],
'work.class_loader': ['aiida.backends.tests.work.test_class_loader'],
'work.daemon': ['aiida.backends.tests.work.test_daemon'],
Expand Down
2 changes: 1 addition & 1 deletion aiida/backends/tests/cmdline/commands/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def create_cif_data(cls):
filename = fhandle.name
fhandle.write(cls.valid_sample_cif_str)
fhandle.flush()
a = CifData(file=filename, source={'version': '1234', 'db_name': 'COD', 'id': '0000001'})
a = CifData(filepath=filename, source={'version': '1234', 'db_name': 'COD', 'id': '0000001'})
a.store()

g_ne = Group(label='non_empty_group')
Expand Down
2 changes: 1 addition & 1 deletion aiida/backends/tests/cmdline/commands/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def create_cif_data(cls):
filename = fhandle.name
fhandle.write(cls.valid_sample_cif_str)
fhandle.flush()
a = CifData(file=filename, source={'version': '1234', 'db_name': 'COD', 'id': '0000001'})
a = CifData(filepath=filename, source={'version': '1234', 'db_name': 'COD', 'id': '0000001'})
a.store()

g_ne = orm.Group(label='non_empty_group')
Expand Down
2 changes: 1 addition & 1 deletion aiida/backends/tests/cmdline/commands/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ def wf():
# Verify that the node has the correct function name and content
self.assertTrue(isinstance(node, WorkFunctionNode))
self.assertEqual(node.function_name, 'wf')
self.assertEqual(open(node.function_source_file, 'r').read(), script_content)
self.assertEqual(node.get_function_source_code(), script_content)
63 changes: 56 additions & 7 deletions aiida/backends/tests/common/test_folders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,18 @@
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import sys

import io
import os
import sys
import shutil
import tempfile
import unittest

import six

from aiida.common.folders import Folder


def fs_encoding_is_utf8():
"""
Expand All @@ -39,10 +47,6 @@ def test_unicode(cls):
Check that there are no exceptions raised when
using unicode folders.
"""
from aiida.common.folders import Folder
import os
import tempfile

tmpsource = tempfile.mkdtemp()
tmpdest = tempfile.mkdtemp()
with io.open(os.path.join(tmpsource, "sąžininga"), 'w', encoding='utf8') as fhandle:
Expand All @@ -61,8 +65,53 @@ def test_get_abs_path_without_limit(self):
"""
Check that the absolute path function can get an absolute path
"""
from aiida.common.folders import Folder

folder = Folder('/tmp')
# Should not raise any exception
self.assertEqual(folder.get_abs_path('test_file.txt'), '/tmp/test_file.txt')

@staticmethod
@unittest.skipUnless(six.PY2, 'test is only for python2')
def test_create_file_from_filelike_py2():
"""Test `aiida.common.folders.Folder.create_file_from_filelike` for python 2."""
unicode_string = u'unicode_string'
byte_string = 'byte_string'

try:
tempdir = tempfile.mkdtemp()
folder = Folder(tempdir)

# Passing a stream with matching file mode should work ofcourse
folder.create_file_from_filelike(six.StringIO(unicode_string), 'random.dat', mode='w', encoding='utf-8')
folder.create_file_from_filelike(six.StringIO(byte_string), 'random.dat', mode='wb', encoding=None)

# For python 2 the `create_file_from_filelike` should be able to deal with incoherent arguments, such as
# the examples below where a unicode string is passed with a binary mode, or a byte stream in unicode mode.
folder.create_file_from_filelike(six.StringIO(unicode_string), 'random.dat', mode='wb', encoding=None)
folder.create_file_from_filelike(six.StringIO(byte_string), 'random.dat', mode='w', encoding='utf-8')

finally:
shutil.rmtree(tempdir)

@unittest.skipUnless(six.PY3, 'test is only for python3')
def test_create_file_from_filelike_py3(self):
"""Test `aiida.common.folders.Folder.create_file_from_filelike` for python 3."""
unicode_string = 'unicode_string'
byte_string = b'byte_string'

try:
tempdir = tempfile.mkdtemp()
folder = Folder(tempdir)

folder.create_file_from_filelike(six.StringIO(unicode_string), 'random.dat', mode='w', encoding='utf-8')
folder.create_file_from_filelike(six.BytesIO(byte_string), 'random.dat', mode='wb', encoding=None)

# For python three we make no exceptions, if you pass a unicode stream with binary mode, one should expect
# a TypeError. Same for the inverse case of wanting to write in unicode mode but passing a byte stream
with self.assertRaises(TypeError):
folder.create_file_from_filelike(six.StringIO(unicode_string), 'random.dat', mode='wb')

with self.assertRaises(TypeError):
folder.create_file_from_filelike(six.BytesIO(byte_string), 'random.dat', mode='w')

finally:
shutil.rmtree(tempdir)
2 changes: 1 addition & 1 deletion aiida/backends/tests/orm/node/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setUp(self):
def test_repository_garbage_collection(self):
"""Verify that the repository sandbox folder is cleaned after the node instance is garbage collected."""
node = Data()
dirpath = node.repository.folder.abspath
dirpath = node._repository._get_temp_folder().abspath # pylint: disable=protected-access

self.assertTrue(os.path.isdir(dirpath))
del node
Expand Down
116 changes: 116 additions & 0 deletions aiida/backends/tests/orm/utils/test_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for the `Repository` utility class."""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import io
import os
import shutil
import tempfile

from aiida.backends.testbase import AiidaTestCase
from aiida.orm import Node


class TestRepository(AiidaTestCase):
"""Tests for the node `Repository` utility class."""

def setUp(self):
"""Create a dummy file tree."""
self.tempdir = tempfile.mkdtemp()
self.tree = {
'subdir': {
'a.txt': u'Content of file A\nWith some newlines',
'b.txt': u'Content of file B without newline',
},
'c.txt': u'Content of file C\n',
}

self.create_file_tree(self.tempdir, self.tree)

def tearDown(self):
shutil.rmtree(self.tempdir)

def create_file_tree(self, directory, tree):
"""Create a file tree in the given directory.

:param directory: the absolute path of the directory into which to create the tree
:param tree: a dictionary representing the tree structure
"""
for key, value in tree.items():
if isinstance(value, dict):
subdir = os.path.join(directory, key)
os.makedirs(subdir)
self.create_file_tree(subdir, value)
else:
with io.open(os.path.join(directory, key), 'w', encoding='utf8') as handle:
handle.write(value)

def get_file_content(self, key):
"""Get the content of a file for a given key.

:param key: the nested key of the file to retrieve
:return: the content of the file
"""
parts = key.split(os.sep)
content = self.tree
for part in parts:
content = content[part]

return content

def test_put_object_from_filelike(self):
"""Test the `put_object_from_filelike` method."""
key = os.path.join('subdir', 'a.txt')
filepath = os.path.join(self.tempdir, key)
content = self.get_file_content(key)

with io.open(filepath, 'r') as handle:
node = Node()
node.put_object_from_filelike(handle, key)
self.assertEqual(node.get_object_content(key), content)

def test_put_object_from_file(self):
"""Test the `put_object_from_file` method."""
key = os.path.join('subdir', 'a.txt')
filepath = os.path.join(self.tempdir, key)
content = self.get_file_content(key)

node = Node()
node.put_object_from_file(filepath, key)
self.assertEqual(node.get_object_content(key), content)

def test_put_object_from_tree(self):
"""Test the `put_object_from_tree` method."""
basepath = ''
node = Node()
node.put_object_from_tree(self.tempdir, basepath)

key = os.path.join('subdir', 'a.txt')
content = self.get_file_content(key)
self.assertEqual(node.get_object_content(key), content)

basepath = 'base'
node = Node()
node.put_object_from_tree(self.tempdir, basepath)

key = os.path.join(basepath, 'subdir', 'a.txt')
content = self.get_file_content(os.path.join('subdir', 'a.txt'))
self.assertEqual(node.get_object_content(key), content)

basepath = 'base/further/nested'
node = Node()
node.put_object_from_tree(self.tempdir, basepath)

key = os.path.join(basepath, 'subdir', 'a.txt')
content = self.get_file_content(os.path.join('subdir', 'a.txt'))
self.assertEqual(node.get_object_content(key), content)
Loading