diff --git a/cap/modules/deposit/api.py b/cap/modules/deposit/api.py index 44d065d7f5..4bf09f5275 100644 --- a/cap/modules/deposit/api.py +++ b/cap/modules/deposit/api.py @@ -233,6 +233,17 @@ def _publish_edited(self): record = super(CAPDeposit, self)._publish_edited() record._add_deposit_permissions(record, record.id) + with db.session.begin_nested(): + if self.files and record.files.bucket is not None: + # Unlock the record bucket + record.files.bucket.locked = False + # Lock the deposit's files bucket + self.files.bucket.locked = True + # Update the record files with the deposit files + update_record_files(self.files, record.files) + # lock the record bucket after update + record.files.bucket.locked = True + if record["_experiment"]: record._add_experiment_permissions(record, record.id) @@ -862,6 +873,23 @@ def check_data( return data +def update_record_files(deposit_files, record_files): + for key in deposit_files.keys: + deposit_file = deposit_files.__getitem__(key) + deposit_version_id = deposit_file.get_version() + + try: + record_file = record_files.__getitem__(key) + if record_file.get('version_id', None) == deposit_version_id: + continue + except KeyError: + record_file = None + + with deposit_file.obj.file.storage().open() as file_stream: + record_files[key] = file_stream + record_files.flush() + + def has_changed(error, current, new): error_path = get_error_path(error) current_version = get_val_from_path(current, error_path) or None diff --git a/tests/integration/deposits/test_edit_published_deposit.py b/tests/integration/deposits/test_edit_published_deposit.py index bb140e116f..7297f12963 100644 --- a/tests/integration/deposits/test_edit_published_deposit.py +++ b/tests/integration/deposits/test_edit_published_deposit.py @@ -24,6 +24,8 @@ # or submit itself to any jurisdiction. """Integration tests for record edit.""" +from io import BytesIO + ########################################### # api/deposits/{pid}/actions/edit [POST] @@ -130,3 +132,26 @@ def test_edit_record(client, create_deposit, users, auth_headers_for_superuser): .format(depid) } } + + +def test_edit_record_with_uploading_new_files(client, users, auth_headers_for_user, create_deposit): + owner = users['cms_user'] + deposit = create_deposit(owner, 'test-analysis-v0.0.1') + deposit.files['file_1.txt'] = BytesIO(b'Hello world!') + pid = deposit['_deposit']['id'] + + client.post('/deposits/{}/actions/publish'.format(pid), + headers=auth_headers_for_user(owner)) + + client.post('/deposits/{}/actions/edit'.format(pid), + headers=auth_headers_for_user(owner)) + + bucket = deposit.files.bucket + client.put('/files/{}/file_2.txt'.format(bucket), + input_stream=BytesIO(b'Hello brave new world!'), + headers=auth_headers_for_user(owner)) + + resp = client.post('/deposits/{}/actions/publish'.format(pid), + headers=auth_headers_for_user(owner)) + + assert len(resp.json['files']) == 2