From d0f7a43758cafe86843e58fe3d760aea9dd6b1b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= Date: Mon, 20 Mar 2023 22:24:19 +0200 Subject: [PATCH 1/4] Add support for MustFromBig for nicer struct initialization --- conversion.go | 11 ++++++++ conversion_test.go | 62 +++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/conversion.go b/conversion.go index eb5ca364..e5c92d2a 100644 --- a/conversion.go +++ b/conversion.go @@ -66,6 +66,17 @@ func FromBig(b *big.Int) (*Int, bool) { return z, overflow } +// MustFromBig is a convenience-constructor from big.Int. +// Returns a new Int and panics if overflow occurred. +func MustFromBig(b *big.Int) *Int { + z := &Int{} + overflow := z.SetFromBig(b) + if overflow { + panic("overflow") + } + return z +} + // SetFromHex sets z from the given string, interpreted as a hexadecimal number. // OBS! This method is _not_ strictly identical to the (*big.Int).SetString(..., 16) method. // Notable differences: diff --git a/conversion_test.go b/conversion_test.go index b999ff20..822e5a52 100644 --- a/conversion_test.go +++ b/conversion_test.go @@ -27,6 +27,10 @@ func TestFromBig(t *testing.T) { if exp, got := a.Bytes(), b.Bytes(); !bytes.Equal(got, exp) { t.Fatalf("got %x exp %x", got, exp) } + b2 := MustFromBig(a) + if exp, got := a.Bytes(), b2.Bytes(); !bytes.Equal(got, exp) { + t.Fatalf("got %x exp %x", got, exp) + } a = big.NewInt(1) b, o = FromBig(a) @@ -36,6 +40,10 @@ func TestFromBig(t *testing.T) { if exp, got := a.Bytes(), b.Bytes(); !bytes.Equal(got, exp) { t.Fatalf("got %x exp %x", got, exp) } + b2 = MustFromBig(a) + if exp, got := a.Bytes(), b2.Bytes(); !bytes.Equal(got, exp) { + t.Fatalf("got %x exp %x", got, exp) + } a = big.NewInt(0x1000000000000000) b, o = FromBig(a) @@ -45,6 +53,10 @@ func TestFromBig(t *testing.T) { if exp, got := a.Bytes(), b.Bytes(); !bytes.Equal(got, exp) { t.Fatalf("got %x exp %x", got, exp) } + b2 = MustFromBig(a) + if exp, got := a.Bytes(), b2.Bytes(); !bytes.Equal(got, exp) { + t.Fatalf("got %x exp %x", got, exp) + } a = big.NewInt(0x1234) b, o = FromBig(a) @@ -54,6 +66,10 @@ func TestFromBig(t *testing.T) { if exp, got := a.Bytes(), b.Bytes(); !bytes.Equal(got, exp) { t.Fatalf("got %x exp %x", got, exp) } + b2 = MustFromBig(a) + if exp, got := a.Bytes(), b2.Bytes(); !bytes.Equal(got, exp) { + t.Fatalf("got %x exp %x", got, exp) + } a = big.NewInt(1) a.Lsh(a, 256) @@ -65,6 +81,18 @@ func TestFromBig(t *testing.T) { if !b.Eq(new(Int)) { t.Fatalf("got %x exp 0", b.Bytes()) } + done := make(chan struct{}) + go func() { + defer func() { + o = recover() != nil + done <- struct{}{} + }() + MustFromBig(a) + }() + <-done + if !o { + t.Fatalf("expected overflow") + } a.Sub(a, big.NewInt(1)) b, o = FromBig(a) @@ -74,6 +102,10 @@ func TestFromBig(t *testing.T) { if exp, got := a.Bytes(), b.Bytes(); !bytes.Equal(got, exp) { t.Fatalf("got %x exp %x", got, exp) } + b2 = MustFromBig(a) + if exp, got := a.Bytes(), b2.Bytes(); !bytes.Equal(got, exp) { + t.Fatalf("got %x exp %x", got, exp) + } } func TestScanScientific(t *testing.T) { @@ -160,19 +192,41 @@ func TestScanScientific(t *testing.T) { } func TestFromBigOverflow(t *testing.T) { - _, o := FromBig(new(big.Int).SetBytes(hex2Bytes("ababee444444444444ffcc333333333333ddaa222222222222bb8811111111111199"))) + // Test overflow with error returns + b := new(big.Int).SetBytes(hex2Bytes("ababee444444444444ffcc333333333333ddaa222222222222bb8811111111111199")) + _, o := FromBig(b) if !o { t.Errorf("expected overflow, got %v", o) } - _, o = FromBig(new(big.Int).SetBytes(hex2Bytes("ee444444444444ffcc333333333333ddaa222222222222bb8811111111111199"))) + // Test overflow with panic (recovery is a bit unwieldy) + done := make(chan struct{}) + go func() { + defer func() { + o = recover() != nil + done <- struct{}{} + }() + MustFromBig(b) + }() + <-done + if !o { + t.Fatalf("expected overflow") + } + // Test no overflow + b = new(big.Int).SetBytes(hex2Bytes("ee444444444444ffcc333333333333ddaa222222222222bb8811111111111199")) + _, o = FromBig(b) if o { t.Errorf("expected no overflow, got %v", o) } - b := new(big.Int).SetBytes(hex2Bytes("ee444444444444ffcc333333333333ddaa222222222222bb8811111111111199")) - _, o = FromBig(b.Neg(b)) + MustFromBig(b) + + b = new(big.Int).SetBytes(hex2Bytes("ee444444444444ffcc333333333333ddaa222222222222bb8811111111111199")) + b.Neg(b) + + _, o = FromBig(b) if o { t.Errorf("expected no overflow, got %v", o) } + MustFromBig(b) } func TestToBig(t *testing.T) { From 5e044ce5677301aa47139e691868c2cead2d64f1 Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Tue, 21 Mar 2023 03:22:38 -0400 Subject: [PATCH 2/4] minor nit --- conversion.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/conversion.go b/conversion.go index e5c92d2a..822497c8 100644 --- a/conversion.go +++ b/conversion.go @@ -70,8 +70,7 @@ func FromBig(b *big.Int) (*Int, bool) { // Returns a new Int and panics if overflow occurred. func MustFromBig(b *big.Int) *Int { z := &Int{} - overflow := z.SetFromBig(b) - if overflow { + if overflow := z.SetFromBig(b) { panic("overflow") } return z From 0ef04d88fe20c45666710fd076a8f6c63a0ecbba Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Tue, 21 Mar 2023 03:24:05 -0400 Subject: [PATCH 3/4] Update conversion.go --- conversion.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conversion.go b/conversion.go index 822497c8..76c76ea8 100644 --- a/conversion.go +++ b/conversion.go @@ -70,7 +70,7 @@ func FromBig(b *big.Int) (*Int, bool) { // Returns a new Int and panics if overflow occurred. func MustFromBig(b *big.Int) *Int { z := &Int{} - if overflow := z.SetFromBig(b) { + if overflow := z.SetFromBig(b); overflow { panic("overflow") } return z From 416c76a4a46ea0f0bae61b98d014c21a15873c3a Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Tue, 21 Mar 2023 03:29:43 -0400 Subject: [PATCH 4/4] Update conversion.go --- conversion.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conversion.go b/conversion.go index 76c76ea8..b209433f 100644 --- a/conversion.go +++ b/conversion.go @@ -70,7 +70,7 @@ func FromBig(b *big.Int) (*Int, bool) { // Returns a new Int and panics if overflow occurred. func MustFromBig(b *big.Int) *Int { z := &Int{} - if overflow := z.SetFromBig(b); overflow { + if z.SetFromBig(b) { panic("overflow") } return z