Skip to content

Commit

Permalink
Move the branches logic out of the write_if method
Browse files Browse the repository at this point in the history
  • Loading branch information
GuillaumeGomez committed Jul 22, 2024
1 parent 43fe75a commit 39e9f21
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 79 deletions.
212 changes: 138 additions & 74 deletions rinja_derive/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,13 +357,9 @@ impl<'a> Generator<'a> {
Ok(size_hint)
}

fn is_var_defined(&self, expr: &Expr<'_>) -> bool {
match expr {
Expr::Var(s) => {
self.locals.get(&(*s).into()).is_some() || self.input.fields.iter().any(|f| f == s)
}
_ => false,
}
fn is_var_defined(&self, var_name: &str) -> bool {
self.locals.get(&var_name.into()).is_some()
|| self.input.fields.iter().any(|f| f == var_name)
}

fn evaluate_condition(
Expand Down Expand Up @@ -456,74 +452,44 @@ impl<'a> Generator<'a> {
&mut self,
ctx: &Context<'a>,
buf: &mut Buffer,
i: &'a If<'_>,
if_: &'a If<'_>,
) -> Result<usize, CompileError> {
fn write_if_cond(buf: &mut Buffer, nb_written_branches: &mut usize) {
if *nb_written_branches == 0 {
buf.write("if ");
} else {
buf.write("} else if ");
}
*nb_written_branches += 1;
}

let mut flushed = 0;
let mut arm_sizes = Vec::new();
let mut has_else = false;

let mut nb_written_branches = 0;
let mut prev_was_generated = false;
let mut stop_loop = false;
let conds = Conds::compute_branches(self, if_);

if let Some(ws_before) = conds.ws_before {
self.handle_ws(ws_before);
}

for (pos, cond_info) in conds.conds.iter().enumerate() {
// It's completely fine here since we got these indexes by iterator `if_.branches`.
let cond = unsafe { if_.branches.get_unchecked(cond_info.cond_index) };

for cond in i.branches.iter() {
self.handle_ws(cond.ws);
flushed += self.write_buf_writable(ctx, buf)?;
if stop_loop {
break;
}
if prev_was_generated {
if pos > 0 {
self.locals.pop();
}
prev_was_generated = false;
let mut generate_content = true;

self.locals.push();
let mut arm_size = 0;
if let Some(CondTest { target, expr }) = &cond.cond {
let mut only_contains_is_defined = true;
let mut generate_condition = true;

match self.evaluate_condition(expr, &mut only_contains_is_defined) {
// We generate the condition in case some calls are changing a variable, but
// no need to generate the condition body since it will never be called.
//
// However, if the condition only contains "is (not) defined" checks, then we
// can completely skip it.
EvaluatedResult::AlwaysFalse => {
if only_contains_is_defined {
continue;
}
generate_content = false;
write_if_cond(buf, &mut nb_written_branches);
}
// This case is more interesting: it means that we will always enter this
// condition, meaning that any following should not be generated. Another
// thing to take into account: if there are no if branches before this one,
// no need to generate an `else`.
EvaluatedResult::AlwaysTrue => {
stop_loop = true;
if only_contains_is_defined {
generate_condition = false;
if nb_written_branches != 0 {
buf.writeln("} else {");
has_else = true;
}
} else {
write_if_cond(buf, &mut nb_written_branches);
}
if let Some(CondTest { target, expr }) = &cond.cond {
if pos == 0 {
if cond_info.generate_condition {
buf.write("if ");
}
EvaluatedResult::Unknown => write_if_cond(buf, &mut nb_written_branches),
// Otherwise it means it will be the only condition generated,
// so nothing to be added here.
} else if cond_info.generate_condition {
buf.write("} else if ");
} else {
buf.write("} else {");
has_else = true;
}
self.locals.push();

if let Some(target) = target {
let mut expr_buf = Buffer::new();
Expand All @@ -546,7 +512,7 @@ impl<'a> Generator<'a> {
buf.write(" = &");
buf.write(expr_buf.buf);
buf.writeln(" {");
} else if generate_condition {
} else if cond_info.generate_condition {
// The following syntax `*(&(...) as &bool)` is used to
// trigger Rust's automatic dereferencing, to coerce
// e.g. `&&&&&bool` to `bool`. First `&(...) as &bool`
Expand All @@ -556,30 +522,30 @@ impl<'a> Generator<'a> {
buf.write(self.visit_expr_root(ctx, expr)?);
buf.writeln(") as &bool) {");
}
} else {
self.locals.push();
if nb_written_branches > 0 {
buf.writeln("} else {");
}
} else if pos != 0 {
buf.writeln("} else {");
has_else = true;
}

prev_was_generated = true;
if generate_content {
if cond_info.generate_content {
arm_size += self.handle(ctx, &cond.nodes, buf, AstLevel::Nested)?;
}
arm_sizes.push(arm_size);
}
self.handle_ws(i.ws);

if let Some(ws_after) = conds.ws_after {
self.handle_ws(ws_after);
}
self.handle_ws(if_.ws);
flushed += self.write_buf_writable(ctx, buf)?;
if nb_written_branches > 0 {
if conds.nb_conds > 0 {
buf.writeln("}");
}
if prev_was_generated {
if !conds.conds.is_empty() {
self.locals.pop();
}

if !has_else && nb_written_branches > 0 {
if !has_else && !conds.conds.is_empty() {
arm_sizes.push(0);
}
Ok(flushed + median(&mut arm_sizes))
Expand Down Expand Up @@ -1426,7 +1392,7 @@ impl<'a> Generator<'a> {
&mut self,
buf: &mut Buffer,
is_defined: bool,
left: &WithSpan<'_, Expr<'_>>,
left: &str,
) -> Result<DisplayWrap, CompileError> {
match (is_defined, self.is_var_defined(left)) {
(true, true) | (false, false) => buf.write("true"),
Expand Down Expand Up @@ -2187,6 +2153,104 @@ impl BufferFmt for Arguments<'_> {
}
}

struct CondInfo {
cond_index: usize,
generate_condition: bool,
generate_content: bool,
}

struct Conds {
conds: Vec<CondInfo>,
ws_before: Option<Ws>,
ws_after: Option<Ws>,
nb_conds: usize,
}

impl Conds {
fn compute_branches(generator: &Generator<'_>, i: &If<'_>) -> Self {
let mut conds = Vec::with_capacity(i.branches.len());
let mut ws_before = None;
let mut ws_after = None;
let mut nb_conds = 0;
let mut stop_loop = false;

for (i, cond) in i.branches.iter().enumerate() {
if stop_loop {
ws_after = Some(cond.ws);
break;
}
if let Some(CondTest { expr, .. }) = &cond.cond {
let mut only_contains_is_defined = true;

match generator.evaluate_condition(expr, &mut only_contains_is_defined) {
// We generate the condition in case some calls are changing a variable, but
// no need to generate the condition body since it will never be called.
//
// However, if the condition only contains "is (not) defined" checks, then we
// can completely skip it.
EvaluatedResult::AlwaysFalse => {
if only_contains_is_defined {
if conds.is_empty() && ws_before.is_none() {
// If this is the first `if` and it's skipped, we definitely don't
// want its whitespace control to be lost.
ws_before = Some(cond.ws);
}
continue;
}
nb_conds += 1;
conds.push(CondInfo {
cond_index: i,
generate_condition: true,
generate_content: false,
});
}
// This case is more interesting: it means that we will always enter this
// condition, meaning that any following should not be generated. Another
// thing to take into account: if there are no if branches before this one,
// no need to generate an `else`.
EvaluatedResult::AlwaysTrue => {
let generate_condition = !only_contains_is_defined;
if generate_condition {
nb_conds += 1;
}
conds.push(CondInfo {
cond_index: i,
generate_condition,
generate_content: true,
});
// Since it's always true, we can stop here.
stop_loop = true;
}
EvaluatedResult::Unknown => {
nb_conds += 1;
conds.push(CondInfo {
cond_index: i,
generate_condition: true,
generate_content: true,
});
}
}
} else {
let generate_condition = !conds.is_empty();
if generate_condition {
nb_conds += 1;
}
conds.push(CondInfo {
cond_index: i,
generate_condition,
generate_content: true,
});
}
}
Self {
conds,
ws_before,
ws_after,
nb_conds,
}
}
}

struct SeparatedPath<I>(I);

impl<I: IntoIterator<Item = E> + Copy, E: BufferFmt> BufferFmt for SeparatedPath<I> {
Expand Down Expand Up @@ -2365,7 +2429,7 @@ pub(crate) fn is_cacheable(expr: &WithSpan<'_, Expr<'_>>) -> bool {
Expr::Filter(Filter { arguments, .. }) => arguments.iter().all(is_cacheable),
Expr::Unary(_, arg) => is_cacheable(arg),
Expr::BinOp(_, lhs, rhs) => is_cacheable(lhs) && is_cacheable(rhs),
Expr::IsDefined(lhs) | Expr::IsNotDefined(lhs) => is_cacheable(lhs),
Expr::IsDefined(_) | Expr::IsNotDefined(_) => true,
Expr::Range(_, lhs, rhs) => {
lhs.as_ref().map_or(true, |v| is_cacheable(v))
&& rhs.as_ref().map_or(true, |v| is_cacheable(v))
Expand Down
13 changes: 8 additions & 5 deletions rinja_derive/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ struct Foo {{ {} }}"##,
);
let generated = build_template(&syn::parse_str::<syn::DeriveInput>(&jinja).unwrap()).unwrap();

let generated_s = syn::parse_str::<proc_macro2::TokenStream>(&generated)
.unwrap()
.to_string();
let generated_s = match syn::parse_str::<proc_macro2::TokenStream>(&generated) {
Ok(s) => s.to_string(),
Err(e) => {
panic!("===== Invalid generated code =====\n{generated}\n===== ERROR =====\n{e:?}")
}
};
let mut new_expected = String::with_capacity(expected.len());
for line in expected.split('\n') {
new_expected.write_fmt(format_args!("{line}\n")).unwrap();
Expand Down Expand Up @@ -53,8 +56,8 @@ impl ::std::fmt::Display for Foo {{
.to_string();
assert_eq!(
generated_s, expected_s,
"=== Expected ===\n{:#}\n=== Found ===\n{:#}\n=====",
generated, expected
"\n=== Expected ===\n{:#}\n=== Found ===\n{:#}\n=====",
expected, generated
);
}

Expand Down

0 comments on commit 39e9f21

Please sign in to comment.