Skip to content

Commit

Permalink
Check the length of hash in Marshal and Unmarshal and return always t…
Browse files Browse the repository at this point in the history
…he actual length.
  • Loading branch information
NicolasMahe committed Nov 15, 2019
1 parent f5902ec commit f4839ed
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
28 changes: 15 additions & 13 deletions hash/hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ var DefaultHash = sha256.New
// size is a default size for hashing algorithm.
var size = DefaultHash().Size()

func checkLen(data []byte) error {
if len(data) != size {
return fmt.Errorf("hash: invalid length")
}
return nil
}

// Digest represents the partial evaluation of a checksum.
type Digest struct {
hash.Hash
Expand Down Expand Up @@ -67,11 +74,9 @@ func Random() (Hash, error) {
if err != nil {
return nil, fmt.Errorf("hash generate random error: %s", err)
}

if n != size {
return nil, fmt.Errorf("hash generate random error: invalid hash length")
}

return hash, nil
}

Expand All @@ -82,8 +87,8 @@ func Decode(h string) (Hash, error) {
if err != nil {
return nil, fmt.Errorf("hash: %s", err)
}
if len(hash) != size {
return nil, fmt.Errorf("hash: invalid length string")
if err := checkLen(hash); err != nil {
return nil, err
}
return Hash(hash), nil
}
Expand All @@ -105,27 +110,27 @@ func (h Hash) Equal(h1 Hash) bool {

// Marshal marshals hash into slice of bytes. It's used by protobuf.
func (h Hash) Marshal() ([]byte, error) {
return h, nil
return h, checkLen(h)
}

// MarshalTo marshals hash into slice of bytes. It's used by protobuf.
func (h Hash) MarshalTo(data []byte) (int, error) {
return copy(data, h), nil
return copy(data, h), checkLen(h)
}

// Unmarshal unmarshals slice of bytes into hash. It's used by protobuf.
func (h *Hash) Unmarshal(data []byte) error {
if err := checkLen(data); err != nil {
return err
}
*h = make([]byte, len(data))
copy(*h, data)
return nil
}

// Size retruns size of hash. It's used by protobuf.
func (h Hash) Size() int {
if len(h) == 0 {
return 0
}
return size
return len(h)
}

// MarshalJSON mashals hash into encoded json string.
Expand All @@ -139,16 +144,13 @@ func (h *Hash) UnmarshalJSON(data []byte) error {
if err := json.Unmarshal(data, &str); err != nil {
return err
}

if str == "" {
return nil
}

h1, err := base58.Decode(str)
if err != nil {
return err
}

*h = Hash(h1)
return nil
}
19 changes: 18 additions & 1 deletion hash/hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestDecode(t *testing.T) {
assert.Equal(t, "hash: invalid base58 digit ('0')", err.Error())

_, err = Decode("1")
assert.Equal(t, "hash: invalid length string", err.Error())
assert.Equal(t, "hash: invalid length", err.Error())
}

func TestIsZero(t *testing.T) {
Expand All @@ -79,6 +79,7 @@ func TestEqual(t *testing.T) {
func TestSize(t *testing.T) {
assert.Equal(t, 0, Hash(nil).Size())
assert.Equal(t, size, zero.Size())
assert.Equal(t, 5, Hash([]byte("hello")).Size())
}

func TestMarshalJSON(t *testing.T) {
Expand All @@ -101,3 +102,19 @@ func TestUnmarshal(t *testing.T) {
// test if two slises do not share the same address
assert.True(t, &hash[cap(hash)-1] != &data[cap(data)-1])
}

func TestWrongLength(t *testing.T) {
var h Hash
wrongLenByte := []byte("hello")
t.Run("Marshal", func(t *testing.T) {
_, err := Hash(wrongLenByte).Marshal()
require.EqualError(t, err, "hash: invalid length")
})
t.Run("MarshalTo", func(t *testing.T) {
_, err := h.MarshalTo(wrongLenByte)
require.EqualError(t, err, "hash: invalid length")
})
t.Run("Unmarshal", func(t *testing.T) {
require.EqualError(t, h.Unmarshal(wrongLenByte), "hash: invalid length")
})
}

0 comments on commit f4839ed

Please sign in to comment.