diff --git a/stats/opentelemetry/client_metrics.go b/stats/opentelemetry/client_metrics.go index 4fffba60fb33..c3fe81b5e3e0 100644 --- a/stats/opentelemetry/client_metrics.go +++ b/stats/opentelemetry/client_metrics.go @@ -21,7 +21,9 @@ import ( "sync/atomic" "time" + otelattribute "go.opentelemetry.io/otel/attribute" otelcodes "go.opentelemetry.io/otel/codes" + otelmetric "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/trace" "google.golang.org/grpc" grpccodes "google.golang.org/grpc/codes" @@ -30,9 +32,6 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/stats" "google.golang.org/grpc/status" - - otelattribute "go.opentelemetry.io/otel/attribute" - otelmetric "go.opentelemetry.io/otel/metric" ) type clientStatsHandler struct { @@ -41,6 +40,14 @@ type clientStatsHandler struct { clientMetrics clientMetrics } +type clientMetricsStatsHandler struct { + *clientStatsHandler +} + +type clientTracingStatsHandler struct { + *clientStatsHandler +} + func (h *clientStatsHandler) initializeMetrics() { // Will set no metrics to record, logically making this stats handler a // no-op. @@ -71,7 +78,7 @@ func (h *clientStatsHandler) initializeMetrics() { rm.registerMetrics(metrics, meter) } -func (h *clientStatsHandler) unaryInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { +func (h *clientMetricsStatsHandler) unaryInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { ci := &callInfo{ target: cc.CanonicalTarget(), method: h.determineMethod(method, opts...), @@ -88,12 +95,8 @@ func (h *clientStatsHandler) unaryInterceptor(ctx context.Context, method string } startTime := time.Now() - var span trace.Span - if h.options.isTracingEnabled() { - ctx, span = h.createCallTraceSpan(ctx, method) - } err := invoker(ctx, method, req, reply, cc, opts...) - h.perCallTracesAndMetrics(ctx, err, startTime, ci, span) + h.perCallMetrics(ctx, err, startTime, ci) return err } @@ -109,7 +112,7 @@ func (h *clientStatsHandler) determineMethod(method string, opts ...grpc.CallOpt return "other" } -func (h *clientStatsHandler) streamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { +func (h *clientMetricsStatsHandler) streamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { ci := &callInfo{ target: cc.CanonicalTarget(), method: h.determineMethod(method, opts...), @@ -126,37 +129,63 @@ func (h *clientStatsHandler) streamInterceptor(ctx context.Context, desc *grpc.S } startTime := time.Now() + callback := func(err error) { + h.perCallMetrics(ctx, err, startTime, ci) + } + opts = append([]grpc.CallOption{grpc.OnFinish(callback)}, opts...) + return streamer(ctx, desc, cc, method, opts...) +} + +// perCallMetrics records per call metrics. +func (h *clientMetricsStatsHandler) perCallMetrics(ctx context.Context, err error, startTime time.Time, ci *callInfo) { + callLatency := float64(time.Since(startTime)) / float64(time.Second) + attrs := otelmetric.WithAttributeSet(otelattribute.NewSet( + otelattribute.String("grpc.method", ci.method), + otelattribute.String("grpc.target", ci.target), + otelattribute.String("grpc.status", canonicalString(status.Code(err))), + )) + h.clientMetrics.callDuration.Record(ctx, callLatency, attrs) +} + +func (h *clientTracingStatsHandler) unaryInterceptor(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ci := &callInfo{ + target: cc.CanonicalTarget(), + method: h.determineMethod(method, opts...), + } + ctx = setCallInfo(ctx, ci) + var span trace.Span - if h.options.isTracingEnabled() { - ctx, span = h.createCallTraceSpan(ctx, method) + ctx, span = h.createCallTraceSpan(ctx, method) + err := invoker(ctx, method, req, reply, cc, opts...) + h.perCallTraces(err, span) + return err +} + +func (h *clientTracingStatsHandler) streamInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + ci := &callInfo{ + target: cc.CanonicalTarget(), + method: h.determineMethod(method, opts...), } + ctx = setCallInfo(ctx, ci) + + var span trace.Span + ctx, span = h.createCallTraceSpan(ctx, method) callback := func(err error) { - h.perCallTracesAndMetrics(ctx, err, startTime, ci, span) + h.perCallTraces(err, span) } opts = append([]grpc.CallOption{grpc.OnFinish(callback)}, opts...) return streamer(ctx, desc, cc, method, opts...) } -// perCallTracesAndMetrics records per call trace spans and metrics. -func (h *clientStatsHandler) perCallTracesAndMetrics(ctx context.Context, err error, startTime time.Time, ci *callInfo, ts trace.Span) { - if h.options.isTracingEnabled() { - s := status.Convert(err) - if s.Code() == grpccodes.OK { - ts.SetStatus(otelcodes.Ok, s.Message()) - } else { - ts.SetStatus(otelcodes.Error, s.Message()) - } - ts.End() - } - if h.options.isMetricsEnabled() { - callLatency := float64(time.Since(startTime)) / float64(time.Second) - attrs := otelmetric.WithAttributeSet(otelattribute.NewSet( - otelattribute.String("grpc.method", ci.method), - otelattribute.String("grpc.target", ci.target), - otelattribute.String("grpc.status", canonicalString(status.Code(err))), - )) - h.clientMetrics.callDuration.Record(ctx, callLatency, attrs) +// perCallTraces records per call trace spans. +func (h *clientTracingStatsHandler) perCallTraces(err error, ts trace.Span) { + s := status.Convert(err) + if s.Code() == grpccodes.OK { + ts.SetStatus(otelcodes.Ok, s.Message()) + } else { + ts.SetStatus(otelcodes.Error, s.Message()) } + ts.End() } // TagConn exists to satisfy stats.Handler. @@ -198,6 +227,7 @@ func (h *clientStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) }) } +// HandleRPC implements per RPC tracing and stats implementation. func (h *clientStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { ri := getRPCInfo(ctx) if ri == nil { diff --git a/stats/opentelemetry/opentelemetry.go b/stats/opentelemetry/opentelemetry.go index d99169e2da67..2ad1451981c4 100644 --- a/stats/opentelemetry/opentelemetry.go +++ b/stats/opentelemetry/opentelemetry.go @@ -120,7 +120,17 @@ type MetricsOptions struct { func DialOption(o Options) grpc.DialOption { csh := &clientStatsHandler{options: o} csh.initializeMetrics() - return joinDialOptions(grpc.WithChainUnaryInterceptor(csh.unaryInterceptor), grpc.WithChainStreamInterceptor(csh.streamInterceptor), grpc.WithStatsHandler(csh)) + var interceptors []grpc.DialOption + if o.isMetricsEnabled() { + metricsHandler := &clientMetricsStatsHandler{clientStatsHandler: csh} + interceptors = append(interceptors, grpc.WithChainUnaryInterceptor(metricsHandler.unaryInterceptor), grpc.WithChainStreamInterceptor(metricsHandler.streamInterceptor)) + } + if o.isTracingEnabled() { + tracingHandler := &clientTracingStatsHandler{clientStatsHandler: csh} + interceptors = append(interceptors, grpc.WithChainUnaryInterceptor(tracingHandler.unaryInterceptor), grpc.WithChainStreamInterceptor(tracingHandler.streamInterceptor)) + } + interceptors = append(interceptors, grpc.WithStatsHandler(csh)) + return joinDialOptions(interceptors...) } var joinServerOptions = internal.JoinServerOptions.(func(...grpc.ServerOption) grpc.ServerOption) @@ -140,7 +150,17 @@ var joinServerOptions = internal.JoinServerOptions.(func(...grpc.ServerOption) g func ServerOption(o Options) grpc.ServerOption { ssh := &serverStatsHandler{options: o} ssh.initializeMetrics() - return joinServerOptions(grpc.ChainUnaryInterceptor(ssh.unaryInterceptor), grpc.ChainStreamInterceptor(ssh.streamInterceptor), grpc.StatsHandler(ssh)) + var interceptors []grpc.ServerOption + if o.isMetricsEnabled() { + metricsHandler := &serverMetricsStatsHandler{serverStatsHandler: ssh} + interceptors = append(interceptors, grpc.ChainUnaryInterceptor(metricsHandler.unaryInterceptor), grpc.ChainStreamInterceptor(metricsHandler.streamInterceptor)) + } + if o.isTracingEnabled() { + tracingHandler := &serverTracingStatsHandler{serverStatsHandler: ssh} + interceptors = append(interceptors, grpc.ChainUnaryInterceptor(tracingHandler.unaryInterceptor), grpc.ChainStreamInterceptor(tracingHandler.streamInterceptor)) + } + interceptors = append(interceptors, grpc.StatsHandler(ssh)) + return joinServerOptions(interceptors...) } // callInfo is information pertaining to the lifespan of the RPC client side. diff --git a/stats/opentelemetry/server_metrics.go b/stats/opentelemetry/server_metrics.go index da3f60a9ebe5..cd50f5243ec3 100644 --- a/stats/opentelemetry/server_metrics.go +++ b/stats/opentelemetry/server_metrics.go @@ -38,6 +38,14 @@ type serverStatsHandler struct { serverMetrics serverMetrics } +type serverMetricsStatsHandler struct { + *serverStatsHandler +} + +type serverTracingStatsHandler struct { + *serverStatsHandler +} + func (h *serverStatsHandler) initializeMetrics() { // Will set no metrics to record, logically making this stats handler a // no-op. @@ -211,24 +219,41 @@ func (h *serverStatsHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) // HandleRPC implements per RPC tracing and stats implementation. func (h *serverStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { + if h.options.isMetricsEnabled() { + metricsHandler := &serverMetricsStatsHandler{serverStatsHandler: h} + metricsHandler.HandleRPC(ctx, rs) + } + if h.options.isTracingEnabled() { + tracingHandler := &serverTracingStatsHandler{serverStatsHandler: h} + tracingHandler.HandleRPC(ctx, rs) + } +} + +// HandleRPC implements per RPC stats handling for metrics. +func (h *serverMetricsStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { ri := getRPCInfo(ctx) if ri == nil { - logger.Error("ctx passed into server side stats handler metrics event handling has no server call data present") + logger.Error("ctx passed into server metrics stats handler metrics event handling has no server call data present") return } - if h.options.isTracingEnabled() { - populateSpan(rs, ri.ai) - } - if h.options.isMetricsEnabled() { - h.processRPCData(ctx, rs, ri.ai) + h.processRPCData(ctx, rs, ri.ai) +} + +// HandleRPC implements per RPC tracing handling for tracing. +func (h *serverTracingStatsHandler) HandleRPC(ctx context.Context, rs stats.RPCStats) { + ri := getRPCInfo(ctx) + if ri == nil { + logger.Error("ctx passed into server tracing stats handler tracing event handling has no server call data present") + return } + populateSpan(rs, ri.ai) } -func (h *serverStatsHandler) processRPCData(ctx context.Context, s stats.RPCStats, ai *attemptInfo) { +func (h *serverMetricsStatsHandler) processRPCData(ctx context.Context, s stats.RPCStats, ai *attemptInfo) { switch st := s.(type) { case *stats.InHeader: - if ai.pluginOptionLabels == nil && h.options.MetricsOptions.pluginOption != nil { - labels := h.options.MetricsOptions.pluginOption.GetLabels(st.Header) + if ai.pluginOptionLabels == nil && h.serverStatsHandler.options.MetricsOptions.pluginOption != nil { + labels := h.serverStatsHandler.options.MetricsOptions.pluginOption.GetLabels(st.Header) if labels == nil { labels = map[string]string{} // Shouldn't return a nil map. Make it empty if so to ignore future Get Calls for this Attempt. } @@ -237,7 +262,7 @@ func (h *serverStatsHandler) processRPCData(ctx context.Context, s stats.RPCStat attrs := otelmetric.WithAttributeSet(otelattribute.NewSet( otelattribute.String("grpc.method", ai.method), )) - h.serverMetrics.callStarted.Add(ctx, 1, attrs) + h.serverStatsHandler.serverMetrics.callStarted.Add(ctx, 1, attrs) case *stats.OutPayload: atomic.AddInt64(&ai.sentCompressedBytes, int64(st.CompressedLength)) case *stats.InPayload: