diff --git a/pyo3-derive-backend/src/pyfunction.rs b/pyo3-derive-backend/src/pyfunction.rs index c7bcacd3214..71dcb52821a 100644 --- a/pyo3-derive-backend/src/pyfunction.rs +++ b/pyo3-derive-backend/src/pyfunction.rs @@ -65,16 +65,10 @@ impl PyFunctionAttr { syn::Lit::Str(ref lits) => { // "*" if lits.value() == "*" { - if self.has_kwargs { + if self.has_kwargs || self.has_varargs { return Err(syn::Error::new_spanned( item, - "syntax error, keyword self.arguments is defined", - )); - } - if self.has_varargs { - return Err(syn::Error::new_spanned( - item, - "self.arguments already define * (var args)", + "* is not allowed after varargs(*) or kwargs(**)", )); } self.has_varargs = true; @@ -94,20 +88,52 @@ impl PyFunctionAttr { } fn add_work(&mut self, item: &NestedMeta, path: &Path) -> syn::Result<()> { - // self.arguments in form somename - if self.has_kwargs { + self.pos_arg_is_ok(item)?; + self.arguments.push(Argument::Arg(path.clone(), None)); + Ok(()) + } + + fn pos_arg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> { + if self.has_kw || self.has_kwargs || self.has_varargs { return Err(syn::Error::new_spanned( item, - "syntax error, keyword self.arguments is defined", + "Positional argument or varargs(*) is not allowed after keyword arguments", )); } - if self.has_kw { + if self.has_varargs { return Err(syn::Error::new_spanned( item, - "syntax error, argument is not allowed after keyword argument", + "Positional argument or varargs(*) is not allowed after *", + )); + } + Ok(()) + } + + fn kw_arg_is_ok(&self, item: &NestedMeta) -> syn::Result<()> { + if self.has_kwargs { + return Err(syn::Error::new_spanned( + item, + "Keyword argument or kwargs(**) is not allowed after kwargs(**)", )); } - self.arguments.push(Argument::Arg(path.clone(), None)); + Ok(()) + } + + fn add_nv_common( + &mut self, + item: &NestedMeta, + name: &syn::Path, + value: String, + ) -> syn::Result<()> { + self.kw_arg_is_ok(item)?; + if self.has_varargs { + // kw only + self.arguments.push(Argument::Kwarg(name.clone(), value)); + } else { + self.has_kw = true; + self.arguments + .push(Argument::Arg(name.clone(), Some(value))); + } Ok(()) } @@ -116,75 +142,23 @@ impl PyFunctionAttr { syn::Lit::Str(ref litstr) => { if litstr.value() == "*" { // args="*" - if self.has_kwargs { - return Err(syn::Error::new_spanned( - item, - "* - syntax error, keyword self.arguments is defined", - )); - } - if self.has_varargs { - return Err(syn::Error::new_spanned(item, "*(var args) is defined")); - } + self.pos_arg_is_ok(item)?; self.has_varargs = true; self.arguments.push(Argument::VarArgs(nv.path.clone())); } else if litstr.value() == "**" { // kwargs="**" - if self.has_kwargs { - return Err(syn::Error::new_spanned( - item, - "self.arguments already define ** (kw args)", - )); - } + self.kw_arg_is_ok(item)?; self.has_kwargs = true; self.arguments.push(Argument::KeywordArgs(nv.path.clone())); - } else if self.has_varargs { - self.arguments - .push(Argument::Kwarg(nv.path.clone(), litstr.value())) } else { - if self.has_kwargs { - return Err(syn::Error::new_spanned( - item, - "syntax error, keyword self.arguments is defined", - )); - } - self.has_kw = true; - self.arguments - .push(Argument::Arg(nv.path.clone(), Some(litstr.value()))) + self.add_nv_common(item, &nv.path, litstr.value())?; } } syn::Lit::Int(ref litint) => { - if self.has_varargs { - self.arguments - .push(Argument::Kwarg(nv.path.clone(), format!("{}", litint))); - } else { - if self.has_kwargs { - return Err(syn::Error::new_spanned( - item, - "syntax error, keyword self.arguments is defined", - )); - } - self.has_kw = true; - self.arguments - .push(Argument::Arg(nv.path.clone(), Some(format!("{}", litint)))); - } + self.add_nv_common(item, &nv.path, format!("{}", litint))?; } syn::Lit::Bool(ref litb) => { - if self.has_varargs { - self.arguments - .push(Argument::Kwarg(nv.path.clone(), format!("{}", litb.value))); - } else { - if self.has_kwargs { - return Err(syn::Error::new_spanned( - item, - "syntax error, keyword self.arguments is defined", - )); - } - self.has_kw = true; - self.arguments.push(Argument::Arg( - nv.path.clone(), - Some(format!("{}", litb.value)), - )); - } + self.add_nv_common(item, &nv.path, format!("{}", litb.value))?; } _ => { return Err(syn::Error::new_spanned( diff --git a/tests/test_methods.rs b/tests/test_methods.rs index 35906d8183f..b5547f42d28 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -231,7 +231,7 @@ impl MethArgs { [a.to_object(py), args.into(), kwargs.to_object(py)].to_object(py) } - #[args("*", c = 10)] + #[args(a, b, "*", c = 10)] fn get_pos_arg_kw_sep(&self, a: i32, b: i32, c: i32) -> PyResult { Ok(a + b + c) } diff --git a/tests/test_module.rs b/tests/test_module.rs index 535a74c943e..2798be41387 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -246,15 +246,14 @@ fn test_module_nesting() { } // Test that argument parsing specification works for pyfunctions - -#[pyfunction(a = 5, vararg = "*")] +#[pyfunction(a, vararg = "*")] fn ext_vararg_fn(py: Python, a: i32, vararg: &PyTuple) -> PyObject { [a.to_object(py), vararg.into()].to_object(py) } #[pymodule] fn vararg_module(_py: Python, m: &PyModule) -> PyResult<()> { - #[pyfn(m, "int_vararg_fn", a = 5, vararg = "*")] + #[pyfn(m, "int_vararg_fn", a, vararg = "*")] fn int_vararg_fn(py: Python, a: i32, vararg: &PyTuple) -> PyObject { ext_vararg_fn(py, a, vararg) } @@ -270,9 +269,9 @@ fn test_vararg_module() { let py = gil.python(); let m = pyo3::wrap_pymodule!(vararg_module)(py); - py_assert!(py, m, "m.ext_vararg_fn() == [5, ()]"); + py_assert!(py, m, "m.ext_vararg_fn(1) == [1, ()]"); py_assert!(py, m, "m.ext_vararg_fn(1, 2) == [1, (2,)]"); - py_assert!(py, m, "m.int_vararg_fn() == [5, ()]"); + py_assert!(py, m, "m.int_vararg_fn(1) == [1, ()]"); py_assert!(py, m, "m.int_vararg_fn(1, 2) == [1, (2,)]"); } diff --git a/tests/ui/invalid_args.rs b/tests/ui/invalid_args.rs new file mode 100644 index 00000000000..ed7ecefc439 --- /dev/null +++ b/tests/ui/invalid_args.rs @@ -0,0 +1,14 @@ +use pyo3::prelude::*; + +#[pyfunction(a = 5, vararg = "*")] +fn invalid_fn(py: Python, a: i32, vararg: &PyTuple) -> PyObject { + [a.to_object(py), vararg.into()].to_object(py) +} + +#[pyclass] +struct Class {} + +impl Class { + #[args("*", a = 5)] + fn invalid_method(&self, a: i32) {} +}