diff --git a/benchmarks_test.go b/benchmarks_test.go index 37c5c00b..b595a87f 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -165,6 +165,34 @@ func BenchmarkMul(bench *testing.B) { bench.Run("single/big", benchmarkBig) } +func BenchmarkMulOverflow(bench *testing.B) { + benchmarkUint256 := func(bench *testing.B) { + a := big.NewInt(0).SetBytes(hex2Bytes("f123456789abcdeffedcba9876543210f2f3f4f5f6f7f8f9fff3f4f5f6f7f8f9")) + b := big.NewInt(0).SetBytes(hex2Bytes("f123456789abcdefaaaaaa9876543210f2f3f4f5f6f7f8f9fff3f4f5f6f7f8f9")) + fa, _ := FromBig(a) + fb, _ := FromBig(b) + + result := new(Int) + bench.ResetTimer() + for i := 0; i < bench.N; i++ { + result.MulOverflow(fa, fb) + } + } + benchmarkBig := func(bench *testing.B) { + a := new(big.Int).SetBytes(hex2Bytes("f123456789abcdeffedcba9876543210f2f3f4f5f6f7f8f9fff3f4f5f6f7f8f9")) + b := new(big.Int).SetBytes(hex2Bytes("f123456789abcdefaaaaaa9876543210f2f3f4f5f6f7f8f9fff3f4f5f6f7f8f9")) + + result := new(big.Int) + bench.ResetTimer() + for i := 0; i < bench.N; i++ { + U256(result.Mul(a, b)) + } + } + + bench.Run("single/uint256", benchmarkUint256) + bench.Run("single/big", benchmarkBig) +} + func BenchmarkSquare(bench *testing.B) { benchmarkUint256 := func(bench *testing.B) { diff --git a/uint256.go b/uint256.go index 278b1047..5faa400f 100644 --- a/uint256.go +++ b/uint256.go @@ -330,15 +330,8 @@ func (z *Int) Mul(x, y *Int) *Int { // MulOverflow sets z to the product x*y, and returns whether overflow occurred func (z *Int) MulOverflow(x, y *Int) bool { p := umul(x, y) - var ( - pl Int - ph Int - ) - copy(pl[:], p[:4]) - copy(ph[:], p[4:]) - - z.Set(&pl) - return !ph.IsZero() + copy(z[:], p[:4]) + return (p[4] | p[5] | p[6] | p[7]) != 0 } func (z *Int) squared() {