Skip to content

Commit

Permalink
feat: customized payload validator (cloudwego#1478)
Browse files Browse the repository at this point in the history
  • Loading branch information
ppzqh authored Aug 28, 2024
1 parent b870445 commit ebd3dfe
Show file tree
Hide file tree
Showing 10 changed files with 722 additions and 23 deletions.
1 change: 1 addition & 0 deletions pkg/consts/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package consts
// Method key used in context.
const (
CtxKeyMethod = "K_METHOD"
CtxKeyLogID = "K_LOGID"
)

const (
Expand Down
2 changes: 2 additions & 0 deletions pkg/kerrors/kerrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ var (
ErrRPCFinish = &basicError{"rpc call finished"}
// ErrRoute happens when router fail to route this call
ErrRoute = &basicError{"rpc route failed"}
// ErrPayloadValidation happens when payload validation failed
ErrPayloadValidation = &basicError{"payload validation error"}
)

// More detailed error types
Expand Down
110 changes: 102 additions & 8 deletions pkg/remote/codec/default_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ import (
"fmt"
"sync/atomic"

"github.com/cloudwego/netpoll"

"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
netpolltrans "github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
"github.com/cloudwego/kitex/pkg/retry"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
Expand Down Expand Up @@ -59,21 +62,45 @@ var (
func NewDefaultCodec() remote.Codec {
// No size limit by default
return &defaultCodec{
maxSize: 0,
CodecConfig{MaxSize: 0},
}
}

// NewDefaultCodecWithSizeLimit creates the default protocol sniffing codec supporting thrift and protobuf but with size limit.
// maxSize is in bytes
func NewDefaultCodecWithSizeLimit(maxSize int) remote.Codec {
return &defaultCodec{
maxSize: maxSize,
CodecConfig{MaxSize: maxSize},
}
}

type defaultCodec struct {
// NewDefaultCodecWithConfig creates the default protocol sniffing codec supporting thrift and protobuf with the input config.
func NewDefaultCodecWithConfig(cfg CodecConfig) remote.Codec {
if cfg.CRC32Check {
// TODO: crc32 has higher priority now.
cfg.PayloadValidator = NewCRC32PayloadValidator()
}
return &defaultCodec{cfg}
}

// CodecConfig is the config of defaultCodec
type CodecConfig struct {
// maxSize limits the max size of the payload
maxSize int
MaxSize int

// If crc32Check is true, the codec will validate the payload using crc32c.
// Only effective when transport is TTHeader.
// Payload is all the data after TTHeader.
CRC32Check bool

// PayloadValidator is used to validate payload with customized checksum logic.
// It prepares a value based on payload in sender-side and validates the value in receiver-side.
// It can only be used when ttheader is enabled.
PayloadValidator PayloadValidator
}

type defaultCodec struct {
CodecConfig
}

// EncodePayload encode payload
Expand Down Expand Up @@ -111,23 +138,27 @@ func (c *defaultCodec) EncodePayload(ctx context.Context, message remote.Message
return perrors.NewProtocolErrorWithMsg("no buffer allocated for the framed length field")
}
payloadLen = out.MallocLen() - headerLen
// FIXME: if the `out` buffer using copy to grow when the capacity is not enough, setting the pre-allocated `framedLenField` may not take effect.
binary.BigEndian.PutUint32(framedLenField, uint32(payloadLen))
} else if message.ProtocolInfo().CodecType == serviceinfo.Protobuf {
return perrors.NewProtocolErrorWithMsg("protobuf just support 'framed' trans proto")
}
if tp&transport.TTHeader == transport.TTHeader {
payloadLen = out.MallocLen() - Size32
}
err = checkPayloadSize(payloadLen, c.maxSize)
err = checkPayloadSize(payloadLen, c.MaxSize)
return err
}

// EncodeMetaAndPayload encode meta and payload
func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error {
var err error
var totalLenField []byte
tp := message.ProtocolInfo().TransProto
if c.PayloadValidator != nil && tp&transport.TTHeader == transport.TTHeader {
return c.encodeMetaAndPayloadWithPayloadValidator(ctx, message, out, me)
}

var err error
var totalLenField []byte
// 1. encode header and return totalLenField if needed
// totalLenField will be filled after payload encoded
if tp&transport.TTHeader == transport.TTHeader {
Expand All @@ -144,6 +175,7 @@ func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.
if totalLenField == nil {
return perrors.NewProtocolErrorWithMsg("no buffer allocated for the header length field")
}
// FIXME: if the `out` buffer using copy to grow when the capacity is not enough, setting the pre-allocated `totalLenField` may not take effect.
payloadLen := out.MallocLen() - Size32
binary.BigEndian.PutUint32(totalLenField, uint32(payloadLen))
}
Expand Down Expand Up @@ -176,6 +208,11 @@ func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, i
if flagBuf, err = in.Peek(2 * Size32); err != nil {
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("ttheader read payload first 8 byte failed: %s", err.Error()))
}
if c.PayloadValidator != nil {
if pErr := payloadChecksumValidate(ctx, c.PayloadValidator, in, message); pErr != nil {
return pErr
}
}
} else if isMeshHeader(flagBuf) {
message.Tags()[remote.MeshHeader] = true
// MeshHeader
Expand All @@ -186,7 +223,7 @@ func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, i
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("meshHeader read payload first 8 byte failed: %s", err.Error()))
}
}
return checkPayload(flagBuf, message, in, isTTHeader, c.maxSize)
return checkPayload(flagBuf, message, in, isTTHeader, c.MaxSize)
}

