diff --git a/src/solve/python.rs b/src/solve/python.rs index 3d110c59d8..b67deaeeb1 100644 --- a/src/solve/python.rs +++ b/src/solve/python.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // https://github.com/imageworks/spk use pyo3::prelude::*; +use pyo3::py_run; use super::errors::SolverError; use super::graph::{ @@ -17,7 +18,7 @@ fn init_submodule_errors(py: &Python, module: &PyModule) -> PyResult<()> { Ok(()) } -fn init_submodule_graph(module: &PyModule) -> PyResult<()> { +fn init_submodule_graph(_py: &Python, module: &PyModule) -> PyResult<()> { module.add_class::()?; module.add_class::()?; module.add_class::()?; @@ -33,7 +34,7 @@ fn init_submodule_graph(module: &PyModule) -> PyResult<()> { Ok(()) } -fn init_submodule_solution(module: &PyModule) -> PyResult<()> { +fn init_submodule_solution(_py: &Python, module: &PyModule) -> PyResult<()> { module.add_class::()?; Ok(()) } @@ -44,37 +45,34 @@ fn init_submodule_solver(py: &Python, module: &PyModule) -> PyResult<()> { Ok(()) } -fn init_submodule_validation(module: &PyModule) -> PyResult<()> { +fn init_submodule_validation(_py: &Python, module: &PyModule) -> PyResult<()> { module.add_class::()?; Ok(()) } +macro_rules! add_submodule { + ($m:ident, $py:ident, $mod_name:expr, $init_fn:ident) => { + let submod = PyModule::new(*$py, $mod_name)?; + // Hack to make `from spk.solve.foo import ...` work + py_run!( + *$py, + submod, + &format!( + "import sys; sys.modules['spkrs.solve.{}'] = submod", + $mod_name + ) + ); + $init_fn($py, submod)?; + $m.add_submodule(submod)?; + }; +} + pub fn init_module(py: &Python, m: &PyModule) -> PyResult<()> { - { - let submod_errors = PyModule::new(*py, "_errors")?; - init_submodule_errors(py, submod_errors)?; - m.add_submodule(submod_errors)?; - } - { - let submod_graph = PyModule::new(*py, "graph")?; - init_submodule_graph(submod_graph)?; - m.add_submodule(submod_graph)?; - } - { - let submod_solver = PyModule::new(*py, "_solver")?; - init_submodule_solver(py, submod_solver)?; - m.add_submodule(submod_solver)?; - } - { - let submod_solution = PyModule::new(*py, "_solution")?; - init_submodule_solution(submod_solution)?; - m.add_submodule(submod_solution)?; - } - { - let submod_validation = PyModule::new(*py, "validation")?; - init_submodule_validation(submod_validation)?; - m.add_submodule(submod_validation)?; - } + add_submodule!(m, py, "_errors", init_submodule_errors); + add_submodule!(m, py, "graph", init_submodule_graph); + add_submodule!(m, py, "_solver", init_submodule_solver); + add_submodule!(m, py, "_solution", init_submodule_solution); + add_submodule!(m, py, "validation", init_submodule_validation); m.add_class::()?; m.add_class::()?;