Skip to content

Commit

Permalink
Merge pull request #1870 from glotzerlab/feat_burst_dump_range
Browse files Browse the repository at this point in the history
Enable dumping only part of the buffer for burst writer.
  • Loading branch information
joaander committed Aug 27, 2024
2 parents 38d6e6a + bb6c734 commit bcc61f0
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 19 deletions.
55 changes: 50 additions & 5 deletions hoomd/GSDDequeWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ GSDDequeWriter::GSDDequeWriter(std::shared_ptr<SystemDefinition> sysdef,
int queue_size,
std::string mode,
bool write_at_init,
bool clear_whole_buffer_after_dump,
uint64_t timestep)
: GSDDumpWriter(sysdef, trigger, fname, group, mode), m_queue_size(queue_size)
{
setLogWriter(logger);
bool file_empty = true;
m_clear_whole_buffer_after_dump = true;
#ifdef ENABLE_MPI
if (m_sysdef->isDomainDecomposed())
{
Expand All @@ -42,9 +44,10 @@ GSDDequeWriter::GSDDequeWriter(std::shared_ptr<SystemDefinition> sysdef,
else
{
analyze(timestep);
dump();
dump(0, -1);
}
}
setClearWholeBufferAfterDump(clear_whole_buffer_after_dump);
}

void GSDDequeWriter::analyze(uint64_t timestep)
Expand All @@ -59,14 +62,42 @@ void GSDDequeWriter::analyze(uint64_t timestep)
}
}

void GSDDequeWriter::dump()
void GSDDequeWriter::dump(long int start, long int end)
{
for (auto i {static_cast<long int>(m_frame_queue.size()) - 1}; i >= 0; --i)
auto buffer_length = static_cast<long int>(m_frame_queue.size());
if (end > buffer_length)
{
throw std::runtime_error("Burst.dump's end index is out of range.");
}
if (start < 0 || start > buffer_length)
{
throw std::runtime_error("Burst.dump's start index is out of range.");
}
long int iterator_start, iterator_end;
if (end < 0)
{
iterator_end = buffer_length - start;
iterator_start = 0;
}
else
{
iterator_end = buffer_length - start;
iterator_start = buffer_length - end;
}
for (auto i = iterator_end - 1; i >= iterator_start; --i)
{
write(m_frame_queue[i], m_log_queue[i]);
}
m_frame_queue.clear();
m_log_queue.clear();
if (m_clear_whole_buffer_after_dump)
{
m_frame_queue.clear();
m_log_queue.clear();
}
else
{
m_frame_queue.erase(m_frame_queue.begin() + iterator_start, m_frame_queue.end());
m_log_queue.erase(m_log_queue.begin() + iterator_start, m_log_queue.end());
}
}

int GSDDequeWriter::getMaxQueueSize() const
Expand All @@ -93,6 +124,16 @@ void GSDDequeWriter::setMaxQueueSize(int new_max_size)
}
}

bool GSDDequeWriter::getClearWholeBufferAfterDump() const
{
return m_clear_whole_buffer_after_dump;
}

void GSDDequeWriter::setClearWholeBufferAfterDump(bool clear_whole_buffer_after_dump)
{
m_clear_whole_buffer_after_dump = clear_whole_buffer_after_dump;
}

namespace detail
{
void export_GSDDequeWriter(pybind11::module& m)
Expand All @@ -108,10 +149,14 @@ void export_GSDDequeWriter(pybind11::module& m)
int,
std::string,
bool,
bool,
uint64_t>())
.def_property("max_burst_size",
&GSDDequeWriter::getMaxQueueSize,
&GSDDequeWriter::setMaxQueueSize)
.def_property("clear_whole_buffer_after_dump",
&GSDDequeWriter::getClearWholeBufferAfterDump,
&GSDDequeWriter::setClearWholeBufferAfterDump)
.def("__len__", &GSDDequeWriter::getCurrentQueueSize)
.def("dump", &GSDDequeWriter::dump);
}
Expand Down
6 changes: 5 additions & 1 deletion hoomd/GSDDequeWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,24 @@ class PYBIND11_EXPORT GSDDequeWriter : public GSDDumpWriter
int queue_size,
std::string mode,
bool write_on_init,
bool clear_whole_buffer_after_dump,
uint64_t timestep);
~GSDDequeWriter() = default;

