diff --git a/hoomd/GSDDequeWriter.cc b/hoomd/GSDDequeWriter.cc index dae062c87d..04f76892d1 100644 --- a/hoomd/GSDDequeWriter.cc +++ b/hoomd/GSDDequeWriter.cc @@ -14,11 +14,13 @@ GSDDequeWriter::GSDDequeWriter(std::shared_ptr sysdef, int queue_size, std::string mode, bool write_at_init, + bool clear_whole_buffer_after_dump, uint64_t timestep) : GSDDumpWriter(sysdef, trigger, fname, group, mode), m_queue_size(queue_size) { setLogWriter(logger); bool file_empty = true; + m_clear_whole_buffer_after_dump = true; #ifdef ENABLE_MPI if (m_sysdef->isDomainDecomposed()) { @@ -42,9 +44,10 @@ GSDDequeWriter::GSDDequeWriter(std::shared_ptr sysdef, else { analyze(timestep); - dump(); + dump(0, -1); } } + setClearWholeBufferAfterDump(clear_whole_buffer_after_dump); } void GSDDequeWriter::analyze(uint64_t timestep) @@ -59,14 +62,42 @@ void GSDDequeWriter::analyze(uint64_t timestep) } } -void GSDDequeWriter::dump() +void GSDDequeWriter::dump(long int start, long int end) { - for (auto i {static_cast(m_frame_queue.size()) - 1}; i >= 0; --i) + auto buffer_length = static_cast(m_frame_queue.size()); + if (end > buffer_length) + { + throw std::runtime_error("Burst.dump's end index is out of range."); + } + if (start < 0 || start > buffer_length) + { + throw std::runtime_error("Burst.dump's start index is out of range."); + } + long int iterator_start, iterator_end; + if (end < 0) + { + iterator_end = buffer_length - start; + iterator_start = 0; + } + else + { + iterator_end = buffer_length - start; + iterator_start = buffer_length - end; + } + for (auto i = iterator_end - 1; i >= iterator_start; --i) { write(m_frame_queue[i], m_log_queue[i]); } - m_frame_queue.clear(); - m_log_queue.clear(); + if (m_clear_whole_buffer_after_dump) + { + m_frame_queue.clear(); + m_log_queue.clear(); + } + else + { + m_frame_queue.erase(m_frame_queue.begin() + iterator_start, m_frame_queue.end()); + m_log_queue.erase(m_log_queue.begin() + iterator_start, m_log_queue.end()); + } } int GSDDequeWriter::getMaxQueueSize() const @@ -93,6 +124,16 @@ void GSDDequeWriter::setMaxQueueSize(int new_max_size) } } +bool GSDDequeWriter::getClearWholeBufferAfterDump() const + { + return m_clear_whole_buffer_after_dump; + } + +void GSDDequeWriter::setClearWholeBufferAfterDump(bool clear_whole_buffer_after_dump) + { + m_clear_whole_buffer_after_dump = clear_whole_buffer_after_dump; + } + namespace detail { void export_GSDDequeWriter(pybind11::module& m) @@ -108,10 +149,14 @@ void export_GSDDequeWriter(pybind11::module& m) int, std::string, bool, + bool, uint64_t>()) .def_property("max_burst_size", &GSDDequeWriter::getMaxQueueSize, &GSDDequeWriter::setMaxQueueSize) + .def_property("clear_whole_buffer_after_dump", + &GSDDequeWriter::getClearWholeBufferAfterDump, + &GSDDequeWriter::setClearWholeBufferAfterDump) .def("__len__", &GSDDequeWriter::getCurrentQueueSize) .def("dump", &GSDDequeWriter::dump); } diff --git a/hoomd/GSDDequeWriter.h b/hoomd/GSDDequeWriter.h index 5221eefdde..d22432bf2d 100644 --- a/hoomd/GSDDequeWriter.h +++ b/hoomd/GSDDequeWriter.h @@ -26,20 +26,24 @@ class PYBIND11_EXPORT GSDDequeWriter : public GSDDumpWriter int queue_size, std::string mode, bool write_on_init, + bool clear_whole_buffer_after_dump, uint64_t timestep); ~GSDDequeWriter() = default; void analyze(uint64_t timestep) override; - void dump(); + void dump(long int start, long int end); int getMaxQueueSize() const; void setMaxQueueSize(int new_max_size); + bool getClearWholeBufferAfterDump() const; + void setClearWholeBufferAfterDump(bool clear_whole_buffer_after_dump); size_t getCurrentQueueSize() const; protected: int m_queue_size; + bool m_clear_whole_buffer_after_dump; std::deque m_frame_queue; std::deque m_log_queue; }; diff --git a/hoomd/md/pytest/test_burst_writer.py b/hoomd/md/pytest/test_burst_writer.py index e9163985ab..ae0f2de699 100644 --- a/hoomd/md/pytest/test_burst_writer.py +++ b/hoomd/md/pytest/test_burst_writer.py @@ -145,7 +145,9 @@ def test_len(sim, tmp_path): assert len(burst_writer) == 0 -def test_burst_dump(sim, tmp_path): +@pytest.mark.parametrize("start, end", [(0, -1), (0, 0), (0, 1), (0, 2), (1, 1), + (2, 2), (1, 2), (1, -1), (2, -1)]) +def test_burst_dump(sim, tmp_path, start, end): filename = tmp_path / "temporary_test_file.gsd" burst_trigger = hoomd.trigger.Periodic(period=2, phase=1) @@ -164,11 +166,59 @@ def test_burst_dump(sim, tmp_path): # First frame is always written assert len(traj) == 1 + burst_writer.dump(start=start, end=end) + burst_writer.flush() + dumped_frames = [3, 5, 7] + if sim.device.communicator.rank == 0: + if end == -1: + end = len(dumped_frames) + with gsd.hoomd.open(name=filename, mode='r') as traj: + assert [frame.configuration.step for frame in traj + ] == [0] + dumped_frames[start:end] + + +@pytest.mark.parametrize("clear_entire_buffer", [True, False]) +def test_burst_dump_with_clear_buffer(sim, tmp_path, clear_entire_buffer): + filename = tmp_path / "temporary_test_file.gsd" + start_frame = 1 + end_frame = 3 + burst_trigger = hoomd.trigger.Periodic(period=2, phase=1) + burst_writer = hoomd.write.Burst( + trigger=burst_trigger, + filename=filename, + mode='wb', + dynamic=['property', 'momentum'], + max_burst_size=4, + write_at_start=True, + clear_whole_buffer_after_dump=clear_entire_buffer) + sim.operations.writers.append(burst_writer) + sim.run(12) + burst_writer.flush() + if sim.device.communicator.rank == 0: + assert Path(filename).exists() + with gsd.hoomd.open(filename, "r") as traj: + # First frame is always written + assert len(traj) == 1 + + burst_writer.dump(start_frame, end_frame) + burst_writer.flush() + dumped_frames = [0, 7, 9] + if sim.device.communicator.rank == 0: + with gsd.hoomd.open(name=filename, mode='r') as traj: + print([frame.configuration.step for frame in traj]) + assert [frame.configuration.step for frame in traj] == dumped_frames + + sim.run(4) burst_writer.dump() burst_writer.flush() + if clear_entire_buffer: + dumped_frames += [13, 15] + else: + dumped_frames += [11, 13, 15] if sim.device.communicator.rank == 0: with gsd.hoomd.open(name=filename, mode='r') as traj: - assert [frame.configuration.step for frame in traj] == [0, 3, 5, 7] + print([frame.configuration.step for frame in traj]) + assert [frame.configuration.step for frame in traj] == dumped_frames def test_burst_max_size(sim, tmp_path): @@ -242,3 +292,38 @@ def test_write_burst_log(sim, tmp_path): with gsd.hoomd.open(name=filename, mode='r') as traj: for frame, sim_ke in zip(traj[1:], kinetic_energies): assert frame.log[key] == sim_ke + + +@pytest.mark.parametrize("clear_entire_buffer", [True, False]) +def test_burst_dump_empty_buffer(sim, tmp_path, clear_entire_buffer): + filename = tmp_path / "temporary_test_file.gsd" + burst_trigger = hoomd.trigger.Periodic(period=2, phase=1) + burst_writer = hoomd.write.Burst( + trigger=burst_trigger, + filename=filename, + mode='wb', + dynamic=['property', 'momentum'], + max_burst_size=3, + write_at_start=True, + clear_whole_buffer_after_dump=clear_entire_buffer) + sim.operations.writers.append(burst_writer) + sim.run(8) + burst_writer.flush() + if sim.device.communicator.rank == 0: + assert Path(filename).exists() + with gsd.hoomd.open(filename, "r") as traj: + # First frame is always written + assert len(traj) == 1 + + burst_writer.dump(1, 2) + burst_writer.flush() + if sim.device.communicator.rank == 0: + with gsd.hoomd.open(name=filename, mode='r') as traj: + assert len(traj) == 2 + + sim.run(4) + burst_writer.dump() + burst_writer.flush() + if sim.device.communicator.rank == 0: + with gsd.hoomd.open(name=filename, mode='r') as traj: + assert len(traj) == (4 if clear_entire_buffer else 5) diff --git a/hoomd/write/gsd_burst.py b/hoomd/write/gsd_burst.py index a982596dc8..94d4c0c09e 100644 --- a/hoomd/write/gsd_burst.py +++ b/hoomd/write/gsd_burst.py @@ -39,6 +39,10 @@ class Burst(GSD): write_at_start (bool): When ``True`` **and** the file does not exist or has 0 frames: write one frame with the current state of the system when `hoomd.Simulation.run` is called. Defaults to ``False``. + clear_whole_buffer_after_dump (bool): When ``True`` the buffer is + emptied after calling `dump` each time. When ``False``, `dump` + removes frames from the buffer unil the ``end`` index. Defaults + to ``True``. Warning: `Burst` errors when attempting to create a file or writing to one with @@ -80,6 +84,16 @@ class Burst(GSD): .. code-block:: python write_at_start = burst.write_at_start + + clear_whole_buffer_after_dump (bool): When ``True`` the buffer is + emptied after calling `dump` each time. When ``False``, `dump` + removes frames from the buffer unil the ``end`` index. + + .. rubric:: Example: + + .. code-block:: python + + burst.clear_buffer_after_dump = False """ def __init__(self, @@ -90,7 +104,8 @@ def __init__(self, dynamic=None, logger=None, max_burst_size=-1, - write_at_start=False): + write_at_start=False, + clear_whole_buffer_after_dump=True): super().__init__(trigger=trigger, filename=filename, filter=filter, @@ -102,24 +117,29 @@ def __init__(self, ParameterDict(max_burst_size=int, write_at_start=bool)) self._param_dict.update({ "max_burst_size": max_burst_size, - "write_at_start": write_at_start + "write_at_start": write_at_start, + "clear_whole_buffer_after_dump": clear_whole_buffer_after_dump }) def _attach_hook(self): sim = self._simulation - self._cpp_obj = _hoomd.GSDDequeWriter(sim.state._cpp_sys_def, - self.trigger, self.filename, - sim.state._get_group(self.filter), - self.logger, self.max_burst_size, - self.mode, self.write_at_start, - sim.timestep) + self._cpp_obj = _hoomd.GSDDequeWriter( + sim.state._cpp_sys_def, self.trigger, self.filename, + sim.state._get_group(self.filter), self.logger, self.max_burst_size, + self.mode, self.write_at_start, self.clear_whole_buffer_after_dump, + sim.timestep) - def dump(self): - """Write all currently stored frames to the file and empties the buffer. + def dump(self, start=0, end=-1): + """Write stored frames in range to the file and empties the buffer. This method alllows for custom writing of frames at user specified conditions. + Args: + start (int): The first frame to write. Defaults to 0. + end (int): The last frame to write. + Defaults to -1 (last frame). + .. rubric:: Example: .. code-block:: python @@ -127,7 +147,7 @@ def dump(self): burst.dump() """ if self._attached: - self._cpp_obj.dump() + self._cpp_obj.dump(start, end) def __len__(self): """Get the current length of the internal frame buffer.