From c15497560bf2f4d49d1ddd112bfef2e7f4ff6d25 Mon Sep 17 00:00:00 2001 From: Sebastien Binet Date: Mon, 20 Jan 2020 10:58:47 +0100 Subject: [PATCH] zmq4: add r/w test for greeting Updates go-zeromq/zmq4#56. --- protocol.go | 30 ++++++------- protocol_test.go | 115 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 16 deletions(-) create mode 100644 protocol_test.go diff --git a/protocol.go b/protocol.go index ddaef7b..dc16eb9 100644 --- a/protocol.go +++ b/protocol.go @@ -35,6 +35,8 @@ const ( hasMoreBitFlag = 0x1 isLongBitFlag = 0x2 isCommandBitFlag = 0x4 + + zmtpMsgLen = 64 ) var ( @@ -90,46 +92,42 @@ type greeting struct { } func (g *greeting) read(r io.Reader) error { - var data [64]byte + var data [zmtpMsgLen]byte _, err := io.ReadFull(r, data[:]) if err != nil { - return err + return xerrors.Errorf("could not read ZMTP greeting: %w", err) } - err = g.unmarshal(data[:]) - if err != nil { - return err - } + g.unmarshal(data[:]) if g.Sig.Header != sigHeader { - return errGreeting + return xerrors.Errorf("invalid ZMTP signature header: %w", errGreeting) } if g.Sig.Footer != sigFooter { - return errGreeting + return xerrors.Errorf("invalid ZMTP signature footer: %w", errGreeting) } // FIXME(sbinet): handle version negotiations as per // https://rfc.zeromq.org/spec:23/ZMTP/#version-negotiation if g.Version != defaultVersion { - return errGreeting + return xerrors.Errorf( + "invalid ZMTP version (got=%v, want=%v): %w", + g.Version, defaultVersion, errGreeting, + ) } return nil } -func (g *greeting) unmarshal(data []byte) error { - if len(data) < 64 { - return io.ErrShortBuffer - } - _ = data[:64] +func (g *greeting) unmarshal(data []byte) { + _ = data[:zmtpMsgLen] g.Sig.Header = data[0] g.Sig.Footer = data[9] g.Version[0] = data[10] g.Version[1] = data[11] copy(g.Mechanism[:], data[12:32]) g.Server = data[32] - return nil } func (g *greeting) write(w io.Writer) error { @@ -138,7 +136,7 @@ func (g *greeting) write(w io.Writer) error { } func (g *greeting) marshal() []byte { - var buf [64]byte + var buf [zmtpMsgLen]byte buf[0] = g.Sig.Header // padding 1 ignored buf[9] = g.Sig.Footer diff --git a/protocol_test.go b/protocol_test.go new file mode 100644 index 0000000..c25afe9 --- /dev/null +++ b/protocol_test.go @@ -0,0 +1,115 @@ +// Copyright 2020 The go-zeromq Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package zmq4 + +import ( + "bytes" + "io" + "testing" + + "golang.org/x/xerrors" +) + +func TestGreeting(t *testing.T) { + for _, tc := range []struct { + name string + data []byte + want error + }{ + { + name: "valid", + data: func() []byte { + w := new(bytes.Buffer) + g := greeting{ + Version: defaultVersion, + } + g.Sig.Header = sigHeader + g.Sig.Footer = sigFooter + g.write(w) + return w.Bytes() + }(), + }, + { + name: "empty-buffer", + data: nil, + want: xerrors.Errorf("could not read ZMTP greeting: %w", io.EOF), + }, + { + name: "unexpected-EOF", + data: make([]byte, 1), + want: xerrors.Errorf("could not read ZMTP greeting: %w", io.ErrUnexpectedEOF), + }, + { + name: "invalid-header", + data: func() []byte { + w := new(bytes.Buffer) + g := greeting{ + Version: defaultVersion, + } + g.Sig.Header = sigFooter // err + g.Sig.Footer = sigFooter + g.write(w) + return w.Bytes() + }(), + want: xerrors.Errorf("invalid ZMTP signature header: %w", errGreeting), + }, + { + name: "invalid-footer", + data: func() []byte { + w := new(bytes.Buffer) + g := greeting{ + Version: defaultVersion, + } + g.Sig.Header = sigHeader + g.Sig.Footer = sigHeader // err + g.write(w) + return w.Bytes() + }(), + want: xerrors.Errorf("invalid ZMTP signature footer: %w", errGreeting), + }, + { + name: "invalid-version", // FIXME(sbinet): adapt for when/if we support multiple ZMTP versions + data: func() []byte { + w := new(bytes.Buffer) + g := greeting{ + Version: [2]uint8{1, 1}, + } + g.Sig.Header = sigHeader + g.Sig.Footer = sigFooter + g.write(w) + return w.Bytes() + }(), + want: xerrors.Errorf("invalid ZMTP version (got=%v, want=%v): %w", + [2]uint{1, 1}, + defaultVersion, + errGreeting, + ), + }, + } { + t.Run(tc.name, func(t *testing.T) { + var ( + g greeting + r = bytes.NewReader(tc.data) + ) + + err := g.read(r) + switch { + case err == nil && tc.want == nil: + // ok + case err == nil && tc.want != nil: + t.Fatalf("expected an error (%s)", tc.want) + case err != nil && tc.want == nil: + t.Fatalf("could not read ZMTP greeting: %+v", err) + case err != nil && tc.want != nil: + if got, want := err.Error(), tc.want.Error(); got != want { + t.Fatalf("invalid ZMTP greeting error:\ngot= %+v\nwant=%+v\n", + got, want, + ) + } + } + + }) + } +}