Skip to content

Commit

Permalink
refactor: use emulated.FieldParams as type parameter to generic Curve…
Browse files Browse the repository at this point in the history
… and Pairing (#901)

* feat: newScalar method in 2-chains

* refactor: better error msgs

* refactor: change curve and pairing definition

* refactor: 2-chains implement

* refactor: sw_bls12381 implement

* refactor: sw_bn254 implement

* refactor: sw_bw6761 implement

* test: remove unused pairing test

* refactor: groth16 implement

* refactor: kzg implement

* refactor: recursion hash

* fix: use scalar limb

* fix: marshal tests with short hash

* feat: define 2-chain emulation params

* feat: use emulated decomposition for scalar marshal

* feat: repack emulated scalars to frontend.Variable for compatibility
  • Loading branch information
ivokub authored Nov 10, 2023
1 parent 3f98e9b commit 29cadaa
Show file tree
Hide file tree
Showing 29 changed files with 1,541 additions and 1,432 deletions.
25 changes: 16 additions & 9 deletions std/algebra/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,46 @@ import (
"github.com/consensys/gnark/std/algebra/emulated/sw_emulated"
"github.com/consensys/gnark/std/algebra/native/sw_bls12377"
"github.com/consensys/gnark/std/algebra/native/sw_bls24315"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/math/emulated/emparams"
)

// GetCurve returns the [Curve] implementation corresponding to the scalar and
// G1 type parameters. The method allows to have a fully generic implementation
// without taking into consideration the initialization differences of different
// curves.
func GetCurve[S ScalarT, G1El G1ElementT](api frontend.API) (Curve[S, G1El], error) {
var ret Curve[S, G1El]
func GetCurve[FR emulated.FieldParams, G1El G1ElementT](api frontend.API) (Curve[FR, G1El], error) {
var ret Curve[FR, G1El]
switch s := any(&ret).(type) {
case *Curve[sw_bn254.Scalar, sw_bn254.G1Affine]:
case *Curve[sw_bn254.ScalarField, sw_bn254.G1Affine]:
c, err := sw_emulated.New[emparams.BN254Fp, emparams.BN254Fr](api, sw_emulated.GetBN254Params())
if err != nil {
return ret, fmt.Errorf("new curve: %w", err)
}
*s = c
case *Curve[sw_bw6761.Scalar, sw_bw6761.G1Affine]:
case *Curve[sw_bw6761.ScalarField, sw_bw6761.G1Affine]:
c, err := sw_emulated.New[emparams.BW6761Fp, emparams.BW6761Fr](api, sw_emulated.GetBW6761Params())
if err != nil {
return ret, fmt.Errorf("new curve: %w", err)
}
*s = c
case *Curve[sw_bls12381.Scalar, sw_bls12381.G1Affine]:
case *Curve[sw_bls12381.ScalarField, sw_bls12381.G1Affine]:
c, err := sw_emulated.New[emparams.BLS12381Fp, emparams.BLS12381Fr](api, sw_emulated.GetBLS12381Params())
if err != nil {
return ret, fmt.Errorf("new curve: %w", err)
}
*s = c
case *Curve[sw_bls12377.Scalar, sw_bls12377.G1Affine]:
c := sw_bls12377.NewCurve(api)
case *Curve[sw_bls12377.ScalarField, sw_bls12377.G1Affine]:
c, err := sw_bls12377.NewCurve(api)
if err != nil {
return ret, fmt.Errorf("new curve: %w", err)
}
*s = c
case *Curve[sw_bls24315.Scalar, sw_bls24315.G1Affine]:
c := sw_bls24315.NewCurve(api)
case *Curve[sw_bls24315.ScalarField, sw_bls24315.G1Affine]:
c, err := sw_bls24315.NewCurve(api)
if err != nil {
return ret, fmt.Errorf("new curve: %w", err)
}
*s = c
default:
return ret, fmt.Errorf("unknown type parametrisation")
Expand Down
25 changes: 15 additions & 10 deletions std/algebra/emulated/sw_bls12381/g1.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,31 @@ import (

// G1Affine is the point in G1. It is an alias to the generic emulated affine
// point.
type G1Affine = sw_emulated.AffinePoint[emulated.BLS12381Fp]
type G1Affine = sw_emulated.AffinePoint[BaseField]

// Scalar is the scalar in the groups. It is an alias to the emulated element
// defined over the scalar field of the groups.
type Scalar = emulated.Element[emulated.BLS12381Fr]
type Scalar = emulated.Element[ScalarField]

// NewG1Affine allocates a witness from the native G1 element and returns it.

func NewG1Affine(v bls12381.G1Affine) G1Affine {
return G1Affine{
X: emulated.ValueOf[emulated.BLS12381Fp](v.X),
Y: emulated.ValueOf[emulated.BLS12381Fp](v.Y),
X: emulated.ValueOf[BaseField](v.X),
Y: emulated.ValueOf[BaseField](v.Y),
}
}

type G1 struct {
curveF *emulated.Field[emulated.BLS12381Fp]
w *emulated.Element[emulated.BLS12381Fp]
curveF *emulated.Field[BaseField]
w *emulated.Element[BaseField]
}

func NewG1(api frontend.API) (*G1, error) {
ba, err := emulated.NewField[emulated.BLS12381Fp](api)
ba, err := emulated.NewField[BaseField](api)
if err != nil {
return nil, fmt.Errorf("new base api: %w", err)
}
w := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436")
w := emulated.ValueOf[BaseField]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436")
return &G1{
curveF: ba,
w: &w,
Expand All @@ -55,5 +54,11 @@ func (g1 *G1) phi(q *G1Affine) *G1Affine {

// NewScalar allocates a witness from the native scalar and returns it.
func NewScalar(v fr_bls12381.Element) Scalar {
return emulated.ValueOf[emulated.BLS12381Fr](v)
return emulated.ValueOf[ScalarField](v)
}

// ScalarField is the [emulated.FieldParams] impelementation of the curve scalar field.
type ScalarField = emulated.BLS12381Fr

// BaseField is the [emulated.FieldParams] impelementation of the curve base field.
type BaseField = emulated.BLS12381Fp
18 changes: 9 additions & 9 deletions std/algebra/emulated/sw_bls12381/g2.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

type G2 struct {
*fields_bls12381.Ext2
u1, w *emulated.Element[emulated.BLS12381Fp]
u1, w *emulated.Element[BaseField]
v *fields_bls12381.E2
}

Expand All @@ -20,11 +20,11 @@ type G2Affine struct {
}

func NewG2(api frontend.API) *G2 {
w := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436")
u1 := emulated.ValueOf[emulated.BLS12381Fp]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437")
w := emulated.ValueOf[BaseField]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436")
u1 := emulated.ValueOf[BaseField]("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437")
v := fields_bls12381.E2{
A0: emulated.ValueOf[emulated.BLS12381Fp]("2973677408986561043442465346520108879172042883009249989176415018091420807192182638567116318576472649347015917690530"),
A1: emulated.ValueOf[emulated.BLS12381Fp]("1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257"),
A0: emulated.ValueOf[BaseField]("2973677408986561043442465346520108879172042883009249989176415018091420807192182638567116318576472649347015917690530"),
A1: emulated.ValueOf[BaseField]("1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257"),
}
return &G2{
Ext2: fields_bls12381.NewExt2(api),
Expand All @@ -37,12 +37,12 @@ func NewG2(api frontend.API) *G2 {
func NewG2Affine(v bls12381.G2Affine) G2Affine {
return G2Affine{
X: fields_bls12381.E2{
A0: emulated.ValueOf[emulated.BLS12381Fp](v.X.A0),
A1: emulated.ValueOf[emulated.BLS12381Fp](v.X.A1),
A0: emulated.ValueOf[BaseField](v.X.A0),
A1: emulated.ValueOf[BaseField](v.X.A1),
},
Y: fields_bls12381.E2{
A0: emulated.ValueOf[emulated.BLS12381Fp](v.Y.A0),
A1: emulated.ValueOf[emulated.BLS12381Fp](v.Y.A1),
A0: emulated.ValueOf[BaseField](v.Y.A0),
A1: emulated.ValueOf[BaseField](v.Y.A1),
},
}
}
Expand Down
46 changes: 23 additions & 23 deletions std/algebra/emulated/sw_bls12381/pairing.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ import (
type Pairing struct {
api frontend.API
*fields_bls12381.Ext12
curveF *emulated.Field[emulated.BLS12381Fp]
curveF *emulated.Field[BaseField]
g2 *G2
g1 *G1
curve *sw_emulated.Curve[emulated.BLS12381Fp, emulated.BLS12381Fr]
curve *sw_emulated.Curve[BaseField, ScalarField]
bTwist *fields_bls12381.E2
lines [4][63]fields_bls12381.E2
}
Expand All @@ -29,47 +29,47 @@ func NewGTEl(v bls12381.GT) GTEl {
return GTEl{
C0: fields_bls12381.E6{
B0: fields_bls12381.E2{
A0: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B0.A0),
A1: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B0.A1),
A0: emulated.ValueOf[BaseField](v.C0.B0.A0),
A1: emulated.ValueOf[BaseField](v.C0.B0.A1),
},
B1: fields_bls12381.E2{
A0: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B1.A0),
A1: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B1.A1),
A0: emulated.ValueOf[BaseField](v.C0.B1.A0),
A1: emulated.ValueOf[BaseField](v.C0.B1.A1),
},
B2: fields_bls12381.E2{
A0: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B2.A0),
A1: emulated.ValueOf[emulated.BLS12381Fp](v.C0.B2.A1),
A0: emulated.ValueOf[BaseField](v.C0.B2.A0),
A1: emulated.ValueOf[BaseField](v.C0.B2.A1),
},
},
C1: fields_bls12381.E6{
B0: fields_bls12381.E2{
A0: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B0.A0),
A1: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B0.A1),
A0: emulated.ValueOf[BaseField](v.C1.B0.A0),
A1: emulated.ValueOf[BaseField](v.C1.B0.A1),
},
B1: fields_bls12381.E2{
A0: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B1.A0),
A1: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B1.A1),
A0: emulated.ValueOf[BaseField](v.C1.B1.A0),
A1: emulated.ValueOf[BaseField](v.C1.B1.A1),
},
B2: fields_bls12381.E2{
A0: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B2.A0),
A1: emulated.ValueOf[emulated.BLS12381Fp](v.C1.B2.A1),
A0: emulated.ValueOf[BaseField](v.C1.B2.A0),
A1: emulated.ValueOf[BaseField](v.C1.B2.A1),
},
},
}
}

func NewPairing(api frontend.API) (*Pairing, error) {
ba, err := emulated.NewField[emulated.BLS12381Fp](api)
ba, err := emulated.NewField[BaseField](api)
if err != nil {
return nil, fmt.Errorf("new base api: %w", err)
}
curve, err := sw_emulated.New[emulated.BLS12381Fp, emulated.BLS12381Fr](api, sw_emulated.GetBLS12381Params())
curve, err := sw_emulated.New[BaseField, ScalarField](api, sw_emulated.GetBLS12381Params())
if err != nil {
return nil, fmt.Errorf("new curve: %w", err)
}
bTwist := fields_bls12381.E2{
A0: emulated.ValueOf[emulated.BLS12381Fp]("4"),
A1: emulated.ValueOf[emulated.BLS12381Fp]("4"),
A0: emulated.ValueOf[BaseField]("4"),
A1: emulated.ValueOf[BaseField]("4"),
}
g1, err := NewG1(api)
if err != nil {
Expand Down Expand Up @@ -281,7 +281,7 @@ func (pr Pairing) AssertIsOnG1(P *G1Affine) {
// TODO: add phi and scalarMulBySeedSquare to g1.go
// [x²]ϕ(P)
phiP := pr.g1.phi(P)
seedSquare := emulated.ValueOf[emulated.BLS12381Fr]("228988810152649578064853576960394133504")
seedSquare := emulated.ValueOf[ScalarField]("228988810152649578064853576960394133504")
// TODO: use addchain to construct a fixed-scalar ScalarMul
_P := pr.curve.ScalarMul(phiP, &seedSquare)
_P = pr.curve.Neg(_P)
Expand Down Expand Up @@ -329,8 +329,8 @@ func (pr Pairing) MillerLoop(P []*G1Affine, Q []*G2Affine) (*GTEl, error) {

var l1, l2 *lineEvaluation
Qacc := make([]*G2Affine, n)
yInv := make([]*emulated.Element[emulated.BLS12381Fp], n)
xNegOverY := make([]*emulated.Element[emulated.BLS12381Fp], n)
yInv := make([]*emulated.Element[BaseField], n)
xNegOverY := make([]*emulated.Element[BaseField], n)

for k := 0; k < n; k++ {
Qacc[k] = Q[k]
Expand Down Expand Up @@ -669,7 +669,7 @@ func (pr Pairing) MillerLoopFixedQ(P *G1Affine) (*GTEl, error) {

res := pr.Ext12.One()

var yInv, xOverY *emulated.Element[emulated.BLS12381Fp]
var yInv, xOverY *emulated.Element[BaseField]

// P and Q are supposed to be on G1 and G2 respectively of prime order r.
// The point (x,0) is of order 2. But this function does not check
Expand Down Expand Up @@ -729,7 +729,7 @@ func (pr Pairing) DoubleMillerLoopFixedQ(P, T *G1Affine, Q *G2Affine) (*GTEl, er
var l1, l2 *lineEvaluation
var Qacc *G2Affine
Qacc = Q
var yInv, xNegOverY, y2Inv, x2OverY2 *emulated.Element[emulated.BLS12381Fp]
var yInv, xNegOverY, y2Inv, x2OverY2 *emulated.Element[BaseField]
yInv = pr.curveF.Inverse(&P.Y)
xNegOverY = pr.curveF.MulMod(&P.X, yInv)
xNegOverY = pr.curveF.Neg(xNegOverY)
Expand Down
Loading

0 comments on commit 29cadaa

Please sign in to comment.