Skip to content

Commit

Permalink
Add file::Options as an optional argument to ArrayRecordReader and Ar…
Browse files Browse the repository at this point in the history
…rayRecordWriter

PiperOrigin-RevId: 686520599
  • Loading branch information
ArrayRecord Team authored and copybara-github committed Oct 16, 2024
1 parent ee89b74 commit 9c1edd9
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 5 deletions.
2 changes: 2 additions & 0 deletions python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ pybind_extension(
name = "array_record_module",
srcs = ["array_record_module.cc"],
deps = [
"//file/base:options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"//cpp:array_record_reader",
"//cpp:array_record_writer",
"//cpp:thread_pool",
"//third_party/protobuf",
"@com_google_riegeli//riegeli/bytes:fd_reader",
"@com_google_riegeli//riegeli/bytes:fd_writer",
],
Expand Down
17 changes: 12 additions & 5 deletions python/array_record_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ limitations under the License.
#include <utility>
#include <vector>

#include "file/base/options.pb.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "cpp/array_record_reader.h"
#include "cpp/array_record_writer.h"
#include "cpp/thread_pool.h"
#include "third_party/protobuf/text_format.h"
#include "pybind11/gil.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
Expand All @@ -44,7 +46,8 @@ PYBIND11_MODULE(array_record_module, m) {
array_record::ArrayRecordReader<std::unique_ptr<riegeli::Reader>>;

py::class_<ArrayRecordWriter>(m, "ArrayRecordWriter")
.def(py::init([](const std::string& path, const std::string& options) {
.def(py::init([](const std::string& path, const std::string& options,
const file::Options& file_options) {
auto status_or_option =
array_record::ArrayRecordWriterBase::Options::FromString(
options);
Expand All @@ -66,7 +69,8 @@ PYBIND11_MODULE(array_record_module, m) {
return ArrayRecordWriter(std::move(file_writer),
status_or_option.value());
}),
py::arg("path"), py::arg("options") = "")
py::arg("path"), py::arg("options") = "",
py::arg("file_options") = "")
.def("ok", &ArrayRecordWriter::ok)
.def("close",
[](ArrayRecordWriter& writer) {
Expand All @@ -84,10 +88,10 @@ PYBIND11_MODULE(array_record_module, m) {
throw std::runtime_error(std::string(writer.status().message()));
}
});

py::class_<ArrayRecordReader>(m, "ArrayRecordReader")
.def(py::init([](const std::string& path, const std::string& options,
const std::optional<int64_t> file_reader_buffer_size) {
const std::optional<int64_t> file_reader_buffer_size,
const std::string& file_options_str) {
auto status_or_option =
array_record::ArrayRecordReaderBase::Options::FromString(
options);
Expand Down Expand Up @@ -115,7 +119,8 @@ PYBIND11_MODULE(array_record_module, m) {
array_record::ArrayRecordGlobalPool());
}),
py::arg("path"), py::arg("options") = "",
py::arg("file_reader_buffer_size") = std::nullopt, R"(
py::arg("file_reader_buffer_size") = std::nullopt,
py::arg("file_options_str") = "", R"(
ArrayRecordReader for fast sequential or random access.
Args:
Expand All @@ -124,6 +129,8 @@ PYBIND11_MODULE(array_record_module, m) {
file_reader_buffer_size: Optional size of the buffer (in bytes)
for the underlying file (Riegeli) reader. The default buffer
size is 1 MiB.
file_options_str: Optional file::Options textproto to use for the underlying
file (Riegeli) reader.
options ::= option? ("," option?)*
option ::=
Expand Down
25 changes: 25 additions & 0 deletions python/array_record_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,30 @@ def test_writer_options(self):
"group_size:42,transpose:false,pad_to_block_boundary:false,zstd:3,"
"window_log:20,max_parallelism:1")

def test_write_read_with_file_options(self):
writer = ArrayRecordWriter(self.test_file, "", "priority:200")
test_strs = [b"abc", b"def", b"ghi"]
for s in test_strs:
writer.write(s)
writer.close()
reader = ArrayRecordReader(
self.test_file,
"readahead_buffer_size:0,max_parallelism:0",
None,
"priority:200",
)
num_strs = len(test_strs)
self.assertEqual(reader.num_records(), num_strs)
self.assertEqual(reader.record_index(), 0)
for gt in test_strs:
result = reader.read()
self.assertEqual(result, gt)
self.assertRaises(IndexError, reader.read)
reader.seek(0)
self.assertEqual(reader.record_index(), 0)
self.assertEqual(reader.read(), test_strs[0])
self.assertEqual(reader.record_index(), 1)


if __name__ == "__main__":
absltest.main()

0 comments on commit 9c1edd9

Please sign in to comment.