Skip to content

Commit

Permalink
*: fix several bugs in QueryRegion (#9055)
Browse files Browse the repository at this point in the history
ref #8690

Fix several bugs in QueryRegion, including:

- Perform a deep copy of the region results from a gRPC response if used multiple times.
- Store the `QueryRegion` stream when it is not nil.

Signed-off-by: JmPotato <[email protected]>
  • Loading branch information
JmPotato authored Feb 14, 2025
1 parent e893032 commit b7977dd
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
52 changes: 48 additions & 4 deletions client/clients/router/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"sync/atomic"
"time"

"github.com/gogo/protobuf/proto"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
"google.golang.org/grpc"
Expand Down Expand Up @@ -78,6 +79,28 @@ func ConvertToRegion(res regionResponse) *Region {
return r
}

// convertToRegionCopy converts and deep-copies the region response to a new region.
func convertToRegionCopy(res regionResponse) *Region {
region := res.GetRegion()
if region == nil {
return nil
}

r := &Region{
Meta: proto.Clone(region).(*metapb.Region),
Leader: proto.Clone(res.GetLeader()).(*metapb.Peer),
Buckets: proto.Clone(res.GetBuckets()).(*metapb.Buckets),
}
for _, s := range res.GetDownPeers() {
r.DownPeers = append(r.DownPeers, proto.Clone(s.Peer).(*metapb.Peer))
}
for _, p := range res.GetPendingPeers() {
r.PendingPeers = append(r.PendingPeers, proto.Clone(p).(*metapb.Peer))
}

return r
}

// KeyRange defines a range of keys in bytes.
type KeyRange struct {
StartKey []byte
Expand Down Expand Up @@ -204,13 +227,24 @@ func (c *Cli) newRequest(ctx context.Context) *Request {
req := c.reqPool.Get().(*Request)
req.requestCtx = ctx
req.clientCtx = c.ctx
// Reset the request fields before using it.
req.key = nil
req.prevKey = nil
req.id = 0
req.needBuckets = false
req.region = nil
// Initialize the runtime fields.
req.pool = c.reqPool

return req
}

func requestFinisher(resp *pdpb.QueryRegionResponse) batch.FinisherFunc[*Request] {
var keyIdx, prevKeyIdx int
var (
keyIdx, prevKeyIdx int
// regionUsed is used to record whether the region has been used.
regionUsed = make(map[uint64]struct{})
)
return func(_ int, req *Request, err error) {
requestCtx := req.requestCtx
defer trace.StartRegion(requestCtx, "pdclient.regionReqDone").End()
Expand All @@ -230,8 +264,15 @@ func requestFinisher(resp *pdpb.QueryRegionResponse) batch.FinisherFunc[*Request
} else if req.id != 0 {
id = req.id
}
if region, ok := resp.RegionsById[id]; ok {
req.region = ConvertToRegion(region)
if regionResp, ok := resp.RegionsById[id]; ok {
// Since the region results may be modified by the requester,
// we need to ensure each region result returned is unique.
if _, used := regionUsed[id]; used {
req.region = convertToRegionCopy(regionResp)
} else {
req.region = ConvertToRegion(regionResp)
regionUsed[id] = struct{}{}
}
}
req.tryDone(err)
}
Expand Down Expand Up @@ -339,7 +380,10 @@ func (c *Cli) updateConnection(ctx context.Context) {
if err != nil {
log.Error("[router] failed to create the router stream connection", errs.ZapError(err))
}
c.conCtxMgr.Store(ctx, url, stream)
// Store the stream connection context if it is successfully created.
if stream != nil {
c.conCtxMgr.Store(ctx, url, stream)
}
// TODO: support the forwarding mechanism for the router client.
// TODO: support sending the router requests to the follower nodes.
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/core/region.go
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,7 @@ func (r *RegionsInfo) QueryRegions(
panic("returned prev regions count mismatch with the input keys")
}
// Build the key -> ID map for the final results.
regionsByID := make(map[uint64]*pdpb.RegionResponse, len(regions))
regionsByID := make(map[uint64]*pdpb.RegionResponse, len(regions)+len(prevRegions)+len(ids))
keyIDMap := sortOutKeyIDMap(regionsByID, regions, needBuckets)
prevKeyIDMap := sortOutKeyIDMap(regionsByID, prevRegions, needBuckets)
// Iterate the region IDs to find the regions.
Expand Down

0 comments on commit b7977dd

Please sign in to comment.