Skip to content

Commit

Permalink
Add support for .. in let pattern matching for structs
Browse files Browse the repository at this point in the history
  • Loading branch information
Kijewski committed Jun 28, 2024
1 parent 214c445 commit a0a6c44
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 12 deletions.
9 changes: 7 additions & 2 deletions rinja_derive/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1788,8 +1788,8 @@ impl<'a> Generator<'a> {
target: &Target<'a>,
) {
match target {
Target::Name("_") => {
buf.write("_");
Target::Placeholder(s) | Target::Rest(s) => {
buf.write(s);
}
Target::Name(name) => {
let name = normalize_identifier(name);
Expand Down Expand Up @@ -1824,6 +1824,11 @@ impl<'a> Generator<'a> {
buf.write(SeparatedPath(path));
buf.write(" { ");
for (name, target) in targets {
if let Target::Rest(s) = target {
buf.write(s);
continue;
}

buf.write(normalize_identifier(name));
buf.write(": ");
self.visit_target(buf, initialized, false, target);
Expand Down
69 changes: 59 additions & 10 deletions rinja_parser/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::str;

use nom::branch::alt;
use nom::bytes::complete::{tag, take_till};
use nom::character::complete::char;
use nom::character::complete::{char, one_of};
use nom::combinator::{
complete, consumed, cut, eof, map, map_res, not, opt, peek, recognize, value,
};
Expand All @@ -11,7 +11,7 @@ use nom::error_position;
use nom::multi::{fold_many0, many0, many1, separated_list0, separated_list1};
use nom::sequence::{delimited, pair, preceded, terminated, tuple};

use crate::{not_ws, ErrorContext, ParseResult, WithSpan};
use crate::{not_ws, ErrorContext, ParseErr, ParseResult, WithSpan};

use super::{
bool_lit, char_lit, filter, identifier, is_ws, keyword, num_lit, path_or_identifier, skip_till,
Expand Down Expand Up @@ -188,6 +188,8 @@ pub enum Target<'a> {
BoolLit(&'a str),
Path(Vec<&'a str>),
OrChain(Vec<Target<'a>>),
Placeholder(&'a str),
Rest(&'a str),
}

impl<'a> Target<'a> {
Expand Down Expand Up @@ -221,7 +223,7 @@ impl<'a> Target<'a> {
return Ok((i, Self::Tuple(Vec::new(), Vec::new())));
}

let (i, first_target) = Self::parse(i, s)?;
let (i, first_target) = Self::unnamed(i, s)?;
let (i, is_unused_paren) = opt_closing_paren(i)?;
if is_unused_paren {
return Ok((i, first_target));
Expand All @@ -230,7 +232,7 @@ impl<'a> Target<'a> {
let mut targets = vec![first_target];
let (i, _) = cut(tuple((
fold_many0(
preceded(ws(char(',')), |i| Self::parse(i, s)),
preceded(ws(char(',')), |i| Self::unnamed(i, s)),
|| (),
|_, target| {
targets.push(target);
Expand All @@ -239,7 +241,7 @@ impl<'a> Target<'a> {
opt(ws(char(','))),
ws(cut(char(')'))),
)))(i)?;
return Ok((i, Self::Tuple(Vec::new(), targets)));
return Ok((i, Self::Tuple(Vec::new(), only_one_rest_pattern(targets)?)));
}

let path = |i| {
Expand All @@ -260,11 +262,11 @@ impl<'a> Target<'a> {
let (i, targets) = alt((
map(char(')'), |_| Vec::new()),
terminated(
cut(separated_list1(ws(char(',')), |i| Self::parse(i, s))),
cut(separated_list1(ws(char(',')), |i| Self::unnamed(i, s))),
pair(opt(ws(char(','))), ws(cut(char(')')))),
),
))(i)?;
return Ok((i, Self::Tuple(path, targets)));
return Ok((i, Self::Tuple(path, only_one_rest_pattern(targets)?)));
}

let (i, is_named_struct) = opt_opening_brace(i)?;
Expand All @@ -284,7 +286,11 @@ impl<'a> Target<'a> {

// neither literal nor struct nor path
let (new_i, name) = identifier(i)?;
Ok((new_i, Self::verify_name(i, name)?))
let target = match name {
"_" => Self::Placeholder(name),
_ => Self::verify_name(i, name)?,
};
Ok((new_i, target))
}

fn lit(i: &'a str) -> ParseResult<'a, Self> {
Expand All @@ -296,20 +302,46 @@ impl<'a> Target<'a> {
))(i)
}

fn unnamed(i: &'a str, s: &State<'_>) -> ParseResult<'a, Self> {
alt((Self::rest, |i| Self::parse(i, s)))(i)
}

fn named(init_i: &'a str, s: &State<'_>) -> ParseResult<'a, (&'a str, Self)> {
let (i, rest) = opt(consumed(Self::rest))(init_i)?;
if let Some(rest) = rest {
let (_, chr) = ws(opt(one_of(",:")))(i)?;
if let Some(chr) = chr {
return Err(nom::Err::Failure(ErrorContext::new(
format!("unexpected `{chr}` character after `..`"),
i,
)));
}
return Ok((i, rest));
}

let (i, (src, target)) = pair(
identifier,
opt(preceded(ws(char(':')), |i| Self::parse(i, s))),
)(init_i)?;

if src == "_" {
return Err(nom::Err::Failure(ErrorContext::new(
"cannot use placeholder `_` as source in named struct",
init_i,
)));
}

let target = match target {
Some(target) => target,
None => Self::verify_name(init_i, src)?,
};

Ok((i, (src, target)))
}

fn rest(i: &'a str) -> ParseResult<'a, Self> {
map(tag(".."), Self::Rest)(i)
}

fn verify_name(input: &'a str, name: &'a str) -> Result<Self, nom::Err<ErrorContext<'a>>> {
match name {
"self" | "writer" => Err(nom::Err::Failure(ErrorContext::new(
Expand All @@ -321,6 +353,23 @@ impl<'a> Target<'a> {
}
}

fn only_one_rest_pattern<'a>(targets: Vec<Target<'a>>) -> Result<Vec<Target<'a>>, ParseErr<'a>> {
let snd_wildcard = targets
.iter()
.filter_map(|t| match t {
Target::Rest(s) => Some(s),
_ => None,
})
.nth(1);
if let Some(snd_wildcard) = snd_wildcard {
return Err(nom::Err::Failure(ErrorContext::new(
"`..` can only be used once per tuple pattern",
snd_wildcard,
)));
}
Ok(targets)
}

#[derive(Debug, PartialEq)]
pub struct When<'a> {
pub ws: Ws,
Expand All @@ -347,7 +396,7 @@ impl<'a> When<'a> {
WithSpan::new(
Self {
ws: Ws(pws, nws),
target: Target::Name("_"),
target: Target::Placeholder("_"),
nodes,
},
start,
Expand Down
78 changes: 78 additions & 0 deletions testing/tests/rest_pattern.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use rinja::Template;

#[test]
fn a() {
#[derive(Template)]
#[template(source = "{% if let (a, ..) = abc %}-{{a}}-{% endif %}", ext = "txt")]
struct Tmpl {
abc: (u32, u32, u32),
}

assert_eq!(Tmpl { abc: (1, 2, 3) }.to_string(), "-1-");
}

#[test]
fn ab() {
#[derive(Template)]
#[template(
source = "{% if let (a, b, ..) = abc %}-{{a}}{{b}}-{% endif %}",
ext = "txt"
)]
struct Tmpl {
abc: (u32, u32, u32),
}

assert_eq!(Tmpl { abc: (1, 2, 3) }.to_string(), "-12-");
}

#[test]
fn abc() {
#[derive(Template)]
#[template(
source = "{% if let (a, b, c, ..) = abc %}-{{a}}{{b}}{{c}}-{% endif %}",
ext = "txt"
)]
struct Tmpl1 {
abc: (u32, u32, u32),
}

assert_eq!(Tmpl1 { abc: (1, 2, 3) }.to_string(), "-123-");

assert_eq!(Tmpl2 { abc: (1, 2, 3) }.to_string(), "-123-");

#[derive(Template)]
#[template(
source = "{% if let (a, b, c, ..) = abc %}-{{a}}{{b}}{{c}}-{% endif %}",
ext = "txt"
)]
struct Tmpl2 {
abc: (u32, u32, u32),
}

assert_eq!(Tmpl2 { abc: (1, 2, 3) }.to_string(), "-123-");
}

#[test]
fn bc() {
#[derive(Template)]
#[template(
source = "{% if let (.., b, c) = abc %}-{{b}}{{c}}-{% endif %}",
ext = "txt"
)]
struct Tmpl {
abc: (u32, u32, u32),
}

assert_eq!(Tmpl { abc: (1, 2, 3) }.to_string(), "-23-");
}

#[test]
fn c() {
#[derive(Template)]
#[template(source = "{% if let (.., c) = abc %}-{{c}}-{% endif %}", ext = "txt")]
struct Tmpl {
abc: (u32, u32, u32),
}

assert_eq!(Tmpl { abc: (1, 2, 3) }.to_string(), "-3-");
}

0 comments on commit a0a6c44

Please sign in to comment.