Skip to content

Commit

Permalink
Merge pull request #92 from Kijewski/pr-escape-at-compile-time
Browse files Browse the repository at this point in the history
derive: escape strings at compile-time when possible
  • Loading branch information
GuillaumeGomez authored Jul 28, 2024
2 parents ab485df + 7ea3484 commit 7408a98
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 74 deletions.
68 changes: 6 additions & 62 deletions rinja/src/filters/escape.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::convert::Infallible;
use std::fmt::{self, Display, Formatter, Write};
use std::num::NonZeroU8;
use std::{borrow, str};

/// Marks a string (or other `Display` type) as safe
Expand Down Expand Up @@ -83,69 +82,14 @@ pub fn e(text: impl fmt::Display, escaper: impl Escaper) -> Result<Safe<impl Dis
pub struct Html;

impl Escaper for Html {
fn write_escaped_str<W: Write>(&self, mut fmt: W, string: &str) -> fmt::Result {
let mut escaped_buf = *b"&#__;";
let mut last = 0;

for (index, byte) in string.bytes().enumerate() {
const MIN_CHAR: u8 = b'"';
const MAX_CHAR: u8 = b'>';

struct Table {
_align: [usize; 0],
lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize],
}

const TABLE: Table = {
const fn n(c: u8) -> Option<[NonZeroU8; 2]> {
let n0 = match NonZeroU8::new(c / 10 + b'0') {
Some(n) => n,
None => panic!(),
};
let n1 = match NonZeroU8::new(c % 10 + b'0') {
Some(n) => n,
None => panic!(),
};
Some([n0, n1])
}

let mut table = Table {
_align: [],
lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize],
};

table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"');
table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&');
table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\'');
table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<');
table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>');
table
};

let escaped = match byte {
MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize],
_ => None,
};
if let Some(escaped) = escaped {
escaped_buf[2] = escaped[0].get();
escaped_buf[3] = escaped[1].get();
fmt.write_str(&string[last..index])?;
fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?;
last = index + 1;
}
}
fmt.write_str(&string[last..])
#[inline]
fn write_escaped_str<W: Write>(&self, fmt: W, string: &str) -> fmt::Result {
crate::html::write_escaped_str(fmt, string)
}

fn write_escaped_char<W: Write>(&self, mut fmt: W, c: char) -> fmt::Result {
fmt.write_str(match (c.is_ascii(), c as u8) {
(true, b'"') => "&#34;",
(true, b'&') => "&#38;",
(true, b'\'') => "&#39;",
(true, b'<') => "&#60;",
(true, b'>') => "&#62;",
_ => return fmt.write_char(c),
})
#[inline]
fn write_escaped_char<W: Write>(&self, fmt: W, c: char) -> fmt::Result {
crate::html::write_escaped_char(fmt, c)
}
}

Expand Down
71 changes: 71 additions & 0 deletions rinja/src/html.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::fmt;
use std::num::NonZeroU8;

#[allow(unused)]
pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, string: &str) -> fmt::Result {
let mut escaped_buf = *b"&#__;";
let mut last = 0;

for (index, byte) in string.bytes().enumerate() {
let escaped = match byte {
MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize],
_ => None,
};
if let Some(escaped) = escaped {
escaped_buf[2] = escaped[0].get();
escaped_buf[3] = escaped[1].get();
fmt.write_str(&string[last..index])?;
fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?;
last = index + 1;
}
}
fmt.write_str(&string[last..])
}

#[allow(unused)]
pub(crate) fn write_escaped_char(mut fmt: impl fmt::Write, c: char) -> fmt::Result {
fmt.write_str(match (c.is_ascii(), c as u8) {
(true, b'"') => "&#34;",
(true, b'&') => "&#38;",
(true, b'\'') => "&#39;",
(true, b'<') => "&#60;",
(true, b'>') => "&#62;",
_ => return fmt.write_char(c),
})
}

const MIN_CHAR: u8 = b'"';
const MAX_CHAR: u8 = b'>';

struct Table {
_align: [usize; 0],
lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize],
}

const TABLE: Table = {
const fn n(c: u8) -> Option<[NonZeroU8; 2]> {
assert!(MIN_CHAR <= c && c <= MAX_CHAR);

let n0 = match NonZeroU8::new(c / 10 + b'0') {
Some(n) => n,
None => panic!(),
};
let n1 = match NonZeroU8::new(c % 10 + b'0') {
Some(n) => n,
None => panic!(),
};
Some([n0, n1])
}

let mut table = Table {
_align: [],
lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize],
};

table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"');
table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&');
table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\'');
table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<');
table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>');
table
};
1 change: 1 addition & 0 deletions rinja/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
mod error;
pub mod filters;
pub mod helpers;
mod html;

use std::{fmt, io};

Expand Down
97 changes: 85 additions & 12 deletions rinja_derive/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use quote::quote;

use crate::config::WhitespaceHandling;
use crate::heritage::{Context, Heritage};
use crate::html::write_escaped_str;
use crate::input::{Source, TemplateInput};
use crate::{CompileError, MsgValidEscapers, CRATE};

Expand Down Expand Up @@ -1162,8 +1163,76 @@ impl<'a> Generator<'a> {
}

