diff --git a/hash/hash.go b/hash/hash.go index 4c005e976..73349d89a 100644 --- a/hash/hash.go +++ b/hash/hash.go @@ -19,6 +19,13 @@ var DefaultHash = sha256.New // size is a default size for hashing algorithm. var size = DefaultHash().Size() +func checkLen(data []byte) error { + if len(data) != size { + return fmt.Errorf("hash: invalid length") + } + return nil +} + // Digest represents the partial evaluation of a checksum. type Digest struct { hash.Hash @@ -67,11 +74,9 @@ func Random() (Hash, error) { if err != nil { return nil, fmt.Errorf("hash generate random error: %s", err) } - if n != size { return nil, fmt.Errorf("hash generate random error: invalid hash length") } - return hash, nil } @@ -82,8 +87,8 @@ func Decode(h string) (Hash, error) { if err != nil { return nil, fmt.Errorf("hash: %s", err) } - if len(hash) != size { - return nil, fmt.Errorf("hash: invalid length string") + if err := checkLen(hash); err != nil { + return nil, err } return Hash(hash), nil } @@ -105,16 +110,19 @@ func (h Hash) Equal(h1 Hash) bool { // Marshal marshals hash into slice of bytes. It's used by protobuf. func (h Hash) Marshal() ([]byte, error) { - return h, nil + return h, checkLen(h) } // MarshalTo marshals hash into slice of bytes. It's used by protobuf. func (h Hash) MarshalTo(data []byte) (int, error) { - return copy(data, h), nil + return copy(data, h), checkLen(h) } // Unmarshal unmarshals slice of bytes into hash. It's used by protobuf. func (h *Hash) Unmarshal(data []byte) error { + if err := checkLen(data); err != nil { + return err + } *h = make([]byte, len(data)) copy(*h, data) return nil @@ -122,10 +130,7 @@ func (h *Hash) Unmarshal(data []byte) error { // Size retruns size of hash. It's used by protobuf. func (h Hash) Size() int { - if len(h) == 0 { - return 0 - } - return size + return len(h) } // MarshalJSON mashals hash into encoded json string. @@ -139,16 +144,13 @@ func (h *Hash) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &str); err != nil { return err } - if str == "" { return nil } - h1, err := base58.Decode(str) if err != nil { return err } - *h = Hash(h1) return nil } diff --git a/hash/hash_test.go b/hash/hash_test.go index db0c786ab..6607d9237 100644 --- a/hash/hash_test.go +++ b/hash/hash_test.go @@ -60,7 +60,7 @@ func TestDecode(t *testing.T) { assert.Equal(t, "hash: invalid base58 digit ('0')", err.Error()) _, err = Decode("1") - assert.Equal(t, "hash: invalid length string", err.Error()) + assert.Equal(t, "hash: invalid length", err.Error()) } func TestIsZero(t *testing.T) { @@ -79,6 +79,7 @@ func TestEqual(t *testing.T) { func TestSize(t *testing.T) { assert.Equal(t, 0, Hash(nil).Size()) assert.Equal(t, size, zero.Size()) + assert.Equal(t, 5, Hash([]byte("hello")).Size()) } func TestMarshalJSON(t *testing.T) { @@ -101,3 +102,19 @@ func TestUnmarshal(t *testing.T) { // test if two slises do not share the same address assert.True(t, &hash[cap(hash)-1] != &data[cap(data)-1]) } + +func TestWrongLength(t *testing.T) { + var h Hash + wrongLenByte := []byte("hello") + t.Run("Marshal", func(t *testing.T) { + _, err := Hash(wrongLenByte).Marshal() + require.EqualError(t, err, "hash: invalid length") + }) + t.Run("MarshalTo", func(t *testing.T) { + _, err := h.MarshalTo(wrongLenByte) + require.EqualError(t, err, "hash: invalid length") + }) + t.Run("Unmarshal", func(t *testing.T) { + require.EqualError(t, h.Unmarshal(wrongLenByte), "hash: invalid length") + }) +}