diff --git a/.golangci.yml b/.golangci.yml index 6ce6ccf..a13ad34 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -60,6 +60,9 @@ issues: - path: internal/hmac/ text: "Write\\` is not checked" linters: [errcheck] + - path: internal/hmac/ + source: justHash + linters: [godot] # Ease linting on benchmarking code. - path: cmd/stun-bench/ diff --git a/internal/hmac/hmac.go b/internal/hmac/hmac.go index 801ece6..a6ba71c 100644 --- a/internal/hmac/hmac.go +++ b/internal/hmac/hmac.go @@ -34,18 +34,36 @@ import ( // opad = 0x5c byte repeated for key length // hmac = H([key ^ opad] H([key ^ ipad] text)) +// Marshalable is the combination of encoding.BinaryMarshaler and +// encoding.BinaryUnmarshaler. Their method definitions are repeated here to +// avoid a dependency on the encoding package. +type marshalable interface { + MarshalBinary() ([]byte, error) + UnmarshalBinary([]byte) error +} + type hmac struct { - size int - blocksize int opad, ipad []byte outer, inner hash.Hash + + // If marshaled is true, then opad and ipad do not contain a padded + // copy of the key, but rather the marshaled state of outer/inner after + // opad/ipad has been fed into it. + marshaled bool } func (h *hmac) Sum(in []byte) []byte { origLen := len(in) in = h.inner.Sum(in) - h.outer.Reset() - h.outer.Write(h.opad) + + if h.marshaled { + if err := h.outer.(marshalable).UnmarshalBinary(h.opad); err != nil { + panic(err) + } + } else { + h.outer.Reset() + h.outer.Write(h.opad) + } h.outer.Write(in[origLen:]) return h.outer.Sum(in[:origLen]) } @@ -54,13 +72,51 @@ func (h *hmac) Write(p []byte) (n int, err error) { return h.inner.Write(p) } -func (h *hmac) Size() int { return h.size } - -func (h *hmac) BlockSize() int { return h.blocksize } +func (h *hmac) Size() int { return h.outer.Size() } +func (h *hmac) BlockSize() int { return h.inner.BlockSize() } func (h *hmac) Reset() { + if h.marshaled { + if err := h.inner.(marshalable).UnmarshalBinary(h.ipad); err != nil { + panic(err) + } + return + } + h.inner.Reset() h.inner.Write(h.ipad) + + // If the underlying hash is marshalable, we can save some time by + // saving a copy of the hash state now, and restoring it on future + // calls to Reset and Sum instead of writing ipad/opad every time. + // + // If either hash is unmarshalable for whatever reason, + // it's safe to bail out here. + marshalableInner, innerOK := h.inner.(marshalable) + if !innerOK { + return + } + marshalableOuter, outerOK := h.outer.(marshalable) + if !outerOK { + return + } + + imarshal, err := marshalableInner.MarshalBinary() + if err != nil { + return + } + + h.outer.Reset() + h.outer.Write(h.opad) + omarshal, err := marshalableOuter.MarshalBinary() + if err != nil { + return + } + + // Marshaling succeeded; save the marshaled state for later + h.ipad = imarshal + h.opad = omarshal + h.marshaled = true } // New returns a new HMAC hash using the given hash.Hash type and key. @@ -71,11 +127,10 @@ func New(h func() hash.Hash, key []byte) hash.Hash { hm := new(hmac) hm.outer = h() hm.inner = h() - hm.size = hm.inner.Size() - hm.blocksize = hm.inner.BlockSize() - hm.ipad = make([]byte, hm.blocksize) - hm.opad = make([]byte, hm.blocksize) - if len(key) > hm.blocksize { + blocksize := hm.inner.BlockSize() + hm.ipad = make([]byte, blocksize) + hm.opad = make([]byte, blocksize) + if len(key) > blocksize { // If key is too big, hash it. hm.outer.Write(key) key = hm.outer.Sum(nil) @@ -89,6 +144,7 @@ func New(h func() hash.Hash, key []byte) hash.Hash { hm.opad[i] ^= 0x5c } hm.inner.Write(hm.ipad) + return hm } diff --git a/internal/hmac/hmac_test.go b/internal/hmac/hmac_test.go index eea345e..453bfb3 100644 --- a/internal/hmac/hmac_test.go +++ b/internal/hmac/hmac_test.go @@ -529,7 +529,7 @@ func TestHMAC(t *testing.T) { if b := h.BlockSize(); b != tt.blocksize { t.Errorf("BlockSize: got %v, want %v", b, tt.blocksize) } - for j := 0; j < 2; j++ { + for j := 0; j < 4; j++ { n, err := h.Write(tt.in) if n != len(tt.in) || err != nil { t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err) @@ -546,10 +546,21 @@ func TestHMAC(t *testing.T) { // Second iteration: make sure reset works. h.Reset() + + // Third and fourth iteration: make sure hmac works on + // hashes without MarshalBinary/UnmarshalBinary + if j == 1 { + h = New(func() hash.Hash { return justHash{tt.hash()} }, tt.key) + } } } } +// justHash implements just the hash.Hash methods and nothing else +type justHash struct { + hash.Hash +} + func TestEqual(t *testing.T) { a := []byte("test") b := []byte("test1") diff --git a/internal/hmac/pool.go b/internal/hmac/pool.go index 62d9240..d62fc25 100644 --- a/internal/hmac/pool.go +++ b/internal/hmac/pool.go @@ -7,21 +7,16 @@ import ( "sync" ) -// setZeroes sets all bytes from b to zeroes. -// -// See https://github.com/golang/go/issues/5373 -func setZeroes(b []byte) { - for i := range b { - b[i] = 0 - } -} - func (h *hmac) resetTo(key []byte) { h.outer.Reset() h.inner.Reset() - setZeroes(h.ipad) - setZeroes(h.opad) - if len(key) > h.blocksize { + blocksize := h.inner.BlockSize() + + // Reset size and zero of ipad and opad. + h.ipad = append(h.ipad[:0], make([]byte, blocksize)...) + h.opad = append(h.opad[:0], make([]byte, blocksize)...) + + if len(key) > blocksize { // If key is too big, hash it. h.outer.Write(key) key = h.outer.Sum(nil) @@ -35,6 +30,8 @@ func (h *hmac) resetTo(key []byte) { h.opad[i] ^= 0x5c } h.inner.Write(h.ipad) + + h.marshaled = false } var hmacSHA1Pool = &sync.Pool{ @@ -86,7 +83,7 @@ func PutSHA256(h hash.Hash) { // Put and Acquire functions are internal functions to project, so // checking it via such assert is optimal. func assertHMACSize(h *hmac, size, blocksize int) { - if h.size != size || h.blocksize != blocksize { + if h.Size() != size || h.BlockSize() != blocksize { panic("BUG: hmac size invalid") } }