void analyze(uint64_t timestep) override;

void dump();
void dump(long int start, long int end);

int getMaxQueueSize() const;
void setMaxQueueSize(int new_max_size);
bool getClearWholeBufferAfterDump() const;
void setClearWholeBufferAfterDump(bool clear_whole_buffer_after_dump);

size_t getCurrentQueueSize() const;

protected:
int m_queue_size;
bool m_clear_whole_buffer_after_dump;
std::deque<GSDDumpWriter::GSDFrame> m_frame_queue;
std::deque<pybind11::dict> m_log_queue;
};
Expand Down
89 changes: 87 additions & 2 deletions hoomd/md/pytest/test_burst_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def test_len(sim, tmp_path):
assert len(burst_writer) == 0


def test_burst_dump(sim, tmp_path):
@pytest.mark.parametrize("start, end", [(0, -1), (0, 0), (0, 1), (0, 2), (1, 1),
(2, 2), (1, 2), (1, -1), (2, -1)])
def test_burst_dump(sim, tmp_path, start, end):
filename = tmp_path / "temporary_test_file.gsd"

burst_trigger = hoomd.trigger.Periodic(period=2, phase=1)
Expand All @@ -164,11 +166,59 @@ def test_burst_dump(sim, tmp_path):
# First frame is always written
assert len(traj) == 1

burst_writer.dump(start=start, end=end)
burst_writer.flush()
dumped_frames = [3, 5, 7]
if sim.device.communicator.rank == 0:
if end == -1:
end = len(dumped_frames)
with gsd.hoomd.open(name=filename, mode='r') as traj:
assert [frame.configuration.step for frame in traj
] == [0] + dumped_frames[start:end]


@pytest.mark.parametrize("clear_entire_buffer", [True, False])
def test_burst_dump_with_clear_buffer(sim, tmp_path, clear_entire_buffer):
filename = tmp_path / "temporary_test_file.gsd"
start_frame = 1
end_frame = 3
burst_trigger = hoomd.trigger.Periodic(period=2, phase=1)
burst_writer = hoomd.write.Burst(
trigger=burst_trigger,
filename=filename,
mode='wb',
dynamic=['property', 'momentum'],
max_burst_size=4,
write_at_start=True,
clear_whole_buffer_after_dump=clear_entire_buffer)
sim.operations.writers.append(burst_writer)
sim.run(12)
burst_writer.flush()
if sim.device.communicator.rank == 0:
assert Path(filename).exists()
with gsd.hoomd.open(filename, "r") as traj:
# First frame is always written
assert len(traj) == 1

burst_writer.dump(start_frame, end_frame)
burst_writer.flush()
dumped_frames = [0, 7, 9]
if sim.device.communicator.rank == 0:
with gsd.hoomd.open(name=filename, mode='r') as traj:
print([frame.configuration.step for frame in traj])
assert [frame.configuration.step for frame in traj] == dumped_frames

sim.run(4)
burst_writer.dump()
burst_writer.flush()
if clear_entire_buffer:
dumped_frames += [13, 15]
else:
dumped_frames += [11, 13, 15]
if sim.device.communicator.rank == 0:
with gsd.hoomd.open(name=filename, mode='r') as traj:
assert [frame.configuration.step for frame in traj] == [0, 3, 5, 7]
print([frame.configuration.step for frame in traj])
assert [frame.configuration.step for frame in traj] == dumped_frames


def test_burst_max_size(sim, tmp_path):
Expand Down Expand Up @@ -242,3 +292,38 @@ def test_write_burst_log(sim, tmp_path):
with gsd.hoomd.open(name=filename, mode='r') as traj:
for frame, sim_ke in zip(traj[1:], kinetic_energies):
assert frame.log[key] == sim_ke


