From 80373d612c023e3e165b49b6d1729486b0ba3b4b Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Wed, 13 Mar 2024 17:04:16 +0000 Subject: [PATCH] feat: optimize sha2 implementation (#4441) # Description ## Problem\* Resolves ## Summary\* We're currently performing byte decompositions in the sha2 functions through repeated division. This would be more efficient if we just did the full byte decomposition at once and then iterate through the results. I've also removed some noop casts. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[Exceptional Case]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- noir_stdlib/src/sha256.nr | 22 ++++++++++++---------- noir_stdlib/src/sha512.nr | 22 ++++++++++++---------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/noir_stdlib/src/sha256.nr b/noir_stdlib/src/sha256.nr index 2f686a64165..8ca6808568d 100644 --- a/noir_stdlib/src/sha256.nr +++ b/noir_stdlib/src/sha256.nr @@ -6,9 +6,11 @@ fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] { let mut msg32: [u32; 16] = [0; 16]; for i in 0..16 { + let mut msg_field: Field = 0; for j in 0..4 { - msg32[15 - i] = (msg32[15 - i] << 8) + msg[64 - 4*(i + 1) + j] as u32; + msg_field = msg_field * 256 + msg[64 - 4*(i + 1) + j] as Field; } + msg32[15 - i] = msg_field as u32; } msg32 @@ -21,7 +23,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { let mut i: u64 = 0; // Message byte pointer for k in 0..N { // Populate msg_block - msg_block[i as Field] = msg[k]; + msg_block[i] = msg[k]; i = i + 1; if i == 64 { // Enough to hash block @@ -32,7 +34,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { } // Pad the rest such that we have a [u32; 2] block at the end representing the length // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). - msg_block[i as Field] = 1 << 7; + msg_block[i] = 1 << 7; i = i + 1; // If i >= 57, there aren't enough bits in the current message block to accomplish this, so // the 1 and 0s fill up the current block, which we then compress accordingly. @@ -41,7 +43,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { if i < 64 { for _i in 57..64 { if i <= 63 { - msg_block[i as Field] = 0; + msg_block[i] = 0; i += 1; } } @@ -51,16 +53,16 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { i = 0; } + let len = 8 * msg.len(); + let len_bytes = (len as Field).to_le_bytes(8); for _i in 0..64 { // In any case, fill blocks up with zeros until the last 64 (i.e. until i = 56). if i < 56 { - msg_block[i as Field] = 0; + msg_block[i] = 0; i = i + 1; } else if i < 64 { - let mut len = 8 * msg.len(); for j in 0..8 { - msg_block[63 - j] = len as u8; - len >>= 8; + msg_block[63 - j] = len_bytes[j]; } i += 8; } @@ -70,9 +72,9 @@ pub fn digest(msg: [u8; N]) -> [u8; 32] { // Return final hash as byte array for j in 0..8 { + let h_bytes = (h[7 - j] as Field).to_le_bytes(4); for k in 0..4 { - out_h[31 - 4*j - k] = h[7 - j] as u8; - h[7-j] >>= 8; + out_h[31 - 4*j - k] = h_bytes[k]; } } diff --git a/noir_stdlib/src/sha512.nr b/noir_stdlib/src/sha512.nr index 4dfe78308e2..a766ae50d55 100644 --- a/noir_stdlib/src/sha512.nr +++ b/noir_stdlib/src/sha512.nr @@ -77,9 +77,11 @@ fn msg_u8_to_u64(msg: [u8; 128]) -> [u64; 16] { let mut msg64: [u64; 16] = [0; 16]; for i in 0..16 { + let mut msg_field: Field = 0; for j in 0..8 { - msg64[15 - i] = (msg64[15 - i] << 8) + msg[128 - 8*(i + 1) + j] as u64; + msg_field = msg_field * 256 + msg[128 - 8*(i + 1) + j] as Field; } + msg64[15 - i] = msg_field as u64; } msg64 @@ -94,7 +96,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 64] { let mut i: u64 = 0; // Message byte pointer for k in 0..msg.len() { // Populate msg_block - msg_block[i as Field] = msg[k]; + msg_block[i] = msg[k]; i = i + 1; if i == 128 { // Enough to hash block @@ -108,7 +110,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 64] { } // Pad the rest such that we have a [u64; 2] block at the end representing the length // of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]). - msg_block[i as Field] = 1 << 7; + msg_block[i] = 1 << 7; i += 1; // If i >= 113, there aren't enough bits in the current message block to accomplish this, so // the 1 and 0s fill up the current block, which we then compress accordingly. @@ -117,7 +119,7 @@ pub fn digest(msg: [u8; N]) -> [u8; 64] { if i < 128 { for _i in 113..128 { if i <= 127 { - msg_block[i as Field] = 0; + msg_block[i] = 0; i += 1; } } @@ -130,16 +132,16 @@ pub fn digest(msg: [u8; N]) -> [u8; 64] { i = 0; } + let len = 8 * msg.len(); + let len_bytes = (len as Field).to_le_bytes(16); for _i in 0..128 { // In any case, fill blocks up with zeros until the last 128 (i.e. until i = 112). if i < 112 { - msg_block[i as Field] = 0; + msg_block[i] = 0; i += 1; } else if i < 128 { - let mut len = 8 * msg.len(); for j in 0..16 { - msg_block[127 - j] = len as u8; - len >>= 8; + msg_block[127 - j] = len_bytes[j]; } i += 16; // Done. } @@ -151,9 +153,9 @@ pub fn digest(msg: [u8; N]) -> [u8; 64] { } // Return final hash as byte array for j in 0..8 { + let h_bytes = (h[7 - j] as Field).to_le_bytes(8); for k in 0..8 { - out_h[63 - 8*j - k] = h[7 - j] as u8; - h[7-j] >>= 8; + out_h[63 - 8*j - k] = h_bytes[k]; } }