Skip to content

Commit

Permalink
Merge pull request #453 from SeokminHong/impl-encoder
Browse files Browse the repository at this point in the history
Change map helper functions' arguments
  • Loading branch information
evnu authored May 27, 2022
2 parents bc6238b + 150cae1 commit adc0b9d
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 90 deletions.
8 changes: 3 additions & 5 deletions rustler/src/types/elixir_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@ use super::map::map_new;
use crate::{Env, NifResult, Term};

pub fn get_ex_struct_name(map: Term) -> NifResult<Atom> {
let env = map.get_env();
// In an Elixir struct the value in the __struct__ field is always an atom.
map.map_get(atom::__struct__().to_term(env))
.and_then(Atom::from_term)
map.map_get(atom::__struct__()).and_then(Atom::from_term)
}

pub fn make_ex_struct<'a>(env: Env<'a>, struct_module: &str) -> NifResult<Term<'a>> {
let map = map_new(env);

let struct_atom = atom::__struct__().to_term(env);
let module_atom = Atom::from_str(env, struct_module)?.to_term(env);
let struct_atom = atom::__struct__();
let module_atom = Atom::from_str(env, struct_module)?;

map.map_put(struct_atom, module_atom)
}
89 changes: 36 additions & 53 deletions rustler/src/types/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use super::atom;
use crate::wrapper::map;
use crate::{Decoder, Env, Error, NifResult, Term};
use crate::{Decoder, Encoder, Env, Error, NifResult, Term};
use std::ops::RangeInclusive;

