Skip to content

Commit

Permalink
perf(4D-fake-GLV/native): simpler scalar decomposition check
Browse files Browse the repository at this point in the history
  • Loading branch information
yelhousni committed Oct 17, 2024
1 parent 4f7abef commit 4bd0d48
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 96 deletions.
87 changes: 29 additions & 58 deletions std/algebra/native/sw_bls12377/g1.go
Original file line number Diff line number Diff line change
Expand Up @@ -709,81 +709,52 @@ func (R *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s fronte
// This can be done through a hinted half-GCD in the number field
// K=Q[w]/f(w). This corresponds to K being the Eisenstein ring of
// integers i.e. w is a primitive cube root of unity, f(w)=w^2+w+1=0.
sd, err := api.NewHint(halfGCDEisenstein, 10, _s, cc.lambda)
//
// The hint returns u1, u2, v1, v2 and the quotient q.
// In-circuit we check that (v1 + λ*v2)*s = (u1 + λ*u2) + r*q
sd, err := api.NewHint(halfGCDEisenstein, 5, _s, cc.lambda)
if err != nil {
panic(fmt.Sprintf("halfGCDEisenstein hint: %v", err))
}
u1, u2, v1, v2, w1, w2, _, _, s1, s2 := sd[0], sd[1], sd[2], sd[3], sd[4], sd[5], sd[6], sd[7], sd[8], sd[9]
// r is fixed and is equal to 91893752504881257701523279626832445440 - ω
var r1 big.Int
r1.SetString("91893752504881257701523279626832445440", 10)
u1, u2, v1, v2, q := sd[0], sd[1], sd[2], sd[3], sd[4]

// Eisenstein integers real and imaginary parts can be negative. So we
// return the absolute value in the hint and negate the corresponding
// points here when needed.
signs, err := api.NewHint(halfGCDEisensteinSigns, 6, _s, cc.lambda)
signs, err := api.NewHint(halfGCDEisensteinSigns, 5, _s, cc.lambda)
if err != nil {
panic(fmt.Sprintf("halfGCDEisensteinSigns hint: %v", err))
}
selector1, selector2, selector3, selector4, selector5, selector6 := signs[0], signs[1], signs[2], signs[3], signs[4], signs[5]
isNegu1, isNegu2, isNegv1, isNegv2, isNegq := signs[0], signs[1], signs[2], signs[3], signs[4]

// We need to check that:
// (s1 + j*s2)(v1 + j*v2) + (r1 + j*r2)(w1 + j*w2) - (u1 + j*u2) = 0
// which is equivalent to checking:
// s1*v1 + r1*w1 = s2*v2 + r2*w2 + u1 and
// s1*v2 + s2*v1 + r1*w2 + r2*w1 = s2*v2 + r2*w2 + u2
// or that:
// s1*v1 + r1*w1 + u2 = s1*v2 + s2*v1 + r1*w2 + r2*w1 + u1
//
// Since all these values can be negative, we gather all positive values
// either in the lhs or rhs and check equality.
s1v1 := api.Mul(s1, v1)
r1w1 := api.Mul(r1, w1)
s1v2 := api.Mul(s1, v2)
s2v1 := api.Mul(s2, v1)
r1w2 := api.Mul(r1, w2)

lhs1 := api.Select(selector3, s1v1, 0)
lhs2 := api.Select(selector5, 0, r1w1)
lhs3 := api.Select(selector4, 0, s1v2)
lhs4 := api.Select(selector3, 0, s2v1)
lhs5 := api.Select(selector6, r1w2, 0)
lhs6 := api.Select(selector5, 0, w1)
lhs7 := api.Select(selector1, u1, 0)
lhs8 := api.Select(selector2, 0, u2)
// s*(v1 + λ*v2) + u1 + λ*u2 - r * q = 0
sv1 := api.Mul(_s, v1)
sλv2 := api.Mul(_s, api.Mul(cc.lambda, v2))
λu2 := api.Mul(cc.lambda, u2)
rq := api.Mul(cc.fr, q)

lhs1 := api.Select(isNegv1, 0, sv1)
lhs2 := api.Select(isNegv2, 0, sλv2)
lhs3 := api.Select(isNegu1, 0, u1)
lhs4 := api.Select(isNegu2, 0, λu2)
lhs5 := api.Select(isNegq, rq, 0)
lhs := api.Add(
api.Add(lhs1, lhs2),
api.Add(lhs3, lhs4),
)
lhs = api.Add(
lhs,
api.Add(lhs5, lhs6),
)
lhs = api.Add(
lhs,
api.Add(lhs7, lhs8),
)
lhs = api.Add(lhs, lhs5)

rhs1 := api.Select(selector3, 0, s1v1)
rhs2 := api.Select(selector5, r1w1, 0)
rhs3 := api.Select(selector4, s1v2, 0)
rhs4 := api.Select(selector3, s2v1, 0)
rhs5 := api.Select(selector6, 0, r1w2)
rhs6 := api.Select(selector5, w1, 0)
rhs7 := api.Select(selector1, 0, u1)
rhs8 := api.Select(selector2, u2, 0)
rhs1 := api.Select(isNegv1, sv1, 0)
rhs2 := api.Select(isNegv2, sλv2, 0)
rhs3 := api.Select(isNegu1, u1, 0)
rhs4 := api.Select(isNegu2, λu2, 0)
rhs5 := api.Select(isNegq, 0, rq)
rhs := api.Add(
api.Add(rhs1, rhs2),
api.Add(rhs3, rhs4),
)
rhs = api.Add(
rhs,
api.Add(rhs5, rhs6),
)
rhs = api.Add(
rhs,
api.Add(rhs7, rhs8),
)
rhs = api.Add(rhs, rhs5)

api.AssertIsEqual(lhs, rhs)

Expand All @@ -810,12 +781,12 @@ func (R *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s fronte
negPY := api.Neg(_P.Y)
tableP[1] = G1Affine{
X: _P.X,
Y: api.Select(selector1, negPY, _P.Y),
Y: api.Select(isNegu1, negPY, _P.Y),
}
tableP[0].Neg(api, tableP[1])
tablePhiP[1] = G1Affine{
X: api.Mul(_P.X, cc.thirdRootOne1),
Y: api.Select(selector2, negPY, _P.Y),
Y: api.Select(isNegu2, negPY, _P.Y),
}
tablePhiP[0].Neg(api, tablePhiP[1])

Expand All @@ -824,12 +795,12 @@ func (R *G1Affine) scalarMulGLVAndFakeGLV(api frontend.API, P G1Affine, s fronte
negQY := api.Neg(Q.Y)
tableQ[1] = G1Affine{
X: Q.X,
Y: api.Select(selector3, negQY, Q.Y),
Y: api.Select(isNegv1, negQY, Q.Y),
}
tableQ[0].Neg(api, tableQ[1])
tablePhiQ[1] = G1Affine{
X: api.Mul(Q.X, cc.thirdRootOne1),
Y: api.Select(selector4, negQY, Q.Y),
Y: api.Select(isNegv2, negQY, Q.Y),
}
tablePhiQ[0].Neg(api, tablePhiQ[1])

Expand Down
66 changes: 28 additions & 38 deletions std/algebra/native/sw_bls12377/hints.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,23 @@ func scalarMulGLVG1Hint(scalarField *big.Int, inputs []*big.Int, outputs []*big.
return nil
}

func halfGCDEisenstein(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error {
func halfGCDEisenstein(scalarField *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != 2 {
return fmt.Errorf("expecting two input")
}
if len(outputs) != 10 {
return fmt.Errorf("expecting ten outputs")
if len(outputs) != 5 {
return fmt.Errorf("expecting five outputs")
}
cc := getInnerCurveConfig(scalarField)
glvBasis := new(ecc.Lattice)
ecc.PrecomputeLattice(mod, inputs[1], glvBasis)
ecc.PrecomputeLattice(cc.fr, inputs[1], glvBasis)
r := eisenstein.ComplexNumber{
A0: &glvBasis.V1[0],
A1: &glvBasis.V1[1],
}
// r = 91893752504881257701523279626832445440 - ω
sp := ecc.SplitScalar(inputs[0], glvBasis)
// in-circuit we check that Q - [s]P = 0 or equivalently Q + [-s]P = 0
// so here we return -s instead of s.
// s.A0 and s.A1 are always positive.
s := eisenstein.ComplexNumber{
A0: &sp[0],
A1: &sp[1],
Expand All @@ -140,12 +139,14 @@ func halfGCDEisenstein(mod *big.Int, inputs []*big.Int, outputs []*big.Int) erro
outputs[1].Set(res[0].A1)
outputs[2].Set(res[1].A0)
outputs[3].Set(res[1].A1)
outputs[4].Set(res[2].A0)
outputs[5].Set(res[2].A1)
outputs[6].Set(r.A0)
outputs[7].Set(r.A1)
outputs[8].Set(s.A0)
outputs[9].Set(s.A1)
outputs[4].Mul(res[1].A1, inputs[1]).
Add(outputs[4], res[1].A0).
Mul(outputs[4], inputs[0]).
Add(outputs[4], res[0].A0)
s.A0.Mul(res[0].A1, inputs[1])
outputs[4].Add(outputs[4], s.A0).
Div(outputs[4], cc.fr)

if outputs[0].Sign() == -1 {
outputs[0].Neg(outputs[0])
}
Expand All @@ -161,34 +162,20 @@ func halfGCDEisenstein(mod *big.Int, inputs []*big.Int, outputs []*big.Int) erro
if outputs[4].Sign() == -1 {
outputs[4].Neg(outputs[4])
}
if outputs[5].Sign() == -1 {
outputs[5].Neg(outputs[5])
}
if outputs[6].Sign() == -1 {
outputs[6].Neg(outputs[6])
}
if outputs[7].Sign() == -1 {
outputs[7].Neg(outputs[7])
}
if outputs[8].Sign() == -1 {
outputs[8].Neg(outputs[8])
}
if outputs[9].Sign() == -1 {
outputs[9].Neg(outputs[9])
}

return nil
}

func halfGCDEisensteinSigns(mod *big.Int, inputs, outputs []*big.Int) error {
func halfGCDEisensteinSigns(scalarField *big.Int, inputs, outputs []*big.Int) error {
if len(inputs) != 2 {
return fmt.Errorf("expecting two input")
}
if len(outputs) != 6 {
return fmt.Errorf("expecting six outputs")
if len(outputs) != 5 {
return fmt.Errorf("expecting five outputs")
}
cc := getInnerCurveConfig(scalarField)
glvBasis := new(ecc.Lattice)
ecc.PrecomputeLattice(mod, inputs[1], glvBasis)
// r = 91893752504881257701523279626832445440 - ω
ecc.PrecomputeLattice(cc.fr, inputs[1], glvBasis)
r := eisenstein.ComplexNumber{
A0: &glvBasis.V1[0],
A1: &glvBasis.V1[1],
Expand All @@ -207,8 +194,15 @@ func halfGCDEisensteinSigns(mod *big.Int, inputs, outputs []*big.Int) error {
outputs[2].SetUint64(0)
outputs[3].SetUint64(0)
outputs[4].SetUint64(0)
outputs[5].SetUint64(0)
res := eisenstein.HalfGCD(&r, &s)
s.A1.Mul(res[1].A1, inputs[1]).
Add(s.A1, res[1].A0).
Mul(s.A1, inputs[0]).
Add(s.A1, res[0].A0)
s.A0.Mul(res[0].A1, inputs[1])
s.A1.Add(s.A1, s.A0).
Div(s.A1, cc.fr)

if res[0].A0.Sign() == -1 {
outputs[0].SetUint64(1)
}
Expand All @@ -221,12 +215,8 @@ func halfGCDEisensteinSigns(mod *big.Int, inputs, outputs []*big.Int) error {
if res[1].A1.Sign() == -1 {
outputs[3].SetUint64(1)
}
if res[2].A0.Sign() == -1 {
if s.A1.Sign() == -1 {
outputs[4].SetUint64(1)
}
if res[2].A1.Sign() == -1 {
outputs[5].SetUint64(1)
}

return nil
}

0 comments on commit 4bd0d48

Please sign in to comment.