diff --git a/chaincfg/chainhash/hashfuncs.go b/chaincfg/chainhash/hashfuncs.go index bf74f73c39..5be8a4d467 100644 --- a/chaincfg/chainhash/hashfuncs.go +++ b/chaincfg/chainhash/hashfuncs.go @@ -5,7 +5,10 @@ package chainhash -import "crypto/sha256" +import ( + "crypto/sha256" + "io" +) // HashB calculates hash(b) and returns the resulting bytes. func HashB(b []byte) []byte { @@ -31,3 +34,24 @@ func DoubleHashH(b []byte) Hash { first := sha256.Sum256(b) return Hash(sha256.Sum256(first[:])) } + +// DoubleHashRaw calculates hash(hash(w)) where w is the resulting bytes from +// the given serialize function and returns the resulting bytes as a Hash. +func DoubleHashRaw(serialize func(w io.Writer) error) Hash { + // Encode the transaction into the hash. Ignore the error returns + // since the only way the encode could fail is being out of memory + // or due to nil pointers, both of which would cause a run-time panic. + h := sha256.New() + _ = serialize(h) + + // This buf is here because Sum() will append the result to the passed + // in byte slice. Pre-allocating here saves an allocation on the second + // hash as we can reuse it. This allocation also does not escape to the + // heap, saving an allocation. + buf := make([]byte, 0, HashSize) + first := h.Sum(buf) + h.Reset() + h.Write(first) + res := h.Sum(buf) + return *(*Hash)(res) +} diff --git a/chaincfg/chainhash/hashfuncs_test.go b/chaincfg/chainhash/hashfuncs_test.go index bcd6f22200..6b9ff9a97f 100644 --- a/chaincfg/chainhash/hashfuncs_test.go +++ b/chaincfg/chainhash/hashfuncs_test.go @@ -6,6 +6,7 @@ package chainhash import ( "fmt" + "io" "testing" ) @@ -133,4 +134,20 @@ func TestDoubleHashFuncs(t *testing.T) { continue } } + + // Ensure the hash function which accepts a hash.Hash returns the expected + // result when given a hash.Hash that is of type SHA256. + for _, test := range tests { + serialize := func(w io.Writer) error { + w.Write([]byte(test.in)) + return nil + } + hash := DoubleHashRaw(serialize) + h := fmt.Sprintf("%x", hash[:]) + if h != test.out { + t.Errorf("DoubleHashRaw(%q) = %s, want %s", test.in, h, + test.out) + continue + } + } }