pub fn map_new(env: Env) -> Term {
Expand Down Expand Up @@ -31,12 +31,12 @@ impl<'a> Term<'a> {
/// ```
pub fn map_from_arrays(
env: Env<'a>,
keys: &[Term<'a>],
values: &[Term<'a>],
keys: &[impl Encoder],
values: &[impl Encoder],
) -> NifResult<Term<'a>> {
if keys.len() == values.len() {
let keys: Vec<_> = keys.iter().map(|k| k.as_c_arg()).collect();
let values: Vec<_> = values.iter().map(|v| v.as_c_arg()).collect();
let keys: Vec<_> = keys.iter().map(|k| k.encode(env).as_c_arg()).collect();
let values: Vec<_> = values.iter().map(|v| v.encode(env).as_c_arg()).collect();

unsafe {
map::make_map_from_arrays(env.as_c_arg(), &keys, &values)
Expand All @@ -57,10 +57,13 @@ impl<'a> Term<'a> {
/// ```elixir
/// Map.new([{"foo", 1}, {"bar", 2}])
/// ```
pub fn map_from_pairs(env: Env<'a>, pairs: &[(Term<'a>, Term<'a>)]) -> NifResult<Term<'a>> {
pub fn map_from_pairs(
env: Env<'a>,
pairs: &[(impl Encoder, impl Encoder)],
) -> NifResult<Term<'a>> {
let (keys, values): (Vec<_>, Vec<_>) = pairs
.iter()
.map(|(k, v)| (k.as_c_arg(), v.as_c_arg()))
.map(|(k, v)| (k.encode(env).as_c_arg(), v.encode(env).as_c_arg()))
.unzip();

unsafe {
Expand All @@ -78,9 +81,11 @@ impl<'a> Term<'a> {
/// ```elixir
/// Map.get(self_term, key)
/// ```
pub fn map_get(self, key: Term) -> NifResult<Term<'a>> {
pub fn map_get(self, key: impl Encoder) -> NifResult<Term<'a>> {
let env = self.get_env();
match unsafe { map::get_map_value(env.as_c_arg(), self.as_c_arg(), key.as_c_arg()) } {
match unsafe {
map::get_map_value(env.as_c_arg(), self.as_c_arg(), key.encode(env).as_c_arg())
} {
Some(value) => Ok(unsafe { Term::new(env, value) }),
None => Err(Error::BadArg),
}
Expand Down Expand Up @@ -108,27 +113,18 @@ impl<'a> Term<'a> {
/// ```elixir
/// Map.put(self_term, key, value)
/// ```
pub fn map_put(self, key: Term<'a>, value: Term<'a>) -> NifResult<Term<'a>> {
let map_env = self.get_env();

assert!(
map_env == key.get_env(),
"key is from different environment as map"
);
assert!(
map_env == value.get_env(),
"value is from different environment as map"
);
pub fn map_put(self, key: impl Encoder, value: impl Encoder) -> NifResult<Term<'a>> {
let env = self.get_env();

match unsafe {
map::map_put(
map_env.as_c_arg(),
env.as_c_arg(),
self.as_c_arg(),
key.as_c_arg(),
value.as_c_arg(),
key.encode(env).as_c_arg(),
value.encode(env).as_c_arg(),
)
} {
Some(inner) => Ok(unsafe { Term::new(map_env, inner) }),
Some(inner) => Ok(unsafe { Term::new(env, inner) }),
None => Err(Error::BadArg),
}
}
Expand All @@ -142,16 +138,13 @@ impl<'a> Term<'a> {
/// ```elixir
/// Map.delete(self_term, key)
/// ```
pub fn map_remove(self, key: Term<'a>) -> NifResult<Term<'a>> {
let map_env = self.get_env();

assert!(
map_env == key.get_env(),
"key is from different environment as map"
);
pub fn map_remove(self, key: impl Encoder) -> NifResult<Term<'a>> {
let env = self.get_env();

match unsafe { map::map_remove(map_env.as_c_arg(), self.as_c_arg(), key.as_c_arg()) } {
Some(inner) => Ok(unsafe { Term::new(map_env, inner) }),
match unsafe {
map::map_remove(env.as_c_arg(), self.as_c_arg(), key.encode(env).as_c_arg())
} {
Some(inner) => Ok(unsafe { Term::new(env, inner) }),
None => Err(Error::BadArg),
}
}
Expand All @@ -160,27 +153,18 @@ impl<'a> Term<'a> {
///
/// Returns Err(Error::BadArg) if the term is not a map of if key
/// doesn't exist.
pub fn map_update(self, key: Term<'a>, new_value: Term<'a>) -> NifResult<Term<'a>> {
let map_env = self.get_env();

assert!(
map_env == key.get_env(),
"key is from different environment as map"
);
assert!(
map_env == new_value.get_env(),
"value is from different environment as map"
);
pub fn map_update(self, key: impl Encoder, new_value: impl Encoder) -> NifResult<Term<'a>> {
let env = self.get_env();

match unsafe {
map::map_update(
map_env.as_c_arg(),
env.as_c_arg(),
self.as_c_arg(),
key.as_c_arg(),
new_value.as_c_arg(),
key.encode(env).as_c_arg(),
new_value.encode(env).as_c_arg(),
)
} {
Some(inner) => Ok(unsafe { Term::new(map_env, inner) }),
Some(inner) => Ok(unsafe { Term::new(env, inner) }),
None => Err(Error::BadArg),
}
}
Expand Down Expand Up @@ -234,17 +218,16 @@ where
T: Decoder<'a>,
{
fn decode(term: Term<'a>) -> NifResult<Self> {
let env = term.get_env();
let name = term.map_get(atom::__struct__().to_term(env))?;
let name = term.map_get(atom::__struct__())?;

match name.atom_to_string()?.as_ref() {
"Elixir.Range" => (),
_ => return Err(Error::BadArg),
}

let first = term.map_get(atom::first().to_term(env))?.decode::<T>()?;
let last = term.map_get(atom::last().to_term(env))?.decode::<T>()?;
if let Ok(step) = term.map_get(atom::step().to_term(env)) {
let first = term.map_get(atom::first())?.decode::<T>()?;
let last = term.map_get(atom::last())?.decode::<T>()?;
if let Ok(step) = term.map_get(atom::step()) {
match step.decode::<i64>()? {
1 => (),
_ => return Err(Error::BadArg),
Expand Down
5 changes: 1 addition & 4 deletions rustler/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,7 @@ where
V: Encoder,
{
fn encode<'c>(&self, env: Env<'c>) -> Term<'c> {
let (keys, values): (Vec<_>, Vec<_>) = self
.iter()
.map(|(k, v)| (k.encode(env), v.encode(env)))
.unzip();
let (keys, values): (Vec<_>, Vec<_>) = self.iter().unzip();
Term::map_from_arrays(env, &keys, &values).unwrap()
}
}
15 changes: 6 additions & 9 deletions rustler_codegen/src/ex_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
let variable = Context::escape_ident_with_index(&ident.to_string(), index, "struct");

let assignment = quote_spanned! { field.span() =>
let #variable = try_decode_field(env, term, #atom_fun())?;
let #variable = try_decode_field(term, #atom_fun())?;
};

let field_def = quote! {
Expand All @@ -100,18 +100,15 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
use #atoms_module_name::*;
use ::rustler::Encoder;

let env = term.get_env();

fn try_decode_field<'a, T>(
env: rustler::Env<'a>,
term: rustler::Term<'a>,
field: rustler::Atom,
) -> ::rustler::NifResult<T>
where
T: rustler::Decoder<'a>,
{
use rustler::Encoder;
match ::rustler::Decoder::decode(term.map_get(field.encode(env))?) {
match ::rustler::Decoder::decode(term.map_get(field)?) {
Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!(
"Could not decode field :{:?} on %{}{{}}",
field, #struct_name_str
Expand All @@ -120,7 +117,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
}
};

let module: ::rustler::types::atom::Atom = term.map_get(atom_struct().to_term(env))?.decode()?;
let module: ::rustler::types::atom::Atom = term.map_get(atom_struct())?.decode()?;
if module != atom_module() {
return Err(::rustler::Error::RaiseAtom("invalid_struct"));
}
Expand Down Expand Up @@ -149,14 +146,14 @@ fn gen_encoder(
let field_ident = field.ident.as_ref().unwrap();
let atom_fun = Context::field_to_atom_fun(field);
quote_spanned! { field.span() =>
map = map.map_put(#atom_fun().encode(env), self.#field_ident.encode(env)).unwrap();
map = map.map_put(#atom_fun(), &self.#field_ident).unwrap();
}
})
.collect();

let exception_field = if add_exception {
quote! {
map = map.map_put(atom_exception().encode(env), true.encode(env)).unwrap();
map = map.map_put(atom_exception(), true).unwrap();
}
} else {
quote! {}
Expand All @@ -167,7 +164,7 @@ fn gen_encoder(
fn encode<'a>(&self, env: ::rustler::Env<'a>) -> ::rustler::Term<'a> {
use #atoms_module_name::*;
let mut map = ::rustler::types::map::map_new(env);
map = map.map_put(atom_struct().encode(env), atom_module().encode(env)).unwrap();
map = map.map_put(atom_struct(), atom_module()).unwrap();
#exception_field
#(#field_defs)*
map
Expand Down
9 changes: 3 additions & 6 deletions rustler_codegen/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
let variable = Context::escape_ident_with_index(&ident.to_string(), index, "map");

let assignment = quote_spanned! { field.span() =>
let #variable = try_decode_field(env, term, #atom_fun())?;
let #variable = try_decode_field(term, #atom_fun())?;
};

let field_def = quote! {
Expand All @@ -81,18 +81,15 @@ fn gen_decoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult<Self> {
use #atoms_module_name::*;

let env = term.get_env();

fn try_decode_field<'a, T>(
env: rustler::Env<'a>,
term: rustler::Term<'a>,
field: rustler::Atom,
) -> ::rustler::NifResult<T>
where
T: rustler::Decoder<'a>,
{
use rustler::Encoder;
match ::rustler::Decoder::decode(term.map_get(field.encode(env))?) {
match ::rustler::Decoder::decode(term.map_get(field)?) {
Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!(
"Could not decode field :{:?} on %{{}}",
field
Expand Down Expand Up @@ -121,7 +118,7 @@ fn gen_encoder(ctx: &Context, fields: &[&Field], atoms_module_name: &Ident) -> T
let atom_fun = Context::field_to_atom_fun(field);

quote_spanned! { field.span() =>
map = map.map_put(#atom_fun().encode(env), self.#field_ident.encode(env)).unwrap();
map = map.map_put(#atom_fun(), self.#field_ident).unwrap();
}
})
.collect();
Expand Down
26 changes: 13 additions & 13 deletions rustler_codegen/src/tagged_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,17 @@ fn gen_decoder(ctx: &Context, variants: &[&Variant], atoms_module_name: &Ident)
fn decode(term: ::rustler::Term<'a>) -> ::rustler::NifResult<Self> {
use #atoms_module_name::*;

let env = term.get_env();
let value = ::rustler::types::atom::Atom::from_term(term);

fn try_decode_field<'a, T>(
env: ::rustler::Env<'a>,
term: ::rustler::Term<'a>,
field: ::rustler::Atom,
) -> ::rustler::NifResult<T>
where
T: ::rustler::Decoder<'a>,
{
use ::rustler::Encoder;
match ::rustler::Decoder::decode(term.map_get(field.encode(env))?) {
match ::rustler::Decoder::decode(term.map_get(field)?) {
Err(_) => Err(::rustler::Error::RaiseTerm(Box::new(format!(
"Could not decode field :{:?} on %{{}}",
field
Expand Down Expand Up @@ -244,7 +242,7 @@ fn gen_named_decoder(
let enum_name_string = enum_name.to_string();

let assignment = quote_spanned! { field.span() =>
let #variable = try_decode_field(env, tuple[1], #atom_fun()).map_err(|_|{
let #variable = try_decode_field(tuple[1], #atom_fun()).map_err(|_|{
::rustler::Error::RaiseTerm(Box::new(format!(
"Could not decode field '{}' on Enum '{}'",
#ident_string, #enum_name_string
Expand Down Expand Up @@ -315,22 +313,24 @@ fn gen_named_encoder(
}
})
.collect::<Vec<_>>();
let field_defs = fields.named.iter()
let (keys, values): (Vec<_>, Vec<_>) = fields
.named
.iter()
.map(|field| {
let field_ident = field.ident.as_ref().expect("Named fields must have an ident.");
let field_ident = field
.ident
.as_ref()
.expect("Named fields must have an ident.");
let atom_fun = Context::field_to_atom_fun(field);

quote_spanned! { field.span() =>
map = map.map_put(#atom_fun().encode(env), #field_ident.encode(env)).expect("Failed to putting map");
}
(atom_fun, field_ident)
})
.collect::<Vec<_>>();
.unzip();
quote! {
#enum_name :: #variant_ident{
#(#field_decls)*
} => {
let mut map = ::rustler::types::map::map_new(env);
#(#field_defs)*
let map = ::rustler::Term::map_from_arrays(env, &[#(#keys()),*], &[#(#values),*])
.expect("Failed to create map");
::rustler::types::tuple::make_tuple(env, &[#atom_fn().encode(env), map])
}
}
Expand Down

0 comments on commit adc0b9d

Please sign in to comment.