Skip to content

Commit a02aa88

Browse files
Move RPC byte-counting logic to separate gRPC stream wrappers
1 parent 262ae36 commit a02aa88

File tree

7 files changed

+286
-25
lines changed

7 files changed

+286
-25
lines changed

enterprise/server/byte_stream_server_proxy/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ go_library(
1515
"//server/real_environment",
1616
"//server/remote_cache/digest",
1717
"//server/util/authutil",
18+
"//server/util/grpc_stream",
1819
"//server/util/log",
1920
"//server/util/status",
2021
"//server/util/tracing",

enterprise/server/byte_stream_server_proxy/byte_stream_server_proxy.go

+12-17
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/buildbuddy-io/buildbuddy/server/real_environment"
1313
"github.com/buildbuddy-io/buildbuddy/server/remote_cache/digest"
1414
"github.com/buildbuddy-io/buildbuddy/server/util/authutil"
15+
"github.com/buildbuddy-io/buildbuddy/server/util/grpc_stream"
1516
"github.com/buildbuddy-io/buildbuddy/server/util/log"
1617
"github.com/buildbuddy-io/buildbuddy/server/util/status"
1718
"github.com/buildbuddy-io/buildbuddy/server/util/tracing"
@@ -75,7 +76,7 @@ func (s *ByteStreamServerProxy) Read(req *bspb.ReadRequest, stream bspb.ByteStre
7576
return err
7677
}
7778

78-
localReadStream, err := s.local.Read(ctx, req)
79+
localGRPCReadStream, err := s.local.Read(ctx, req)
7980
if err != nil {
8081
if skipRemote {
8182
recordReadMetrics(metrics.MissStatusLabel, requestTypeLabel, err, 0)
@@ -96,12 +97,9 @@ func (s *ByteStreamServerProxy) Read(req *bspb.ReadRequest, stream bspb.ByteStre
9697
}
9798

9899
responseSent := false
99-
bytesRead := 0
100+
localReadStream := grpc_stream.NewByteCountingServerStream(localGRPCReadStream)
100101
for {
101102
rsp, err := localReadStream.Recv()
102-
if rsp != nil {
103-
bytesRead += len(rsp.Data)
104-
}
105103
if err == io.EOF {
106104
break
107105
}
@@ -112,10 +110,10 @@ func (s *ByteStreamServerProxy) Read(req *bspb.ReadRequest, stream bspb.ByteStre
112110
// the remote cache, but keep it simple for now.
113111
if responseSent {
114112
log.CtxInfof(ctx, "error midstream of local read: %s", err)
115-
recordReadMetrics(metrics.HitStatusLabel, requestTypeLabel, err, bytesRead)
113+
recordReadMetrics(metrics.HitStatusLabel, requestTypeLabel, err, int(localReadStream.GetByteCount()))
116114
return err
117115
} else if skipRemote {
118-
recordReadMetrics(metrics.MissStatusLabel, requestTypeLabel, err, bytesRead)
116+
recordReadMetrics(metrics.MissStatusLabel, requestTypeLabel, err, int(localReadStream.GetByteCount()))
119117
return err
120118
} else {
121119
// Fall back to reading remotely if the local read fails.
@@ -126,12 +124,12 @@ func (s *ByteStreamServerProxy) Read(req *bspb.ReadRequest, stream bspb.ByteStre
126124
}
127125

128126
if err := stream.Send(rsp); err != nil {
129-
recordReadMetrics(metrics.HitStatusLabel, requestTypeLabel, err, bytesRead)
127+
recordReadMetrics(metrics.HitStatusLabel, requestTypeLabel, err, int(localReadStream.GetByteCount()))
130128
return err
131129
}
132130
responseSent = true
133131
}
134-
recordReadMetrics(metrics.HitStatusLabel, requestTypeLabel, nil, bytesRead)
132+
recordReadMetrics(metrics.HitStatusLabel, requestTypeLabel, nil, int(localReadStream.GetByteCount()))
135133
return nil
136134
}
137135

@@ -213,11 +211,12 @@ func (s *ByteStreamServerProxy) readRemote(req *bspb.ReadRequest, stream bspb.By
213211
ctx, spn := tracing.StartSpan(stream.Context())
214212
defer spn.End()
215213

216-
remoteReadStream, err := s.remote.Read(ctx, req)
214+
remoteGRPCReadStream, err := s.remote.Read(ctx, req)
217215
if err != nil {
218216
log.CtxInfof(ctx, "error reading from remote: %s", err)
219217
return 0, err
220218
}
219+
remoteReadStream := grpc_stream.NewByteCountingServerStream(remoteGRPCReadStream)
221220

222221
var localWriteStream localWriter = &discardingLocalWriter{}
223222
if req.ReadOffset == 0 && !authutil.EncryptionEnabled(ctx, s.authenticator) {
@@ -235,12 +234,8 @@ func (s *ByteStreamServerProxy) readRemote(req *bspb.ReadRequest, stream bspb.By
235234
}
236235
}
237236

238-
bytesRead := 0
239237
for {
240238
rsp, err := remoteReadStream.Recv()
241-
if rsp != nil {
242-
bytesRead += len(rsp.Data)
243-
}
244239
if err != nil {
245240
if err == io.EOF {
246241
if err := localWriteStream.commit(); err != nil {
@@ -249,18 +244,18 @@ func (s *ByteStreamServerProxy) readRemote(req *bspb.ReadRequest, stream bspb.By
249244
break
250245
}
251246
log.CtxInfof(ctx, "error streaming from remote for read through: %s", err)
252-
return bytesRead, err
247+
return int(remoteReadStream.GetByteCount()), err
253248
}
254249

255250
if err := localWriteStream.send(rsp.Data); err != nil {
256251
log.CtxInfof(ctx, "Error writing locally for read through: %s", err)
257252
localWriteStream = &discardingLocalWriter{}
258253
}
259254
if err = stream.Send(rsp); err != nil {
260-
return bytesRead, err
255+
return int(remoteReadStream.GetByteCount()), err
261256
}
262257
}
263-
return bytesRead, nil
258+
return int(remoteReadStream.GetByteCount()), nil
264259
}
265260

266261
func (s *ByteStreamServerProxy) Write(stream bspb.ByteStream_WriteServer) error {

enterprise/server/content_addressable_storage_server_proxy/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ go_library(
1616
"//server/real_environment",
1717
"//server/remote_cache/digest",
1818
"//server/util/authutil",
19+
"//server/util/grpc_stream",
1920
"//server/util/log",
2021
"//server/util/proto",
2122
"//server/util/rpcutil",

enterprise/server/content_addressable_storage_server_proxy/content_addressable_storage_server_proxy.go

+7-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/buildbuddy-io/buildbuddy/server/real_environment"
1515
"github.com/buildbuddy-io/buildbuddy/server/remote_cache/digest"
1616
"github.com/buildbuddy-io/buildbuddy/server/util/authutil"
17+
"github.com/buildbuddy-io/buildbuddy/server/util/grpc_stream"
1718
"github.com/buildbuddy-io/buildbuddy/server/util/log"
1819
"github.com/buildbuddy-io/buildbuddy/server/util/proto"
1920
"github.com/buildbuddy-io/buildbuddy/server/util/rpcutil"
@@ -292,7 +293,6 @@ func (s *CASServerProxy) batchReadBlobsRemote(ctx context.Context, readReq *repb
292293
return readResp, nil
293294
}
294295

295-
// TODO(iain): record per-byte metrics here as well as above.
296296
func (s *CASServerProxy) GetTree(req *repb.GetTreeRequest, stream repb.ContentAddressableStorage_GetTreeServer) error {
297297
if proxy_util.SkipRemote(stream.Context()) {
298298
return status.UnimplementedError("Skip remote not implemented")
@@ -306,18 +306,18 @@ func (s *CASServerProxy) GetTree(req *repb.GetTreeRequest, stream repb.ContentAd
306306

307307
func (s *CASServerProxy) getTreeWithoutCaching(req *repb.GetTreeRequest, stream repb.ContentAddressableStorage_GetTreeServer) error {
308308
digests := 0
309-
bytes := 0
309+
remoteGRPCStream, err := s.remote.GetTree(stream.Context(), req)
310+
if err != nil {
311+
return err
312+
}
313+
remoteStream := grpc_stream.NewByteCountingServerStream(remoteGRPCStream)
310314
defer func() {
311315
recordMetrics(
312316
"GetTree",
313317
metrics.MissStatusLabel,
314318
map[string]int{metrics.MissStatusLabel: digests},
315-
map[string]int{metrics.MissStatusLabel: bytes})
319+
map[string]int{metrics.MissStatusLabel: int(remoteStream.GetByteCount())})
316320
}()
317-
remoteStream, err := s.remote.GetTree(stream.Context(), req)
318-
if err != nil {
319-
return err
320-
}
321321
for {
322322
rsp, err := remoteStream.Recv()
323323
if err == io.EOF {
@@ -330,7 +330,6 @@ func (s *CASServerProxy) getTreeWithoutCaching(req *repb.GetTreeRequest, stream
330330
digests += len(dir.GetFiles())
331331
digests += len(dir.GetDirectories())
332332
}
333-
bytes += proto.Size(rsp)
334333
if err = stream.Send(rsp); err != nil {
335334
return err
336335
}

server/util/grpc_stream/BUILD

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
2+
3+
go_library(
4+
name = "grpc_stream",
5+
srcs = ["grpc_stream.go"],
6+
importpath = "github.com/buildbuddy-io/buildbuddy/server/util/grpc_stream",
7+
visibility = ["//visibility:public"],
8+
deps = [
9+
"//server/util/proto",
10+
"@org_golang_google_grpc//:grpc",
11+
"@org_golang_google_grpc//metadata",
12+
"@org_golang_google_protobuf//reflect/protoreflect",
13+
],
14+
)
15+
16+
go_test(
17+
name = "grpc_stream_test",
18+
srcs = ["grpc_stream_test.go"],
19+
deps = [
20+
":grpc_stream",
21+
"//server/environment",
22+
"//server/testutil/testenv",
23+
"//server/testutil/testport",
24+
"//server/util/grpc_client",
25+
"//server/util/grpc_server",
26+
"//server/util/proto",
27+
"@com_github_stretchr_testify//require",
28+
"@org_golang_google_genproto_googleapis_bytestream//:bytestream",
29+
],
30+
)
+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package grpc_stream
2+
3+
import (
4+
"context"
5+
"sync/atomic"
6+
7+
"github.com/buildbuddy-io/buildbuddy/server/util/proto"
8+
"google.golang.org/grpc"
9+
"google.golang.org/grpc/metadata"
10+
"google.golang.org/protobuf/reflect/protoreflect"
11+
)
12+
13+
// gRPC has three types of streaming RPCs, as described here:
14+
// https://grpc.io/docs/what-is-grpc/core-concepts/#server-streaming-rpc
15+
// - Server streaming (client sends one message, receives many)
16+
// - Client treaming (client sends many messages, receives one)
17+
// - Bidirectional streaming (client and server send many message)
18+
//
19+
// This generic interface represents the server streaming RPC.
20+
type ServerStream[T protoreflect.ProtoMessage] interface {
21+
Recv() (T, error)
22+
grpc.ClientStream
23+
}
24+
25+
// A wrapper around a ServerStream that counts the number of bytes returned by
26+
// the server over the lifetime of the stream.
27+
type ByteCountingServerStream[T protoreflect.ProtoMessage] struct {
28+
bytes *atomic.Int64
29+
proxy ServerStream[T]
30+
}
31+
32+
func NewByteCountingServerStream[T protoreflect.ProtoMessage](proxy ServerStream[T]) ByteCountingServerStream[T] {
33+
bytes := &atomic.Int64{}
34+
return ByteCountingServerStream[T]{
35+
bytes: bytes,
36+
proxy: proxy,
37+
}
38+
}
39+
40+
// Returns the number of bytes received on the stream so far.
41+
func (s *ByteCountingServerStream[T]) GetByteCount() int64 {
42+
return s.bytes.Load()
43+
}
44+
45+
func (s *ByteCountingServerStream[T]) Recv() (T, error) {
46+
message, err := s.proxy.Recv()
47+
s.bytes.Add(int64(proto.Size(message)))
48+
return message, err
49+
}
50+
51+
func (s *ByteCountingServerStream[T]) Header() (metadata.MD, error) {
52+
return s.proxy.Header()
53+
}
54+
55+
func (s *ByteCountingServerStream[T]) Trailer() metadata.MD {
56+
return s.proxy.Trailer()
57+
}
58+
59+
func (s *ByteCountingServerStream[T]) CloseSend() error {
60+
return s.proxy.CloseSend()
61+
}
62+
63+
func (s *ByteCountingServerStream[T]) Context() context.Context {
64+
return s.proxy.Context()
65+
}
66+
67+
func (s *ByteCountingServerStream[T]) SendMsg(message any) error {
68+
return s.proxy.SendMsg(message)
69+
}
70+
71+
func (s *ByteCountingServerStream[T]) RecvMsg(message any) error {
72+
return s.proxy.RecvMsg(message)
73+
}
74+
75+
// This generic interface represents the client streaming RPC.
76+
type ClientStream[S, T protoreflect.ProtoMessage] interface {
77+
Send(S) error
78+
CloseAndRecv() (T, error)
79+
grpc.ClientStream
80+
}
81+
82+
// A wrapper around a ClientStream that counts the number of bytes sent by the
83+
// client over the lifetime of the stream.
84+
type ByteCountingClientStream[S, T protoreflect.ProtoMessage] struct {
85+
bytes *atomic.Int64
86+
proxy ClientStream[S, T]
87+
}
88+
89+
func NewByteCountingClientStream[S, T protoreflect.ProtoMessage](proxy ClientStream[S, T]) ByteCountingClientStream[S, T] {
90+
bytes := &atomic.Int64{}
91+
return ByteCountingClientStream[S, T]{
92+
bytes: bytes,
93+
proxy: proxy,
94+
}
95+
}
96+
97+
// Returns the number of bytes sent to the stream so far.
98+
func (s *ByteCountingClientStream[S, T]) GetByteCount() int64 {
99+
return s.bytes.Load()
100+
}
101+
102+
func (s *ByteCountingClientStream[S, T]) Send(message S) error {
103+
s.bytes.Add(int64(proto.Size(message)))
104+
return s.proxy.Send(message)
105+
}
106+
107+
func (s *ByteCountingClientStream[S, T]) CloseAndRecv() (T, error) {
108+
return s.proxy.CloseAndRecv()
109+
}
110+
111+
func (s *ByteCountingClientStream[S, T]) Header() (metadata.MD, error) {
112+
return s.proxy.Header()
113+
}
114+
115+
func (s *ByteCountingClientStream[S, T]) Trailer() metadata.MD {
116+
return s.proxy.Trailer()
117+
}
118+
119+
func (s *ByteCountingClientStream[S, T]) CloseSend() error {
120+
return s.proxy.CloseSend()
121+
}
122+
123+
func (s *ByteCountingClientStream[S, T]) Context() context.Context {
124+
return s.proxy.Context()
125+
}
126+
127+
func (s *ByteCountingClientStream[S, T]) SendMsg(message any) error {
128+
return s.proxy.SendMsg(message)
129+
}
130+
131+
func (s *ByteCountingClientStream[S, T]) RecvMsg(message any) error {
132+
return s.proxy.RecvMsg(message)
133+
}

0 commit comments

Comments
 (0)