From 4bd0d489071eb087a620bb7e9696ec17a8fb336d Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Thu, 17 Oct 2024 16:57:58 -0400 Subject: [PATCH] perf(4D-fake-GLV/native): simpler scalar decomposition check --- std/algebra/native/sw_bls12377/g1.go | 87 +++++++++---------------- std/algebra/native/sw_bls12377/hints.go | 66 ++++++++----------- 2 files changed, 57 insertions(+), 96 deletions(-) diff --git a/std/algebra/native/sw_bls12377/g1.go b/std/algebra/native/sw_bls12377/g1.go index c16d0195c..049a9819e 100644 --- a/std/algebra/native/sw_bls12377/g1.go +++ b/std/algebra/native/sw_bls12377/g1.go @@ -709,81 +709,52 @@ func (R *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s fronte // This can be done through a hinted half-GCD in the number field // K=Q[w]/f(w). This corresponds to K being the Eisenstein ring of // integers i.e. w is a primitive cube root of unity, f(w)=w^2+w+1=0. - sd, err := api.NewHint(halfGCDEisenstein, 10, _s, cc.lambda) + // + // The hint returns u1, u2, v1, v2 and the quotient q. + // In-circuit we check that (v1 + λ*v2)*s = (u1 + λ*u2) + r*q + sd, err := api.NewHint(halfGCDEisenstein, 5, _s, cc.lambda) if err != nil { panic(fmt.Sprintf("halfGCDEisenstein hint: %v", err)) } - u1, u2, v1, v2, w1, w2, _, _, s1, s2 := sd[0], sd[1], sd[2], sd[3], sd[4], sd[5], sd[6], sd[7], sd[8], sd[9] - // r is fixed and is equal to 91893752504881257701523279626832445440 - ω - var r1 big.Int - r1.SetString("91893752504881257701523279626832445440", 10) + u1, u2, v1, v2, q := sd[0], sd[1], sd[2], sd[3], sd[4] // Eisenstein integers real and imaginary parts can be negative. So we // return the absolute value in the hint and negate the corresponding // points here when needed. - signs, err := api.NewHint(halfGCDEisensteinSigns, 6, _s, cc.lambda) + signs, err := api.NewHint(halfGCDEisensteinSigns, 5, _s, cc.lambda) if err != nil { panic(fmt.Sprintf("halfGCDEisensteinSigns hint: %v", err)) } - selector1, selector2, selector3, selector4, selector5, selector6 := signs[0], signs[1], signs[2], signs[3], signs[4], signs[5] + isNegu1, isNegu2, isNegv1, isNegv2, isNegq := signs[0], signs[1], signs[2], signs[3], signs[4] // We need to check that: - // (s1 + j*s2)(v1 + j*v2) + (r1 + j*r2)(w1 + j*w2) - (u1 + j*u2) = 0 - // which is equivalent to checking: - // s1*v1 + r1*w1 = s2*v2 + r2*w2 + u1 and - // s1*v2 + s2*v1 + r1*w2 + r2*w1 = s2*v2 + r2*w2 + u2 - // or that: - // s1*v1 + r1*w1 + u2 = s1*v2 + s2*v1 + r1*w2 + r2*w1 + u1 - // - // Since all these values can be negative, we gather all positive values - // either in the lhs or rhs and check equality. - s1v1 := api.Mul(s1, v1) - r1w1 := api.Mul(r1, w1) - s1v2 := api.Mul(s1, v2) - s2v1 := api.Mul(s2, v1) - r1w2 := api.Mul(r1, w2) - - lhs1 := api.Select(selector3, s1v1, 0) - lhs2 := api.Select(selector5, 0, r1w1) - lhs3 := api.Select(selector4, 0, s1v2) - lhs4 := api.Select(selector3, 0, s2v1) - lhs5 := api.Select(selector6, r1w2, 0) - lhs6 := api.Select(selector5, 0, w1) - lhs7 := api.Select(selector1, u1, 0) - lhs8 := api.Select(selector2, 0, u2) + // s*(v1 + λ*v2) + u1 + λ*u2 - r * q = 0 + sv1 := api.Mul(_s, v1) + sλv2 := api.Mul(_s, api.Mul(cc.lambda, v2)) + λu2 := api.Mul(cc.lambda, u2) + rq := api.Mul(cc.fr, q) + + lhs1 := api.Select(isNegv1, 0, sv1) + lhs2 := api.Select(isNegv2, 0, sλv2) + lhs3 := api.Select(isNegu1, 0, u1) + lhs4 := api.Select(isNegu2, 0, λu2) + lhs5 := api.Select(isNegq, rq, 0) lhs := api.Add( api.Add(lhs1, lhs2), api.Add(lhs3, lhs4), ) - lhs = api.Add( - lhs, - api.Add(lhs5, lhs6), - ) - lhs = api.Add( - lhs, - api.Add(lhs7, lhs8), - ) + lhs = api.Add(lhs, lhs5) - rhs1 := api.Select(selector3, 0, s1v1) - rhs2 := api.Select(selector5, r1w1, 0) - rhs3 := api.Select(selector4, s1v2, 0) - rhs4 := api.Select(selector3, s2v1, 0) - rhs5 := api.Select(selector6, 0, r1w2) - rhs6 := api.Select(selector5, w1, 0) - rhs7 := api.Select(selector1, 0, u1) - rhs8 := api.Select(selector2, u2, 0) + rhs1 := api.Select(isNegv1, sv1, 0) + rhs2 := api.Select(isNegv2, sλv2, 0) + rhs3 := api.Select(isNegu1, u1, 0) + rhs4 := api.Select(isNegu2, λu2, 0) + rhs5 := api.Select(isNegq, 0, rq) rhs := api.Add( api.Add(rhs1, rhs2), api.Add(rhs3, rhs4), ) - rhs = api.Add( - rhs, - api.Add(rhs5, rhs6), - ) - rhs = api.Add( - rhs, - api.Add(rhs7, rhs8), - ) + rhs = api.Add(rhs, rhs5) api.AssertIsEqual(lhs, rhs) @@ -810,12 +781,12 @@ func (R *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s fronte negPY := api.Neg(_P.Y) tableP[1] = G1Affine{ X: _P.X, - Y: api.Select(selector1, negPY, _P.Y), + Y: api.Select(isNegu1, negPY, _P.Y), } tableP[0].Neg(api, tableP[1]) tablePhiP[1] = G1Affine{ X: api.Mul(_P.X, cc.thirdRootOne1), - Y: api.Select(selector2, negPY, _P.Y), + Y: api.Select(isNegu2, negPY, _P.Y), } tablePhiP[0].Neg(api, tablePhiP[1]) @@ -824,12 +795,12 @@ func (R *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s fronte negQY := api.Neg(Q.Y) tableQ[1] = G1Affine{ X: Q.X, - Y: api.Select(selector3, negQY, Q.Y), + Y: api.Select(isNegv1, negQY, Q.Y), } tableQ[0].Neg(api, tableQ[1]) tablePhiQ[1] = G1Affine{ X: api.Mul(Q.X, cc.thirdRootOne1), - Y: api.Select(selector4, negQY, Q.Y), + Y: api.Select(isNegv2, negQY, Q.Y), } tablePhiQ[0].Neg(api, tablePhiQ[1]) diff --git a/std/algebra/native/sw_bls12377/hints.go b/std/algebra/native/sw_bls12377/hints.go index da3f71a05..9c0c86333 100644 --- a/std/algebra/native/sw_bls12377/hints.go +++ b/std/algebra/native/sw_bls12377/hints.go @@ -112,24 +112,23 @@ func scalarMulGLVG1Hint(scalarField *big.Int, inputs []*big.Int, outputs []*big. return nil } -func halfGCDEisenstein(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { +func halfGCDEisenstein(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error { if len(inputs) != 2 { return fmt.Errorf("expecting two input") } - if len(outputs) != 10 { - return fmt.Errorf("expecting ten outputs") + if len(outputs) != 5 { + return fmt.Errorf("expecting five outputs") } + cc := getInnerCurveConfig(scalarField) glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(mod, inputs[1], glvBasis) + ecc.PrecomputeLattice(cc.fr, inputs[1], glvBasis) r := eisenstein.ComplexNumber{ A0: &glvBasis.V1[0], A1: &glvBasis.V1[1], } - // r = 91893752504881257701523279626832445440 - ω sp := ecc.SplitScalar(inputs[0], glvBasis) // in-circuit we check that Q - [s]P = 0 or equivalently Q + [-s]P = 0 // so here we return -s instead of s. - // s.A0 and s.A1 are always positive. s := eisenstein.ComplexNumber{ A0: &sp[0], A1: &sp[1], @@ -140,12 +139,14 @@ func halfGCDEisenstein(mod *big.Int, inputs []*big.Int, outputs []*big.Int) erro outputs[1].Set(res[0].A1) outputs[2].Set(res[1].A0) outputs[3].Set(res[1].A1) - outputs[4].Set(res[2].A0) - outputs[5].Set(res[2].A1) - outputs[6].Set(r.A0) - outputs[7].Set(r.A1) - outputs[8].Set(s.A0) - outputs[9].Set(s.A1) + outputs[4].Mul(res[1].A1, inputs[1]). + Add(outputs[4], res[1].A0). + Mul(outputs[4], inputs[0]). + Add(outputs[4], res[0].A0) + s.A0.Mul(res[0].A1, inputs[1]) + outputs[4].Add(outputs[4], s.A0). + Div(outputs[4], cc.fr) + if outputs[0].Sign() == -1 { outputs[0].Neg(outputs[0]) } @@ -161,34 +162,20 @@ func halfGCDEisenstein(mod *big.Int, inputs []*big.Int, outputs []*big.Int) erro if outputs[4].Sign() == -1 { outputs[4].Neg(outputs[4]) } - if outputs[5].Sign() == -1 { - outputs[5].Neg(outputs[5]) - } - if outputs[6].Sign() == -1 { - outputs[6].Neg(outputs[6]) - } - if outputs[7].Sign() == -1 { - outputs[7].Neg(outputs[7]) - } - if outputs[8].Sign() == -1 { - outputs[8].Neg(outputs[8]) - } - if outputs[9].Sign() == -1 { - outputs[9].Neg(outputs[9]) - } + return nil } -func halfGCDEisensteinSigns(mod *big.Int, inputs, outputs []*big.Int) error { +func halfGCDEisensteinSigns(scalarField *big.Int, inputs, outputs []*big.Int) error { if len(inputs) != 2 { return fmt.Errorf("expecting two input") } - if len(outputs) != 6 { - return fmt.Errorf("expecting six outputs") + if len(outputs) != 5 { + return fmt.Errorf("expecting five outputs") } + cc := getInnerCurveConfig(scalarField) glvBasis := new(ecc.Lattice) - ecc.PrecomputeLattice(mod, inputs[1], glvBasis) - // r = 91893752504881257701523279626832445440 - ω + ecc.PrecomputeLattice(cc.fr, inputs[1], glvBasis) r := eisenstein.ComplexNumber{ A0: &glvBasis.V1[0], A1: &glvBasis.V1[1], @@ -207,8 +194,15 @@ func halfGCDEisensteinSigns(mod *big.Int, inputs, outputs []*big.Int) error { outputs[2].SetUint64(0) outputs[3].SetUint64(0) outputs[4].SetUint64(0) - outputs[5].SetUint64(0) res := eisenstein.HalfGCD(&r, &s) + s.A1.Mul(res[1].A1, inputs[1]). + Add(s.A1, res[1].A0). + Mul(s.A1, inputs[0]). + Add(s.A1, res[0].A0) + s.A0.Mul(res[0].A1, inputs[1]) + s.A1.Add(s.A1, s.A0). + Div(s.A1, cc.fr) + if res[0].A0.Sign() == -1 { outputs[0].SetUint64(1) } @@ -221,12 +215,8 @@ func halfGCDEisensteinSigns(mod *big.Int, inputs, outputs []*big.Int) error { if res[1].A1.Sign() == -1 { outputs[3].SetUint64(1) } - if res[2].A0.Sign() == -1 { + if s.A1.Sign() == -1 { outputs[4].SetUint64(1) } - if res[2].A1.Sign() == -1 { - outputs[5].SetUint64(1) - } - return nil }