Skip to content

Commit

Permalink
Raise Python exceptions instead of aborting the process
Browse files Browse the repository at this point in the history
Add InvalidFlatbuffer exception, raised upon invalid data into unpack
Raise ValueError upon out of bounds number being passed to __init__ for enums
  • Loading branch information
VirxEC committed May 8, 2024
1 parent e128f84 commit 71419d5
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 50 deletions.
48 changes: 24 additions & 24 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rlbot-flatbuffers-py"
version = "0.3.3"
version = "0.3.4"
edition = "2021"
description = "A Python module implemented in Rust for serializing and deserializing RLBot's flatbuffers"
repository = "https://github.com/VirxEC/rlbot-flatbuffers-py"
Expand Down
51 changes: 34 additions & 17 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ impl PythonBindGenerator {
let mut file_contents = vec![];

file_contents.push(Cow::Borrowed(match bind_type {
PythonBindType::Struct | PythonBindType::Union => "use crate::{generated::rlbot::flat, FromGil, IntoGil};",
PythonBindType::Enum => "use crate::generated::rlbot::flat;",
PythonBindType::Struct => "use crate::{flat_err_to_py, generated::rlbot::flat, FromGil, IntoGil};",
PythonBindType::Union => "use crate::{generated::rlbot::flat, FromGil, IntoGil};",
PythonBindType::Enum => "use crate::{flat_err_to_py, generated::rlbot::flat};",
}));

if bind_type != PythonBindType::Union {
Expand All @@ -82,8 +83,10 @@ impl PythonBindGenerator {
}

file_contents.push(Cow::Borrowed(match bind_type {
PythonBindType::Struct => "use pyo3::{pyclass, pymethods, types::PyBytes, Bound, Py, Python};",
PythonBindType::Enum => "use pyo3::{pyclass, pymethods, types::PyBytes, Bound, Python};",
PythonBindType::Struct => "use pyo3::{pyclass, pymethods, types::PyBytes, Bound, Py, PyResult, Python};",
PythonBindType::Enum => {
"use pyo3::{exceptions::PyValueError, pyclass, pymethods, types::PyBytes, Bound, PyResult, Python};"
}
PythonBindType::Union => "use pyo3::{pyclass, pymethods, Py, PyObject, Python, ToPyObject};",
}));

Expand Down Expand Up @@ -809,19 +812,20 @@ impl PythonBindGenerator {
assert!(u8::try_from(self.types.len()).is_ok());

self.write_str(" #[pyo3(signature = (value=Default::default()))]");
self.write_str(" pub fn new(value: u8) -> Self {");
self.write_str(" pub fn new(value: u8) -> PyResult<Self> {");
self.write_str(" match value {");

for variable_info in &self.types {
let variable_name = &variable_info[0];
let variable_value = &variable_info[1];

self.file_contents
.push(Cow::Owned(format!(" {variable_value} => Self::{variable_name},")));
self.file_contents.push(Cow::Owned(format!(
" {variable_value} => Ok(Self::{variable_name}),"
)));
}

if self.types.len() != usize::from(u8::MAX) {
self.write_str(" v => panic!(\"Unknown value: {v}\"),");
self.write_str(" v => Err(PyValueError::new_err(format!(\"Unknown value of {v}\"))),");
}

self.write_str(" }");
Expand Down Expand Up @@ -1118,14 +1122,17 @@ impl PythonBindGenerator {
self.write_str(" #[staticmethod]");

if self.bind_type == PythonBindType::Enum {
self.write_str(" fn unpack(data: &[u8]) -> Self {");
self.write_string(format!(" root::<flat::{}>(data).unwrap().into()", self.struct_name));
self.write_str(" fn unpack(data: &[u8]) -> PyResult<Self> {");
self.write_string(format!(" match root::<flat::{}>(data) {{", self.struct_name));
self.write_str(" Ok(flat_t) => Ok(flat_t.into()),");
self.write_str(" Err(e) => Err(flat_err_to_py(e)),");
self.write_str(" }");
} else {
self.write_str(" fn unpack(py: Python, data: &[u8]) -> Py<Self> {");
self.write_string(format!(
" root::<flat::{}>(data).unwrap().unpack().into_gil(py)",
self.struct_name
));
self.write_str(" fn unpack(py: Python, data: &[u8]) -> PyResult<Py<Self>> {");
self.write_string(format!(" match root::<flat::{}>(data) {{", self.struct_name));
self.write_str(" Ok(flat_t) => Ok(flat_t.unpack().into_gil(py)),");
self.write_str(" Err(e) => Err(flat_err_to_py(e)),");
self.write_str(" }");
}

self.write_str(" }");
Expand Down Expand Up @@ -1226,6 +1233,8 @@ fn pyi_generator(type_data: &[(String, String, Vec<Vec<String>>)]) -> io::Result
Cow::Borrowed("__doc__: str"),
Cow::Borrowed("__version__: str"),
Cow::Borrowed(""),
Cow::Borrowed("class InvalidFlatbuffer(ValueError): ..."),
Cow::Borrowed(""),
];

let primitive_map = [
Expand Down Expand Up @@ -1347,7 +1356,10 @@ fn pyi_generator(type_data: &[(String, String, Vec<Vec<String>>)]) -> io::Result
file_contents.push(Cow::Borrowed(""));

if is_enum {
file_contents.push(Cow::Borrowed(" def __init__(self, value: int = 0): ..."));
file_contents.push(Cow::Borrowed(" def __init__(self, value: int = 0):"));
file_contents.push(Cow::Borrowed(" \"\"\""));
file_contents.push(Cow::Borrowed(" :raises ValueError: If the `value` is not a valid enum value"));
file_contents.push(Cow::Borrowed(" \"\"\""));
} else {
file_contents.push(Cow::Borrowed(" def __init__("));
file_contents.push(Cow::Borrowed(" self,"));
Expand Down Expand Up @@ -1404,7 +1416,12 @@ fn pyi_generator(type_data: &[(String, String, Vec<Vec<String>>)]) -> io::Result
if !is_union {
file_contents.push(Cow::Borrowed(" def pack(self) -> bytes: ..."));
file_contents.push(Cow::Borrowed(" @staticmethod"));
file_contents.push(Cow::Owned(format!(" def unpack(data: bytes) -> {type_name}: ...")));
file_contents.push(Cow::Owned(format!(" def unpack(data: bytes) -> {type_name}:")));
file_contents.push(Cow::Borrowed(" \"\"\""));
file_contents.push(Cow::Borrowed(
" :raises InvalidFlatbuffer: If the `data` is invalid for this type",
));
file_contents.push(Cow::Borrowed(" \"\"\""));
}

file_contents.push(Cow::Borrowed(""));
Expand Down
43 changes: 38 additions & 5 deletions pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __add__(self, other):
)
dgs.game_info_state.world_gravity_z = Float(-650)
dgs.game_info_state.end_match.val = True
dgs.console_commands = [ConsoleCommand("freeze")]
dgs.console_commands = [ConsoleCommand("dump_items")]
dgs.ball_state = DesiredBallState()

print(repr(dgs))
Expand Down Expand Up @@ -52,20 +52,53 @@ def __add__(self, other):
print(comm.content.decode("utf-8"))
print()

num_trials = 1_000_000
try:
AirState(8)
except ValueError as e:
print(e)
print()

invalid_data = comm.pack()

try:
RenderMessage.unpack(invalid_data)
except InvalidFlatbuffer as e:
print(e)

print("Running quick benchmark...")

num_trials = 100_000

total_make_time = 0
total_pack_time = 0
total_unpack_time = 0
for _ in range(num_trials):
start = time_ns()
desired_game_state = DesiredGameState(
DesiredBallState(DesiredPhysics()),
car_states=[DesiredCarState(boost_amount=100)],
DesiredBallState(
DesiredPhysics(
Vector3Partial(0, 0, 0),
RotatorPartial(0, 0, 0),
Vector3Partial(0, 0, 0),
Vector3Partial(0, 0, 0),
)
),
car_states=[
DesiredCarState(
DesiredPhysics(
Vector3Partial(0, 0, 0),
RotatorPartial(0, 0, 0),
Vector3Partial(0, 0, 0),
Vector3Partial(0, 0, 0),
),
100,
)
for _ in range(8)
],
game_info_state=DesiredGameInfoState(
game_speed=1, world_gravity_z=-650, end_match=True
),
console_commands=[ConsoleCommand("freeze")],
console_commands=[ConsoleCommand("dump_items")],
)
total_make_time += time_ns() - start

Expand Down
20 changes: 17 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,18 @@ pub mod generated;
#[allow(clippy::enum_variant_names)]
mod python;

use pyo3::{prelude::*, types::PyBytes, PyClass};
use pyo3::{create_exception, exceptions::PyValueError, prelude::*, types::PyBytes, PyClass};
use python::*;
use std::panic::Location;

create_exception!(rlbot_flatbuffers, InvalidFlatbuffer, PyValueError, "Invalid FlatBuffer");

#[track_caller]
pub fn flat_err_to_py(err: flatbuffers::InvalidFlatbuffer) -> PyErr {
let caller = Location::caller();
let err_msg = format!("Can't make flatbuffer @ \"rlbot_flatbuffers/{}\":\n {err}", caller.file());
InvalidFlatbuffer::new_err(err_msg)
}

pub trait FromGil<T> {
fn from_gil(py: Python, obj: T) -> Self;
Expand Down Expand Up @@ -108,13 +118,14 @@ impl FromGil<Bools> for Py<Bool> {
}

macro_rules! pynamedmodule {
(doc: $doc:literal, name: $name:tt, classes: [$($class_name:ident),*], vars: [$(($var_name:literal, $value:expr)),*]) => {
(doc: $doc:literal, name: $name:tt, classes: [$($class_name:ident),*], vars: [$(($var_name:literal, $value:expr)),*], exceptions: [$($except:expr),*]) => {
#[doc = $doc]
#[pymodule]
#[allow(redundant_semicolons)]
fn $name(m: Bound<PyModule>) -> PyResult<()> {
fn $name(py: Python, m: Bound<PyModule>) -> PyResult<()> {
$(m.add_class::<$class_name>()?);*;
$(m.add($var_name, $value)?);*;
$(m.add(stringify!($except), py.get_type_bound::<$except>())?);*;
Ok(())
}
};
Expand Down Expand Up @@ -213,5 +224,8 @@ pynamedmodule! {
],
vars: [
("__version__", env!("CARGO_PKG_VERSION"))
],
exceptions: [
InvalidFlatbuffer
]
}

0 comments on commit 71419d5

Please sign in to comment.