diff --git a/python/BUILD b/python/BUILD index 58ed6ce..cfe87ff 100644 --- a/python/BUILD +++ b/python/BUILD @@ -10,12 +10,14 @@ pybind_extension( name = "array_record_module", srcs = ["array_record_module.cc"], deps = [ + "//file/base:options_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "//cpp:array_record_reader", "//cpp:array_record_writer", "//cpp:thread_pool", + "//third_party/protobuf", "@com_google_riegeli//riegeli/bytes:fd_reader", "@com_google_riegeli//riegeli/bytes:fd_writer", ], diff --git a/python/array_record_module.cc b/python/array_record_module.cc index ab14893..fda36f6 100644 --- a/python/array_record_module.cc +++ b/python/array_record_module.cc @@ -22,12 +22,14 @@ limitations under the License. #include #include +#include "file/base/options.pb.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "cpp/array_record_reader.h" #include "cpp/array_record_writer.h" #include "cpp/thread_pool.h" +#include "third_party/protobuf/text_format.h" #include "pybind11/gil.h" #include "pybind11/pybind11.h" #include "pybind11/pytypes.h" @@ -44,7 +46,8 @@ PYBIND11_MODULE(array_record_module, m) { array_record::ArrayRecordReader>; py::class_(m, "ArrayRecordWriter") - .def(py::init([](const std::string& path, const std::string& options) { + .def(py::init([](const std::string& path, const std::string& options, + const file::Options& file_options) { auto status_or_option = array_record::ArrayRecordWriterBase::Options::FromString( options); @@ -66,7 +69,8 @@ PYBIND11_MODULE(array_record_module, m) { return ArrayRecordWriter(std::move(file_writer), status_or_option.value()); }), - py::arg("path"), py::arg("options") = "") + py::arg("path"), py::arg("options") = "", + py::arg("file_options") = "") .def("ok", &ArrayRecordWriter::ok) .def("close", [](ArrayRecordWriter& writer) { @@ -84,10 +88,10 @@ PYBIND11_MODULE(array_record_module, m) { throw std::runtime_error(std::string(writer.status().message())); } }); - py::class_(m, "ArrayRecordReader") .def(py::init([](const std::string& path, const std::string& options, - const std::optional file_reader_buffer_size) { + const std::optional file_reader_buffer_size, + const std::string& file_options_str) { auto status_or_option = array_record::ArrayRecordReaderBase::Options::FromString( options); @@ -115,7 +119,8 @@ PYBIND11_MODULE(array_record_module, m) { array_record::ArrayRecordGlobalPool()); }), py::arg("path"), py::arg("options") = "", - py::arg("file_reader_buffer_size") = std::nullopt, R"( + py::arg("file_reader_buffer_size") = std::nullopt, + py::arg("file_options_str") = "", R"( ArrayRecordReader for fast sequential or random access. Args: @@ -124,6 +129,8 @@ PYBIND11_MODULE(array_record_module, m) { file_reader_buffer_size: Optional size of the buffer (in bytes) for the underlying file (Riegeli) reader. The default buffer size is 1 MiB. + file_options_str: Optional file::Options textproto to use for the underlying + file (Riegeli) reader. options ::= option? ("," option?)* option ::= diff --git a/python/array_record_module_test.py b/python/array_record_module_test.py index 4df8ff2..2ec7dbd 100644 --- a/python/array_record_module_test.py +++ b/python/array_record_module_test.py @@ -142,5 +142,30 @@ def test_writer_options(self): "group_size:42,transpose:false,pad_to_block_boundary:false,zstd:3," "window_log:20,max_parallelism:1") + def test_write_read_with_file_options(self): + writer = ArrayRecordWriter(self.test_file, "", "priority:200") + test_strs = [b"abc", b"def", b"ghi"] + for s in test_strs: + writer.write(s) + writer.close() + reader = ArrayRecordReader( + self.test_file, + "readahead_buffer_size:0,max_parallelism:0", + None, + "priority:200", + ) + num_strs = len(test_strs) + self.assertEqual(reader.num_records(), num_strs) + self.assertEqual(reader.record_index(), 0) + for gt in test_strs: + result = reader.read() + self.assertEqual(result, gt) + self.assertRaises(IndexError, reader.read) + reader.seek(0) + self.assertEqual(reader.record_index(), 0) + self.assertEqual(reader.read(), test_strs[0]) + self.assertEqual(reader.record_index(), 1) + + if __name__ == "__main__": absltest.main()