// DecodePayload decode payload
Expand Down Expand Up @@ -229,6 +266,55 @@ func (c *defaultCodec) Name() string {
return "default"
}

// encodeMetaAndPayloadWithPayloadValidator encodes payload and meta with checksum of the payload.
func (c *defaultCodec) encodeMetaAndPayloadWithPayloadValidator(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) (err error) {
writer := netpoll.NewLinkBuffer()
payloadOut := netpolltrans.NewWriterByteBuffer(writer)
defer func() {
payloadOut.Release(err)
}()

// 1. encode payload and calculate value via payload validator
if err = me.EncodePayload(ctx, message, payloadOut); err != nil {
return err
}
// get the payload from buffer
// use copy api here because the payload will be used as an argument of Generate function in validator
payload, err := getWrittenBytes(writer)
if err != nil {
return err
}
if c.PayloadValidator != nil {
if err = payloadChecksumGenerate(ctx, c.PayloadValidator, payload, message); err != nil {
return err
}
}
// set payload length before encode TTHeader
message.SetPayloadLen(len(payload))

// 2. encode header and return totalLenField if needed
totalLenField, err := ttHeaderCodec.encode(ctx, message, out)
if err != nil {
return err
}

// 3. write payload to the buffer after TTHeader
if ncWriter, ok := out.(remote.NocopyWrite); ok {
err = ncWriter.WriteDirect(payload, 0)
} else {
_, err = out.WriteBinary(payload)
}

// 4. fill totalLen field for header if needed
// FIXME: if the `out` buffer using copy to grow when the capacity is not enough, setting the pre-allocated `totalLenField` may not take effect.
if totalLenField == nil {
return perrors.NewProtocolErrorWithMsg("no buffer allocated for the header length field")
}
payloadLen := out.MallocLen() - Size32
binary.BigEndian.PutUint32(totalLenField, uint32(payloadLen))
return err
}

// Select to use thrift or protobuf according to the protocol.
func (c *defaultCodec) encodePayload(ctx context.Context, message remote.Message, out remote.ByteBuffer) error {
pCodec, err := remote.GetPayloadCodec(message)
Expand Down Expand Up @@ -373,3 +459,11 @@ func checkPayloadSize(payloadLen, maxSize int) error {
}
return nil
}

// getWrittenBytes gets all written bytes from linkbuffer.
func getWrittenBytes(lb *netpoll.LinkBuffer) (buf []byte, err error) {
if err = lb.Flush(); err != nil {
return nil, err
}
return lb.Bytes(), nil
}
Loading

0 comments on commit ebd3dfe

Please sign in to comment.