@pytest.mark.parametrize("clear_entire_buffer", [True, False])
def test_burst_dump_empty_buffer(sim, tmp_path, clear_entire_buffer):
filename = tmp_path / "temporary_test_file.gsd"
burst_trigger = hoomd.trigger.Periodic(period=2, phase=1)
burst_writer = hoomd.write.Burst(
trigger=burst_trigger,
filename=filename,
mode='wb',
dynamic=['property', 'momentum'],
max_burst_size=3,
write_at_start=True,
clear_whole_buffer_after_dump=clear_entire_buffer)
sim.operations.writers.append(burst_writer)
sim.run(8)
burst_writer.flush()
if sim.device.communicator.rank == 0:
assert Path(filename).exists()
with gsd.hoomd.open(filename, "r") as traj:
# First frame is always written
assert len(traj) == 1

burst_writer.dump(1, 2)
burst_writer.flush()
if sim.device.communicator.rank == 0:
with gsd.hoomd.open(name=filename, mode='r') as traj:
assert len(traj) == 2

sim.run(4)
burst_writer.dump()
burst_writer.flush()
if sim.device.communicator.rank == 0:
with gsd.hoomd.open(name=filename, mode='r') as traj:
assert len(traj) == (4 if clear_entire_buffer else 5)
42 changes: 31 additions & 11 deletions hoomd/write/gsd_burst.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class Burst(GSD):
write_at_start (bool): When ``True`` **and** the file does not exist or
has 0 frames: write one frame with the current state of the system
when `hoomd.Simulation.run` is called. Defaults to ``False``.
clear_whole_buffer_after_dump (bool): When ``True`` the buffer is
emptied after calling `dump` each time. When ``False``, `dump`
removes frames from the buffer unil the ``end`` index. Defaults
to ``True``.
Warning:
`Burst` errors when attempting to create a file or writing to one with
Expand Down Expand Up @@ -80,6 +84,16 @@ class Burst(GSD):
.. code-block:: python
write_at_start = burst.write_at_start
clear_whole_buffer_after_dump (bool): When ``True`` the buffer is
emptied after calling `dump` each time. When ``False``, `dump`
removes frames from the buffer unil the ``end`` index.
.. rubric:: Example:
.. code-block:: python
burst.clear_buffer_after_dump = False
"""

def __init__(self,
Expand All @@ -90,7 +104,8 @@ def __init__(self,
dynamic=None,
logger=None,
max_burst_size=-1,
write_at_start=False):
write_at_start=False,
clear_whole_buffer_after_dump=True):
super().__init__(trigger=trigger,
filename=filename,
filter=filter,
Expand All @@ -102,32 +117,37 @@ def __init__(self,
ParameterDict(max_burst_size=int, write_at_start=bool))
self._param_dict.update({
"max_burst_size": max_burst_size,
"write_at_start": write_at_start
"write_at_start": write_at_start,
"clear_whole_buffer_after_dump": clear_whole_buffer_after_dump
})

def _attach_hook(self):
sim = self._simulation
self._cpp_obj = _hoomd.GSDDequeWriter(sim.state._cpp_sys_def,
self.trigger, self.filename,
sim.state._get_group(self.filter),
self.logger, self.max_burst_size,
self.mode, self.write_at_start,
sim.timestep)
self._cpp_obj = _hoomd.GSDDequeWriter(
sim.state._cpp_sys_def, self.trigger, self.filename,
sim.state._get_group(self.filter), self.logger, self.max_burst_size,
self.mode, self.write_at_start, self.clear_whole_buffer_after_dump,
sim.timestep)

def dump(self):
"""Write all currently stored frames to the file and empties the buffer.
def dump(self, start=0, end=-1):
"""Write stored frames in range to the file and empties the buffer.
This method alllows for custom writing of frames at user specified
conditions.
Args:
start (int): The first frame to write. Defaults to 0.
end (int): The last frame to write.
Defaults to -1 (last frame).
.. rubric:: Example:
.. code-block:: python
burst.dump()
"""
if self._attached:
self._cpp_obj.dump()
self._cpp_obj.dump(start, end)

def __len__(self):
"""Get the current length of the internal frame buffer.
Expand Down

0 comments on commit bcc61f0

Please sign in to comment.