fn write_expr(&mut self, ws: Ws, s: &'a WithSpan<'a, Expr<'a>>) {
// In here, we inspect in the expression if it is a literal, and if it is, whether it
// can be escaped at compile time. We use an IIFE to make the code more readable
// (immediate returns, try expressions).
let writable = (|| -> Option<Writable<'a>> {
enum InputKind<'a> {
StrLit(&'a str),
CharLit(&'a str),
}
enum OutputKind {
Html,
Text,
}

// for now, we only escape strings and chars at compile time
let lit = match &**s {
Expr::StrLit(input) => InputKind::StrLit(input),
Expr::CharLit(input) => InputKind::CharLit(input),
_ => return None,
};

// we only optimize for known escapers
let output = match self.input.escaper.strip_prefix(CRATE)? {
"::filters::Html" => OutputKind::Html,
"::filters::Text" => OutputKind::Text,
_ => return None,
};

// the input could be string escaped if it contains any backslashes
let escaped = match lit {
InputKind::StrLit(s) => s,
InputKind::CharLit(s) => s,
};
let unescaped = if escaped.find('\\').is_none() {
// if the literal does not contain any backslashes, then it does not need unescaping
Cow::Borrowed(escaped)
} else {
// convert the input into a TokenStream and extract the first token
Cow::Owned(match lit {
InputKind::StrLit(escaped) => {
let input = format!(r#""{escaped}""#);
let input = input.parse().ok()?;
let input = syn::parse2::<syn::LitStr>(input).ok()?;
input.value()
}
InputKind::CharLit(escaped) => {
let input = format!(r#"'{escaped}'"#);
let input = input.parse().ok()?;
let input = syn::parse2::<syn::LitChar>(input).ok()?;
input.value().to_string()
}
})
};

// escape the un-string-escaped input using the selected escaper
Some(Writable::Lit(match output {
OutputKind::Text => unescaped,
OutputKind::Html => {
let mut escaped = String::with_capacity(unescaped.len() + 20);
write_escaped_str(&mut escaped, &unescaped).ok()?;
match escaped == unescaped {
true => unescaped,
false => Cow::Owned(escaped),
}
}
}))
})()
.unwrap_or(Writable::Expr(s));

self.handle_ws(ws);
self.buf_writable.push(Writable::Expr(s));
self.buf_writable.push(writable);
}

// Write expression buffer and empty
Expand All @@ -1174,7 +1243,7 @@ impl<'a> Generator<'a> {
) -> Result<usize, CompileError> {
let mut size_hint = 0;
let items = mem::take(&mut self.buf_writable.buf);
let mut it = items.into_iter().enumerate().peekable();
let mut it = items.iter().enumerate().peekable();

while let Some((_, Writable::Lit(s))) = it.peek() {
size_hint += buf.write_writer(s);
Expand Down Expand Up @@ -1267,20 +1336,23 @@ impl<'a> Generator<'a> {
assert!(rws.is_empty());
self.next_ws = Some(lws);
}
WhitespaceHandling::Preserve => self.buf_writable.push(Writable::Lit(lws)),
WhitespaceHandling::Preserve => {
self.buf_writable.push(Writable::Lit(Cow::Borrowed(lws)))
}
WhitespaceHandling::Minimize => {
self.buf_writable
.push(Writable::Lit(match lws.contains('\n') {
self.buf_writable.push(Writable::Lit(Cow::Borrowed(
match lws.contains('\n') {
true => "\n",
false => " ",
}));
},
)));
}
}
}

if !val.is_empty() {
self.skip_ws = WhitespaceHandling::Preserve;
self.buf_writable.push(Writable::Lit(val));
self.buf_writable.push(Writable::Lit(Cow::Borrowed(val)));
}

if !rws.is_empty() {
Expand Down Expand Up @@ -2031,17 +2103,18 @@ impl<'a> Generator<'a> {
WhitespaceHandling::Preserve => {
let val = self.next_ws.unwrap();
if !val.is_empty() {
self.buf_writable.push(Writable::Lit(val));
self.buf_writable.push(Writable::Lit(Cow::Borrowed(val)));
}
}
WhitespaceHandling::Minimize => {
let val = self.next_ws.unwrap();
if !val.is_empty() {
self.buf_writable
.push(Writable::Lit(match val.contains('\n') {
self.buf_writable.push(Writable::Lit(Cow::Borrowed(
match val.contains('\n') {
true => "\n",
false => " ",
}));
},
)));
}
}
WhitespaceHandling::Suppress => {}
Expand Down Expand Up @@ -2481,7 +2554,7 @@ impl<'a> Deref for WritableBuffer<'a> {

#[derive(Debug)]
enum Writable<'a> {
Lit(&'a str),
Lit(Cow<'a, str>),
Expr(&'a WithSpan<'a, Expr<'a>>),
}

Expand Down
1 change: 1 addition & 0 deletions rinja_derive/src/html.rs
1 change: 1 addition & 0 deletions rinja_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
mod config;
mod generator;
mod heritage;
mod html;
mod input;
#[cfg(test)]
mod tests;
Expand Down
35 changes: 35 additions & 0 deletions rinja_derive/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,38 @@ fn check_bool_conditions() {
3,
);
}

#[test]
fn check_escaping_at_compile_time() {
compare(
r#"The card is
{%- match suit %}
{%- when Suit::Clubs or Suit::Spades -%}
{{ " black" }}
{%- when Suit::Diamonds or Suit::Hearts -%}
{{ " red" }}
{%- endmatch %}"#,
r#"writer.write_str("The card is")?;
match &self.suit {
Suit::Clubs | Suit::Spades => {
writer.write_str(" black")?;
}
Suit::Diamonds | Suit::Hearts => {
writer.write_str(" red")?;
}
}"#,
&[("suit", "Suit")],
16,
);

compare(
r#"{{ '\x41' }}{{ '\n' }}{{ '\r' }}{{ '\t' }}{{ '\\' }}{{ '\u{2665}' }}{{ '\'' }}{{ '\"' }}{{ '"' }}
{{ "\x41\n\r\t\\\u{2665}\'\"'" }}"#,
r#"writer.write_str("A
\r \\♥'\"\"
A
\r \\♥'\"'")?;"#,
&[],
23,
);
}

0 comments on commit 7408a98

Please sign in to comment.