diff --git a/conversion.go b/conversion.go index eb5ca364..b209433f 100644 --- a/conversion.go +++ b/conversion.go @@ -66,6 +66,16 @@ func FromBig(b *big.Int) (*Int, bool) { return z, overflow } +// MustFromBig is a convenience-constructor from big.Int. +// Returns a new Int and panics if overflow occurred. +func MustFromBig(b *big.Int) *Int { + z := &Int{} + if z.SetFromBig(b) { + panic("overflow") + } + return z +} + // SetFromHex sets z from the given string, interpreted as a hexadecimal number. // OBS! This method is _not_ strictly identical to the (*big.Int).SetString(..., 16) method. // Notable differences: diff --git a/conversion_test.go b/conversion_test.go index b999ff20..822e5a52 100644 --- a/conversion_test.go +++ b/conversion_test.go @@ -27,6 +27,10 @@ func TestFromBig(t *testing.T) { if exp, got := a.Bytes(), b.Bytes(); !bytes.Equal(got, exp) { t.Fatalf("got %x exp %x", got, exp) } + b2 := MustFromBig(a) + if exp, got := a.Bytes(), b2.Bytes(); !bytes.Equal(got, exp) { + t.Fatalf("got %x exp %x", got, exp) + } a = big.NewInt(1) b, o = FromBig(a) @@ -36,6 +40,10 @@ func TestFromBig(t *testing.T) { if exp, got := a.Bytes(), b.Bytes(); !bytes.Equal(got, exp) { t.Fatalf("got %x exp %x", got, exp) } + b2 = MustFromBig(a) + if exp, got := a.Bytes(), b2.Bytes(); !bytes.Equal(got, exp) { + t.Fatalf("got %x exp %x", got, exp) + } a = big.NewInt(0x1000000000000000) b, o = FromBig(a) @@ -45,6 +53,10 @@ func TestFromBig(t *testing.T) { if exp, got := a.Bytes(), b.Bytes(); !bytes.Equal(got, exp) { t.Fatalf("got %x exp %x", got, exp) } + b2 = MustFromBig(a) + if exp, got := a.Bytes(), b2.Bytes(); !bytes.Equal(got, exp) { + t.Fatalf("got %x exp %x", got, exp) + } a = big.NewInt(0x1234) b, o = FromBig(a) @@ -54,6 +66,10 @@ func TestFromBig(t *testing.T) { if exp, got := a.Bytes(), b.Bytes(); !bytes.Equal(got, exp) { t.Fatalf("got %x exp %x", got, exp) } + b2 = MustFromBig(a) + if exp, got := a.Bytes(), b2.Bytes(); !bytes.Equal(got, exp) { + t.Fatalf("got %x exp %x", got, exp) + } a = big.NewInt(1) a.Lsh(a, 256) @@ -65,6 +81,18 @@ func TestFromBig(t *testing.T) { if !b.Eq(new(Int)) { t.Fatalf("got %x exp 0", b.Bytes()) } + done := make(chan struct{}) + go func() { + defer func() { + o = recover() != nil + done <- struct{}{} + }() + MustFromBig(a) + }() + <-done + if !o { + t.Fatalf("expected overflow") + } a.Sub(a, big.NewInt(1)) b, o = FromBig(a) @@ -74,6 +102,10 @@ func TestFromBig(t *testing.T) { if exp, got := a.Bytes(), b.Bytes(); !bytes.Equal(got, exp) { t.Fatalf("got %x exp %x", got, exp) } + b2 = MustFromBig(a) + if exp, got := a.Bytes(), b2.Bytes(); !bytes.Equal(got, exp) { + t.Fatalf("got %x exp %x", got, exp) + } } func TestScanScientific(t *testing.T) { @@ -160,19 +192,41 @@ func TestScanScientific(t *testing.T) { } func TestFromBigOverflow(t *testing.T) { - _, o := FromBig(new(big.Int).SetBytes(hex2Bytes("ababee444444444444ffcc333333333333ddaa222222222222bb8811111111111199"))) + // Test overflow with error returns + b := new(big.Int).SetBytes(hex2Bytes("ababee444444444444ffcc333333333333ddaa222222222222bb8811111111111199")) + _, o := FromBig(b) if !o { t.Errorf("expected overflow, got %v", o) } - _, o = FromBig(new(big.Int).SetBytes(hex2Bytes("ee444444444444ffcc333333333333ddaa222222222222bb8811111111111199"))) + // Test overflow with panic (recovery is a bit unwieldy) + done := make(chan struct{}) + go func() { + defer func() { + o = recover() != nil + done <- struct{}{} + }() + MustFromBig(b) + }() + <-done + if !o { + t.Fatalf("expected overflow") + } + // Test no overflow + b = new(big.Int).SetBytes(hex2Bytes("ee444444444444ffcc333333333333ddaa222222222222bb8811111111111199")) + _, o = FromBig(b) if o { t.Errorf("expected no overflow, got %v", o) } - b := new(big.Int).SetBytes(hex2Bytes("ee444444444444ffcc333333333333ddaa222222222222bb8811111111111199")) - _, o = FromBig(b.Neg(b)) + MustFromBig(b) + + b = new(big.Int).SetBytes(hex2Bytes("ee444444444444ffcc333333333333ddaa222222222222bb8811111111111199")) + b.Neg(b) + + _, o = FromBig(b) if o { t.Errorf("expected no overflow, got %v", o) } + MustFromBig(b) } func TestToBig(t *testing.T) {