diff --git a/internal/backoff/backoff_test.go b/internal/backoff/backoff_test.go index 81cf535318..8890ef406c 100644 --- a/internal/backoff/backoff_test.go +++ b/internal/backoff/backoff_test.go @@ -60,7 +60,7 @@ func TestNew(t *testing.T) { jitterLimit: float64(time.Minute), backoffFactor: 1.1, maxRetryCount: 50, - errLog: true, + errLog: false, durationLimit: float64(time.Hour) / 1.1, }, checkFunc: func(got *backoff, want *backoff) error { diff --git a/internal/backoff/option_test.go b/internal/backoff/option_test.go index c80df571ed..5a1f8af951 100644 --- a/internal/backoff/option_test.go +++ b/internal/backoff/option_test.go @@ -374,10 +374,6 @@ func TestDefaultOptions(t *testing.T) { return errors.New("invalid param (maxRetryCount) was set") } - if got.errLog != true { - return errors.New("invalid param (errLog) was set") - } - return nil }, }, diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 8b5a97d712..6bc26c3589 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -996,13 +996,11 @@ func TestGetActualValue(t *testing.T) { func() test { fname := "version" return test{ - name: "return file contents when val is file://env", + name: "return empty when not exists file contents", args: args{ val: "file://" + fname, }, - want: want{ - wantRes: "file://" + fname, - }, + want: want{}, } }(), } diff --git a/internal/config/grpc.go b/internal/config/grpc.go index 4279bdc001..431acbe86c 100644 --- a/internal/config/grpc.go +++ b/internal/config/grpc.go @@ -183,7 +183,7 @@ func (g *GRPCClient) Opts() ([]grpc.Option, error) { if g.ConnectionPool != nil { opts = append(opts, grpc.WithConnectionPoolSize(g.ConnectionPool.Size), - grpc.WithOldConnCloseDuration(g.ConnectionPool.OldConnCloseDuration), + grpc.WithOldConnCloseDelay(g.ConnectionPool.OldConnCloseDuration), grpc.WithResolveDNS(g.ConnectionPool.ResolveDNS), ) if g.ConnectionPool.EnableRebalance { diff --git a/internal/db/nosql/cassandra/option_test.go b/internal/db/nosql/cassandra/option_test.go index be7072d1f1..fb6cb39858 100644 --- a/internal/db/nosql/cassandra/option_test.go +++ b/internal/db/nosql/cassandra/option_test.go @@ -535,7 +535,7 @@ func TestWithConnectTimeout(t *testing.T) { dur: "dummy", }, want: want{ - err: errors.NewErrCriticalOption("connectTimeout", "dummy", errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy"))), + err: errors.NewErrCriticalOption("connectTimeout", "dummy", errors.New("time: invalid duration \"dummy\"")), obj: &T{}, }, }, @@ -1521,7 +1521,7 @@ func TestWithRetryPolicyMinDuration(t *testing.T) { err: errors.NewErrCriticalOption( "retryPolicyMinDuration", "dummy", - errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy")), + errors.New("time: invalid duration \"dummy\""), ), obj: &T{}, }, @@ -1613,7 +1613,7 @@ func TestWithRetryPolicyMaxDuration(t *testing.T) { err: errors.NewErrCriticalOption( "retryPolicyMaxDuration", "dummy", - errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy")), + errors.New("time: invalid duration \"dummy\""), ), obj: &T{}, }, @@ -1705,7 +1705,7 @@ func TestWithReconnectionPolicyInitialInterval(t *testing.T) { err: errors.NewErrCriticalOption( "reconnectionPolicyInitialInterval", "dummy", - errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy")), + errors.New("time: invalid duration \"dummy\""), ), obj: &T{}, }, @@ -1872,7 +1872,7 @@ func TestWithSocketKeepalive(t *testing.T) { err: errors.NewErrCriticalOption( "socketKeepalive", "dummy", - errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy")), + errors.New("time: invalid duration \"dummy\""), ), obj: &T{}, }, @@ -2977,7 +2977,7 @@ func TestWithMaxWaitSchemaAgreement(t *testing.T) { err: errors.NewErrCriticalOption( "maxWaitSchemaAgreement", "dummy", - errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy")), + errors.New("time: invalid duration \"dummy\""), ), obj: &T{}, }, @@ -3067,7 +3067,7 @@ func TestWithReconnectInterval(t *testing.T) { err: errors.NewErrCriticalOption( "reconnectInterval", "dummy", - errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy")), + errors.New("time: invalid duration \"dummy\""), ), obj: &T{}, }, @@ -3922,7 +3922,7 @@ func TestWithWriteCoalesceWaitTime(t *testing.T) { err: errors.NewErrCriticalOption( "writeCoalesceWaitTime", "dummy", - errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy")), + errors.New("time: invalid duration \"dummy\""), ), obj: &T{}, }, diff --git a/internal/errors/params.go b/internal/errors/params.go new file mode 100644 index 0000000000..114b9e02b1 --- /dev/null +++ b/internal/errors/params.go @@ -0,0 +1,20 @@ +// +// Copyright (C) 2019-2025 vdaas.org vald team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// You may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Package errors provides error types and function +package errors + +var ErrArgumentParserNotFound = New("argument parser not found") diff --git a/internal/net/grpc/client.go b/internal/net/grpc/client.go index 78fc51448b..aadca0c075 100644 --- a/internal/net/grpc/client.go +++ b/internal/net/grpc/client.go @@ -1066,7 +1066,7 @@ func (g *gRPCClient) Disconnect(ctx context.Context, addr string) error { atomic.AddUint64(&g.clientCount, ^uint64(0)) if p != nil { log.Debugf("gRPC client connection pool addr = %s will disconnect soon...", addr) - return nil, p.Disconnect() + return nil, p.Disconnect(ctx) } return nil, nil }) diff --git a/internal/net/grpc/option.go b/internal/net/grpc/option.go index b73514f559..e0f55a6882 100644 --- a/internal/net/grpc/option.go +++ b/internal/net/grpc/option.go @@ -653,7 +653,7 @@ func WithClientInterceptors(names ...string) Option { } } -func WithOldConnCloseDuration(dur string) Option { +func WithOldConnCloseDelay(dur string) Option { return func(g *gRPCClient) { if len(dur) == 0 { return diff --git a/internal/net/grpc/option_test.go b/internal/net/grpc/option_test.go index 0fa02acd09..32db9832fe 100644 --- a/internal/net/grpc/option_test.go +++ b/internal/net/grpc/option_test.go @@ -3258,7 +3258,7 @@ package grpc // } // } // -// func TestWithOldConnCloseDuration(t *testing.T) { +// func TestWithOldConnCloseDelay(t *testing.T) { // type args struct { // dur string // } @@ -3335,7 +3335,7 @@ package grpc // checkFunc = defaultCheckFunc // } // -// got := WithOldConnCloseDuration(test.args.dur) +// got := WithOldConnCloseDelay(test.args.dur) // if err := checkFunc(test.want, got); err != nil { // tt.Errorf("error = %v", err) // } diff --git a/internal/net/grpc/pool/option.go b/internal/net/grpc/pool/option.go index c7031c30e9..815d877142 100644 --- a/internal/net/grpc/pool/option.go +++ b/internal/net/grpc/pool/option.go @@ -19,133 +19,140 @@ package pool import ( "github.com/vdaas/vald/internal/backoff" + "github.com/vdaas/vald/internal/net" "github.com/vdaas/vald/internal/sync/errgroup" "github.com/vdaas/vald/internal/timeutil" ) +// Option defines a functional option for configuring the pool. type Option func(*pool) +// Default options. var defaultOptions = []Option{ WithSize(defaultPoolSize), WithStartPort(80), WithEndPort(65535), WithErrGroup(errgroup.Get()), WithDialTimeout("1s"), - WithOldConnCloseDuration("2m"), + WithOldConnCloseDelay("2m"), WithResolveDNS(true), } +// WithAddr sets the target address. It also extracts the host and port. func WithAddr(addr string) Option { return func(p *pool) { - if len(addr) == 0 { + if addr == "" { return } p.addr = addr + var err error + // Attempt to split host and port. + if p.host, p.port, err = net.SplitHostPort(addr); err != nil { + p.host = addr + } } } +// WithHost sets the target host. func WithHost(host string) Option { return func(p *pool) { - if len(host) == 0 { - return + if host != "" { + p.host = host } - p.host = host } } +// WithPort sets the target port. func WithPort(port int) Option { return func(p *pool) { if port > 0 { - return + p.port = uint16(port) } - p.port = uint16(port) } } +// WithStartPort sets the starting port for scanning. func WithStartPort(port int) Option { return func(p *pool) { if port > 0 { - return + p.startPort = uint16(port) } - p.startPort = uint16(port) } } +// WithEndPort sets the ending port for scanning. func WithEndPort(port int) Option { return func(p *pool) { if port > 0 { - return + p.endPort = uint16(port) } - p.endPort = uint16(port) } } -func WithResolveDNS(flg bool) Option { +// WithResolveDNS enables or disables DNS resolution. +func WithResolveDNS(enable bool) Option { return func(p *pool) { - p.resolveDNS = flg + p.enableDNSLookup = enable } } +// WithBackoff sets the backoff strategy. func WithBackoff(bo backoff.Backoff) Option { return func(p *pool) { if bo != nil { - return + p.bo = bo } - p.bo = bo } } +// WithSize sets the pool size. func WithSize(size uint64) Option { return func(p *pool) { if size < 1 { return } - p.size.Store(size) + p.poolSize.Store(size) } } +// WithDialOptions appends gRPC dial options. func WithDialOptions(opts ...DialOption) Option { return func(p *pool) { if len(opts) > 0 { - if len(p.dopts) > 0 { - p.dopts = append(p.dopts, opts...) - } else { - p.dopts = opts - } + p.dialOpts = append(p.dialOpts, opts...) } } } +// WithDialTimeout sets the dial timeout duration. func WithDialTimeout(dur string) Option { return func(p *pool) { - if len(dur) == 0 { + if dur == "" { return } - d, err := timeutil.Parse(dur) - if err != nil { - return + if t, err := timeutil.Parse(dur); err == nil { + p.dialTimeout = t } - p.dialTimeout = d } } -func WithOldConnCloseDuration(dur string) Option { +// WithOldConnCloseDelay sets the delay before closing old connections. +func WithOldConnCloseDelay(dur string) Option { return func(p *pool) { - if len(dur) == 0 { + if dur == "" { return } - d, err := timeutil.Parse(dur) - if err != nil { - return + if t, err := timeutil.Parse(dur); err == nil { + p.oldConnCloseDelay = t } - p.roccd = d } } +// WithErrGroup sets the errgroup for goroutine management. func WithErrGroup(eg errgroup.Group) Option { return func(p *pool) { if eg != nil { - p.eg = eg + p.errGroup = eg } } } diff --git a/internal/net/grpc/pool/option_test.go b/internal/net/grpc/pool/option_test.go index 4a9b6deaa1..b534c75e3f 100644 --- a/internal/net/grpc/pool/option_test.go +++ b/internal/net/grpc/pool/option_test.go @@ -869,7 +869,7 @@ package pool // } // } // -// func TestWithOldConnCloseDuration(t *testing.T) { +// func TestWithOldConnCloseDelay(t *testing.T) { // type args struct { // dur string // } @@ -946,7 +946,7 @@ package pool // checkFunc = defaultCheckFunc // } // -// got := WithOldConnCloseDuration(test.args.dur) +// got := WithOldConnCloseDelay(test.args.dur) // if err := checkFunc(test.want, got); err != nil { // tt.Errorf("error = %v", err) // } diff --git a/internal/net/grpc/pool/pool.go b/internal/net/grpc/pool/pool.go index 1af0edb867..4f9d3921c5 100644 --- a/internal/net/grpc/pool/pool.go +++ b/internal/net/grpc/pool/pool.go @@ -14,14 +14,15 @@ // limitations under the License. // -// Package pool provides gRPC connection pool client +// Package pool provides a lock-free gRPC connection pool client. +// This re-implementation maintains the public Conn interface unchanged while +// using atomic operations for efficient, lock-free connection management. +// Additional features such as DNS lookup, port scanning, and metrics collection are incorporated. package pool import ( "context" "fmt" - "math" - "slices" "strconv" "sync/atomic" "time" @@ -30,7 +31,8 @@ import ( "github.com/vdaas/vald/internal/errors" "github.com/vdaas/vald/internal/log" "github.com/vdaas/vald/internal/net" - "github.com/vdaas/vald/internal/safety" + "github.com/vdaas/vald/internal/net/grpc/codes" + "github.com/vdaas/vald/internal/net/grpc/status" "github.com/vdaas/vald/internal/strings" "github.com/vdaas/vald/internal/sync" "github.com/vdaas/vald/internal/sync/errgroup" @@ -38,88 +40,177 @@ import ( "google.golang.org/grpc/connectivity" ) +// Alias types for clarity. type ( ClientConn = grpc.ClientConn DialOption = grpc.DialOption ) +// Conn defines the interface for a gRPC connection pool. type Conn interface { + // Connect establishes connections for all slots. Connect(context.Context) (Conn, error) - Disconnect() error - Do(ctx context.Context, f func(*ClientConn) error) error - Get(ctx context.Context) (conn *ClientConn, ok bool) + // Disconnect gracefully closes all connections in the pool. + Disconnect(context.Context) error + // Do executes the provided function using a healthy connection. + Do(context.Context, func(*ClientConn) error) error + // Get returns a healthy connection from the pool, if available. + Get(context.Context) (*ClientConn, bool) + // IsHealthy checks the overall health of the pool. IsHealthy(context.Context) bool + // IsIPConn indicates whether the pool is using direct IP connections. IsIPConn() bool + // Len returns the number of connection slots. Len() uint64 + // Size returns the configured pool size. Size() uint64 - Reconnect(ctx context.Context, force bool) (Conn, error) + // Reconnect re-establishes connections if the pool is unhealthy or if forced. + Reconnect(context.Context, bool) (Conn, error) + // String returns a string representation of the pool's state. String() string } +// poolConn wraps a single gRPC connection and its target address. type poolConn struct { - conn *ClientConn - addr string + conn *ClientConn // Underlying gRPC connection. + addr string // Target address used for dialing this connection. } +// Close gracefully closes the connection with the specified delay. +// It periodically checks the connection state until either the connection is closed or the delay elapses. +func (pc *poolConn) Close(ctx context.Context, delay time.Duration) error { + // Determine the ticker interval (at least 5ms, at most 5s). + interval := delay / 10 + if interval < 5*time.Millisecond { + interval = 5 * time.Millisecond + } else if interval > time.Minute { + interval = 5 * time.Second + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + + // Create a context with timeout to ensure closure does not hang indefinitely. + ctx, cancel := context.WithTimeout(ctx, delay) + defer cancel() + + log.Debugf("Closing connection for %s with delay %s", pc.addr, delay) + for { + err := pc.conn.Close() + if err != nil && !status.Is(err, codes.Canceled) { + return err + } + select { + case <-ctx.Done(): + if ctx.Err() != nil && + !errors.Is(ctx.Err(), context.DeadlineExceeded) && + !errors.Is(ctx.Err(), context.Canceled) { + return ctx.Err() + } + return nil + case <-ticker.C: + switch pc.conn.GetState() { + case connectivity.Idle, connectivity.Connecting, connectivity.TransientFailure: + err := pc.conn.Close() + if err != nil && !status.Is(err, codes.Canceled) { + return err + } + return nil + case connectivity.Shutdown: + return nil + } + } + } +} + +// pool implements the Conn interface. +// It stores connection slots in a lock-free manner using an atomic.Value. type pool struct { - pool []atomic.Pointer[poolConn] - startPort uint16 - endPort uint16 - host string - port uint16 - addr string - size atomic.Uint64 - current atomic.Uint64 - bo backoff.Backoff - eg errgroup.Group - dopts []DialOption - dialTimeout time.Duration - roccd time.Duration // reconnection old connection closing duration - closing atomic.Bool - pmu sync.RWMutex - isIP bool - resolveDNS bool - reconnectHash atomic.Pointer[string] -} - -const defaultPoolSize = 4 - -func New(ctx context.Context, opts ...Option) (c Conn, err error) { - p := new(pool) + // connSlots holds a slice of atomic pointers to poolConn. + connSlots atomic.Pointer[[]atomic.Pointer[poolConn]] // holds []atomic.Pointer[poolConn] + + // Configuration parameters. + startPort uint16 // Starting port for scanning if needed. + endPort uint16 // Ending port for scanning if needed. + host string // Target host. + port uint16 // Target port. + addr string // Complete address (host:port). + isIPAddr bool // True if the target is an IP address. + enableDNSLookup bool // Whether to perform DNS resolution. + + // Pool management fields. + poolSize atomic.Uint64 // Configured pool size. + currentIndex atomic.Uint64 // Atomic counter for round-robin indexing. + + // gRPC dial options and timeouts. + dialOpts []DialOption + dialTimeout time.Duration // Timeout for dialing a connection. + oldConnCloseDelay time.Duration // Delay before closing old connections. + + // Retry/backoff strategy. + bo backoff.Backoff + + // Goroutine management. + errGroup errgroup.Group + + // Used for DNS change detection during reconnection. + dnsHash atomic.Pointer[string] + + // Flag indicating whether the pool is closing. + closing atomic.Bool +} + +// Default pool size. +const defaultPoolSize = uint64(4) +// Global metrics are stored in a sync.Map (key: address, value: healthy connection count). +var metrics sync.Map[string, uint64] + +// New creates a new connection pool with the provided options. +// It parses the target address, initializes the connection slots, and performs an initial dial check. +func New(ctx context.Context, opts ...Option) (Conn, error) { + p := &pool{ + dialTimeout: time.Second, + oldConnCloseDelay: 2 * time.Minute, + enableDNSLookup: false, + } + // Apply default and user-specified options. for _, opt := range append(defaultOptions, opts...) { opt(p) } - p.init(true) + if p.addr == "" { + return nil, errors.Errorf("target address is not provided") + } + + // Initialize the connection slots. + p.init() p.closing.Store(false) - var ( - isIPv4, isIPv6 bool - port uint16 - ) + // Parse the address to extract host and port. + var err error + var isIPv4, isIPv6 bool p.host, p.port, _, isIPv4, isIPv6, err = net.Parse(p.addr) - p.isIP = isIPv4 || isIPv6 + p.isIPAddr = isIPv4 || isIPv6 if err != nil { log.Warnf("failed to parse addr %s: %s", p.addr, err) + // Fallback: split using Cut. if p.host == "" { - var ( - ok bool - portStr string - ) - p.host, portStr, ok = strings.Cut(p.addr, ":") - if !ok { - p.host = p.addr - } else { - portNum, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - p.port = uint16(portNum) + p.host, p.port, err = net.SplitHostPort(p.addr) + if err != nil { + if host, portStr, ok := strings.Cut(p.addr, ":"); ok { + p.host = host + if portNum, err := strconv.ParseUint(portStr, 10, 16); err == nil { + p.port = uint16(portNum) + } + } else { + p.host = p.addr } } } + // If port is still zero, attempt port scanning. if p.port == 0 { - port, err = p.scanGRPCPort(ctx) - if err != nil { + var port uint16 + if port, err = p.scanGRPCPort(ctx); err != nil { return nil, err } p.port = port @@ -127,23 +218,19 @@ func New(ctx context.Context, opts ...Option) (c Conn, err error) { p.addr = net.JoinHostPort(p.host, p.port) } - conn, err := grpc.NewClient(p.addr, p.dopts...) + log.Debugf("Initial connection target: %s, host: %s, port: %d, isIP: %t", p.addr, p.host, p.port, p.isIPAddr) + // Perform an initial dial check. + conn, err := p.dial(ctx, p.addr) if err != nil { - log.Warnf("grpc.New initial Dial check to %s returned error: %v", p.addr, err) - if conn != nil { - err = conn.Close() - if err != nil { - log.Warn("failed to close connection:", err) - } - } - - port, err := p.scanGRPCPort(ctx) - if err != nil { + log.Warnf("Initial dial check to %s failed: %v", p.addr, err) + var port uint16 + if port, err = p.scanGRPCPort(ctx); err != nil { return nil, err } p.port = port p.addr = net.JoinHostPort(p.host, p.port) - conn, err = grpc.NewClient(p.addr, p.dopts...) + log.Debugf("Fallback target: %s, host: %s, port: %d, isIP: %t", p.addr, p.host, p.port, p.isIPAddr) + conn, err = p.dial(ctx, p.addr) if err != nil { if conn != nil { cerr := conn.Close() @@ -161,132 +248,104 @@ func New(ctx context.Context, opts ...Option) (c Conn, err error) { } } - if p.eg == nil { - p.eg = errgroup.Get() + if p.errGroup == nil { + p.errGroup = errgroup.Get() } return p, nil } -func (p *pool) init(force bool) { - if p == nil { - return - } - if p.Size() < 1 { - p.size.Store(defaultPoolSize) - } - p.pmu.RLock() - if force || p.pool == nil || cap(p.pool) == 0 || len(p.pool) == 0 { - p.pmu.RUnlock() - p.pmu.Lock() - p.pool = make([]atomic.Pointer[poolConn], p.Size()) - p.pmu.Unlock() - } else { - p.pmu.RUnlock() +// init initializes the connection slots slice using an atomic.Value. +func (p *pool) init() { + size := p.Size() + if size < 1 { + size = defaultPoolSize + p.poolSize.Store(size) } + slots := make([]atomic.Pointer[poolConn], size) + p.connSlots.Store(&slots) } -func (p *pool) grow(size uint64) { - if p == nil || p.Size() > size { - return +// getSlots returns the current connection slots slice. +func (p *pool) getSlots() *[]atomic.Pointer[poolConn] { + if v := p.connSlots.Load(); v != nil { + return v } - l := p.Len() - if l >= size { + return nil +} + +// grow increases the number of connection slots if the new size is larger. +func (p *pool) grow(newSize uint64) { + oldSlots := *p.getSlots() + currentLen := uint64(len(oldSlots)) + if currentLen >= newSize { return } - epool := make([]atomic.Pointer[poolConn], size-l) // expand pool - log.Debugf("growing pool size %d o %d", l, size) - p.pmu.Lock() - if uint64(len(p.pool)) != l { - epool = make([]atomic.Pointer[poolConn], size-uint64(len(p.pool))) // re-expand pool - } - p.pool = append(p.pool, epool...) - p.pmu.Unlock() - p.size.Store(size) + newSlots := make([]atomic.Pointer[poolConn], newSize) + copy(newSlots, oldSlots) + p.connSlots.Store(&newSlots) + p.poolSize.Store(newSize) } -func (p *pool) load(idx int) (pc *poolConn) { - if p == nil { +// load retrieves the poolConn at the specified index. +func (p *pool) load(idx uint64) *poolConn { + slots := *p.getSlots() + if slots == nil || idx < 0 || idx >= p.slotCount() { return nil } - p.pmu.RLock() - if p.pool != nil && p.Size() > uint64(idx) && len(p.pool) > idx { - pc = p.pool[idx].Load() - } - p.pmu.RUnlock() - return pc + return slots[idx].Load() } -func (p *pool) store(idx int, pc *poolConn) { - if p == nil { +// store sets the poolConn at the specified index. +func (p *pool) store(idx uint64, pc *poolConn) { + slots := *p.getSlots() + if slots == nil || idx < 0 || idx >= p.slotCount() { return } - p.init(false) - p.pmu.RLock() - if p.pool != nil && p.Size() > uint64(idx) && len(p.pool) > idx { - p.pool[idx].Store(pc) - } - p.pmu.RUnlock() + slots[idx].Store(pc) } +// loop iterates over each connection slot and applies the provided function. func (p *pool) loop( - ctx context.Context, fn func(ctx context.Context, idx int, pc *poolConn) bool, -) (err error) { - if p == nil || fn == nil { - return nil - } - p.init(false) - p.pmu.RLock() - defer p.pmu.RUnlock() - var cnt int - for idx, pool := range p.pool { + ctx context.Context, fn func(ctx context.Context, idx uint64, pc *poolConn) bool, +) error { + slots := *p.getSlots() + if slots == nil { + return errors.Errorf("connection slots not initialized") + } + var count uint64 + for idx := range slots { select { case <-ctx.Done(): return ctx.Err() default: - if p.Size() > uint64(idx) && len(p.pool) > idx { - cnt++ - if !fn(ctx, idx, pool.Load()) { - return nil - } + count++ + if !fn(ctx, uint64(idx), slots[idx].Load()) { + return nil } } } - if cnt == 0 { + if count == 0 { return errors.ErrGRPCPoolConnectionNotFound } return nil } -func (p *pool) len() int { - if p == nil { - return 0 - } - p.pmu.RLock() - defer p.pmu.RUnlock() - return len(p.pool) -} - -func (p *pool) cap() int { - if p == nil { - return 0 - } - p.pmu.RLock() - defer p.pmu.RUnlock() - return cap(p.pool) +// slotCount returns the number of connection slots. +func (p *pool) slotCount() uint64 { + return uint64(len(*p.getSlots())) } +// flush clears the connection slots. func (p *pool) flush() { - if p == nil { - return - } - p.pmu.Lock() - p.pool = nil - p.pmu.Unlock() + p.connSlots.Store(nil) } -func (p *pool) refreshConn(ctx context.Context, idx int, pc *poolConn, addr string) (err error) { - if pc != nil && pc.addr == addr && isHealthy(ctx, pc.conn) { +// refreshConn checks if the connection at slot idx is healthy for the given address. +// If not, it dials a new connection and updates the slot atomically. +// It also schedules graceful closure of any existing (old) connection. +func (p *pool) refreshConn(ctx context.Context, idx uint64, pc *poolConn, addr string) error { + if pc != nil && pc.addr == addr && p.isHealthy(ctx, pc.conn) { return nil } if pc != nil { @@ -294,164 +353,158 @@ func (p *pool) refreshConn(ctx context.Context, idx int, pc *poolConn, addr stri } else { log.Debugf("connection pool %d/%d is empty, establish new pool member connection to %s", idx+1, p.Size(), addr) } - conn, err := p.dial(ctx, addr) + newConn, err := p.dial(ctx, addr) if err != nil { - if pc != nil { - if isHealthy(ctx, pc.conn) { - log.Debugf("dialing new connection to %s failed,\terror: %v,\tbut existing connection to %s is healthy will keep existing connection", addr, err, pc.addr) + if pc != nil && p.isHealthy(ctx, pc.conn) { + return nil + } + if pc != nil && pc.conn != nil { + p.errGroup.Go(func() error { + log.Debugf("waiting for invalid connection to %s to be closed...", pc.addr) + err := pc.Close(ctx, p.oldConnCloseDelay) + if err != nil { + log.Debugf("failed to close connection pool addr = %s\terror = %v", pc.addr, err) + } return nil - } - if pc.conn != nil { - p.eg.Go(safety.RecoverFunc(func() error { - log.Debugf("waiting for invalid connection to %s to be closed...", pc.addr) - err = pc.Close(ctx, p.roccd) - if err != nil { - log.Debugf("failed to close connection pool addr = %s\terror = %v", pc.addr, err) - } - return nil - })) - } + }) } return errors.Join(err, errors.ErrInvalidGRPCClientConn(addr)) } - p.store(idx, &poolConn{ - conn: conn, - addr: addr, - }) - if pc != nil { - p.eg.Go(safety.RecoverFunc(func() error { + + p.store(idx, &poolConn{conn: newConn, addr: addr}) + + if pc != nil && pc.conn != nil { + p.errGroup.Go(func() error { log.Debugf("waiting for old connection to %s to be closed...", pc.addr) - err = pc.Close(ctx, p.roccd) + err := pc.Close(ctx, p.oldConnCloseDelay) if err != nil { log.Debugf("failed to close connection pool addr = %s\terror = %v", pc.addr, err) } return nil - })) + }) } return nil } -func (p *pool) Connect(ctx context.Context) (c Conn, err error) { - if p == nil || p.closing.Load() { +// Connect establishes connections for all slots. +// It uses DNS lookup if enabled; otherwise, it connects to the single target address. +func (p *pool) Connect(ctx context.Context) (Conn, error) { + if p.closing.Load() { return p, nil } + log.Debugf("Connecting: addr=%s, host=%s, port=%d, isIP=%t, enableDNS=%t", + p.addr, p.host, p.port, p.isIPAddr, p.enableDNSLookup) - p.init(false) - - if p.isIP || !p.resolveDNS { - return p.singleTargetConnect(ctx) + if p.isIPAddr || !p.enableDNSLookup { + return p.singleTargetConnect(ctx, p.addr) } ips, err := p.lookupIPAddr(ctx) if err != nil { - return p.singleTargetConnect(ctx) + return p.singleTargetConnect(ctx, p.addr) + } + if len(ips) == 1 { + target := net.JoinHostPort(ips[0], p.port) + return p.singleTargetConnect(ctx, target) } return p.connect(ctx, ips...) } +// connect establishes connections using multiple IP addresses. func (p *pool) connect(ctx context.Context, ips ...string) (c Conn, err error) { if uint64(len(ips)) > p.Size() { p.grow(uint64(len(ips))) } - - err = p.loop(ctx, func(ctx context.Context, idx int, pc *poolConn) bool { - addr := net.JoinHostPort(ips[idx%len(ips)], p.port) - ierr := p.refreshConn(ctx, idx, pc, addr) - if ierr != nil { - if !errors.Is(ierr, context.DeadlineExceeded) && - !errors.Is(ierr, context.Canceled) { - log.Warnf("An error occurred while dialing pool member connection to %s,\terror: %v", addr, ierr) + log.Debugf("Connecting to multiple IPs: %v on port %d", ips, p.port) + err = p.loop(ctx, func(ctx context.Context, idx uint64, pc *poolConn) bool { + target := net.JoinHostPort(ips[idx%uint64(len(ips))], p.port) + if err = p.refreshConn(ctx, idx, pc, target); err != nil { + if !errors.Is(err, context.DeadlineExceeded) && + !errors.Is(err, context.Canceled) { + log.Warnf("An error occurred while dialing pool slot %d connection to %s,\terror: %v", idx, target, err) } else { - log.Debugf("Connect loop operation canceled while dialing pool member connection to %s,\terror: %v", addr, ierr) + log.Debugf("Connect loop operation canceled while dialing pool slot %d connection to %s,\terror: %v", idx, target, err) return false } } return true }) - if !errors.Is(err, context.Canceled) && + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { return p, err } - hash := strings.Join(ips, "-") - p.reconnectHash.Store(&hash) + p.dnsHash.Store(&hash) + return p, err +} - return p, nil +// singleTargetConnect connects every slot to a single target address. +func (p *pool) singleTargetConnect(ctx context.Context, addr string) (Conn, error) { + log.Debugf("Connecting to single target: %s", addr) + failCount := uint64(0) + err := p.loop(ctx, func(ctx context.Context, idx uint64, pc *poolConn) bool { + if err := p.refreshConn(ctx, idx, pc, addr); err != nil { + if !errors.Is(err, context.DeadlineExceeded) && + !errors.Is(err, context.Canceled) { + log.Warnf("An error occurred while dialing pool member connection to %s,\terror: %v", p.addr, err) + failCount++ + if p.isIPAddr && (p.slotCount() <= 2 || failCount >= p.slotCount()/3) { + return false + } + } else { + log.Debugf("Connect loop operation canceled while dialing pool member connection to %s,\terror: %v", p.addr, err) + return false + } + } + return true + }) + p.dnsHash.Store(&p.host) + return p, err } -func (p *pool) Reconnect(ctx context.Context, force bool) (c Conn, err error) { - if p == nil || p.closing.Load() { +// Reconnect re-establishes connections if the pool is unhealthy or if forced. +func (p *pool) Reconnect(ctx context.Context, force bool) (Conn, error) { + if p.closing.Load() { return p, nil } - - hash := p.reconnectHash.Load() + hash := p.dnsHash.Load() if force || hash == nil || *hash == "" { return p.Connect(ctx) } - - healthy := p.IsHealthy(ctx) - if healthy { - if !p.isIP && p.resolveDNS && hash != nil && *hash != "" { + if p.IsHealthy(ctx) { + if !p.isIPAddr && p.enableDNSLookup && hash != nil && *hash != "" { ips, err := p.lookupIPAddr(ctx) if err != nil { return p, nil } + if len(ips) == 1 { + target := net.JoinHostPort(ips[0], p.port) + return p.singleTargetConnect(ctx, target) + } if *hash != strings.Join(ips, "-") { return p.connect(ctx, ips...) } } return p, nil } - return p.Connect(ctx) } -func (p *pool) singleTargetConnect(ctx context.Context) (c Conn, err error) { - if p == nil || p.closing.Load() { - return p, nil - } - - failCnt := 0 - err = p.loop(ctx, func(ctx context.Context, idx int, pc *poolConn) bool { - ierr := p.refreshConn(ctx, idx, pc, p.addr) - if ierr != nil { - if !errors.Is(ierr, context.DeadlineExceeded) && - !errors.Is(ierr, context.Canceled) { - log.Warnf("An error occurred while dialing pool member connection to %s,\terror: %v", p.addr, ierr) - failCnt++ - if p.isIP && (p.len() <= 2 || failCnt >= p.len()/3) { - return false - } - return true - } else { - log.Debugf("Connect loop operation canceled while dialing pool member connection to %s,\terror: %v", p.addr, ierr) - return false - } - } - return true - }) - if !errors.Is(err, context.Canceled) && - !errors.Is(err, context.DeadlineExceeded) { - return p, err - } - p.reconnectHash.Store(&p.host) - return p, nil -} - -func (p *pool) Disconnect() (err error) { - ctx := context.Background() +// Disconnect gracefully closes all connections in the pool. +func (p *pool) Disconnect(ctx context.Context) (err error) { + log.Debug("Disconnecting pool...") p.closing.Store(true) defer p.closing.Store(false) - emap := make(map[string]error, p.len()) - err = p.loop(ctx, func(ctx context.Context, _ int, pc *poolConn) bool { + emap := make(map[string]error, p.Size()) + err = p.loop(ctx, func(ctx context.Context, idx uint64, pc *poolConn) bool { if pc != nil && pc.conn != nil { - ierr := pc.conn.Close() - if ierr != nil { - if !errors.Is(ierr, context.DeadlineExceeded) && - !errors.Is(ierr, context.Canceled) { - log.Debugf("failed to close connection pool addr = %s\terror = %v", pc.addr, ierr) - emap[ierr.Error()] = err + log.Debugf("Closing slot %d (addr: %s)", idx, pc.addr) + if err := pc.Close(ctx, p.oldConnCloseDelay); err != nil { + if !errors.Is(err, context.DeadlineExceeded) && + !errors.Is(err, context.Canceled) { + log.Debugf("failed to close connection pool addr = %s\terror = %v", pc.addr, err) + emap[err.Error()] = err } else { - log.Debugf("Disconnect loop operation canceled while closing pool member connection to %s,\terror: %v", pc.addr, ierr) + log.Debugf("Disconnect loop operation canceled while closing pool member connection to %s,\terror: %v", pc.addr, err) return false } } @@ -465,22 +518,21 @@ func (p *pool) Disconnect() (err error) { return err } -func (p *pool) dial(ctx context.Context, addr string) (conn *ClientConn, err error) { - do := func() (conn *ClientConn, err error) { +// dial creates a new gRPC connection to the specified address. +// It uses a dial timeout and, if configured, a backoff strategy. +func (p *pool) dial(ctx context.Context, addr string) (*ClientConn, error) { + dialFunc := func(ctx context.Context) (*ClientConn, error) { ctx, cancel := context.WithTimeout(ctx, p.dialTimeout) defer cancel() - conn, err = grpc.NewClient(addr, p.dopts...) + log.Debugf("Dialing %s with timeout %s", addr, p.dialTimeout) + conn, err := grpc.NewClient(addr, p.dialOpts...) if err != nil { if conn != nil { - cerr := conn.Close() - if cerr != nil { - err = errors.Join(err, cerr) - } + _ = conn.Close() } - log.Debugf("failed to dial gRPC connection to %s,\terror: %v", addr, err) return nil, err } - if !isHealthy(ctx, conn) { + if !p.isHealthy(ctx, conn) { if conn != nil { err = conn.Close() if err != nil { @@ -490,199 +542,150 @@ func (p *pool) dial(ctx context.Context, addr string) (conn *ClientConn, err err } } log.Debugf("connection for %s is unhealthy: %v", addr, err) - return nil, err + return nil, errors.Wrapf(err, "connection to %s is unhealthy", addr) } return conn, nil } if p.bo != nil { - retry := 0 - _, err = p.bo.Do(ctx, func(ctx context.Context) (r any, ret bool, err error) { - log.Debugf("dialing to %s with backoff, retry: %d", addr, retry) - conn, err = do() - retry++ + var conn *ClientConn + _, err := p.bo.Do(ctx, func(ctx context.Context) (interface{}, bool, error) { + var err error + conn, err = dialFunc(ctx) return conn, err != nil, err }) + if err != nil && conn != nil { + _ = conn.Close() + return nil, errors.Join(err, conn.Close()) + } return conn, nil } - - log.Debugf("dialing to %s", addr) - return do() + return dialFunc(ctx) } -func (p *pool) IsHealthy(ctx context.Context) (healthy bool) { +// getHealthyConn retrieves a healthy connection from the pool using round-robin indexing. +// It attempts up to poolSize times. +func (p *pool) getHealthyConn(ctx context.Context) (idx uint64, pc *poolConn, ok bool) { if p == nil || p.closing.Load() { - return false + return 0, nil, false } - var cnt, unhealthy int - pl := p.len() - err := p.loop(ctx, func(ctx context.Context, idx int, pc *poolConn) bool { - if pc == nil || !isHealthy(ctx, pc.conn) { - if p.isIP { - if pc != nil && pc.addr != "" { - err := p.refreshConn(ctx, idx, pc, pc.addr) - if err != nil { - // target addr cannot re-connect so, connection is unhealthy - unhealthy++ - return false - } - return true - } - return false - } - addr := p.addr - if pc != nil { - addr = pc.addr - } - // re-connect to last connected addr - err := p.refreshConn(ctx, idx, pc, addr) - if err != nil { - if addr == p.addr { - unhealthy++ - return true - } - // last connect addr is not dns and cannot connect then try dns - err = p.refreshConn(ctx, idx, pc, p.addr) - // dns addr cannot connect so, connection is unhealthy - if err != nil { - unhealthy = pl - cnt - return false - } + sz := p.Size() + if sz == 0 { + return 0, nil, false + } + for i := uint64(0); i < sz; i++ { + idx = p.currentIndex.Add(1) % sz + pc = p.load(idx) + if pc != nil && p.isHealthy(ctx, pc.conn) { + return idx, pc, true + } + if err := p.refreshConn(ctx, idx, pc, p.addr); err == nil { + if pc = p.load(idx); pc != nil && p.isHealthy(ctx, pc.conn) { + return idx, pc, true } } - cnt++ - return true - }) - if err != nil { - log.Debugf("health check loop for addr=%s returned error: %v,\thealthy %d/%d", p.addr, err, pl-unhealthy, pl) - } - if cnt == 0 { - log.Debugf("no connection pool %d/%d found for %s,\thealthy %d/%d", cnt, pl, p.addr, pl-unhealthy, pl) - return false - } - if p.isIP { - // if ip pool connection, each connection target should be healthy - return unhealthy == 0 } - - // some pool target may unhealthy but pool client is healthy when unhealthy is less than pool length - return unhealthy < pl + return 0, nil, false } -func (p *pool) Do(ctx context.Context, f func(conn *ClientConn) error) (err error) { +// Do executes the provided function using a healthy connection. +// If an error indicating a closed connection is returned, it attempts to refresh the connection and retries. +func (p *pool) Do(ctx context.Context, f func(conn *ClientConn) error) error { if p == nil { return errors.ErrGRPCClientConnNotFound("*") } - idx, conn, ok := p.getHealthyConn(ctx, 0, p.Len()) - if !ok || conn == nil { + _, pc, ok := p.getHealthyConn(ctx) + if !ok || pc == nil || pc.conn == nil { return errors.ErrGRPCClientConnNotFound(p.addr) } - err = f(conn) - if errors.Is(err, grpc.ErrClientConnClosing) { - if conn != nil { - if cerr := conn.Close(); cerr != nil && !errors.Is(cerr, grpc.ErrClientConnClosing) { - log.Warnf("Failed to close connection: %v", cerr) - } - } - conn, err = p.dial(ctx, p.addr) - if err == nil && conn != nil && isHealthy(ctx, conn) { - p.store(idx, &poolConn{ - conn: conn, - addr: p.addr, - }) - if newErr := f(conn); newErr != nil { - return errors.Join(err, newErr) - } - return nil - } - } - return err + return f(pc.conn) } +// Get returns a healthy connection from the pool, if available. func (p *pool) Get(ctx context.Context) (conn *ClientConn, ok bool) { - _, conn, ok = p.getHealthyConn(ctx, 0, p.Len()) - return conn, ok + _, pc, ok := p.getHealthyConn(ctx) + if ok && pc != nil { + return pc.conn, true + } + return nil, false } -func (p *pool) getHealthyConn( - ctx context.Context, cnt, retry uint64, -) (idx int, conn *ClientConn, ok bool) { - if p == nil || p.closing.Load() { - return 0, nil, false - } - select { - case <-ctx.Done(): - return 0, nil, false - default: +// IsHealthy checks the overall health of the pool. +// For IP-based connections, all slots must be healthy; otherwise, at least one healthy slot is acceptable. +// Global metrics are updated accordingly. +func (p *pool) IsHealthy(ctx context.Context) bool { + sz := p.slotCount() + if sz == 0 { + return false } - pl := p.Len() - if retry <= 0 || retry > math.MaxUint64-pl || pl <= 0 { - if p.isIP { - log.Warnf("failed to find gRPC IP connection pool for %s.\tlen(pool): %d,\tretried: %d,\tseems IP %s is unhealthy will going to disconnect...", p.addr, pl, cnt, p.addr) - if err := p.Disconnect(); err != nil { - log.Debugf("failed to disconnect gRPC IP direct connection for %s,\terr: %v", p.addr, err) - } - return 0, nil, false - } - if pl > 0 { - idx = int(p.current.Add(1) % pl) - } - if pc := p.load(idx); pc != nil && isHealthy(ctx, pc.conn) { - return idx, pc.conn, true + healthyCount := uint64(0) + err := p.loop(ctx, func(ctx context.Context, _ uint64, pc *poolConn) bool { + if pc != nil && p.isHealthy(ctx, pc.conn) { + healthyCount++ } - conn, err := p.dial(ctx, p.addr) - if err == nil && conn != nil && isHealthy(ctx, conn) { - p.store(idx, &poolConn{ - conn: conn, - addr: p.addr, - }) - return idx, conn, true - } - log.Warnf("failed to find gRPC connection pool for %s.\tlen(pool): %d,\tretried: %d,\terror: %v", p.addr, pl, cnt, err) - return idx, nil, false + return true + }) + metrics.Store(p.addr, healthyCount) + if err != nil { + log.Debugf("health check loop for addr=%s returned error: %v", p.addr, err) } - - if pl > 0 { - idx = int(p.current.Add(1) % pl) - if pc := p.load(idx); pc != nil && isHealthy(ctx, pc.conn) { - return idx, pc.conn, true - } + if healthyCount == 0 { + log.Debugf("no connection pool member is healthy for addr=%s", p.addr) + return false + } + if p.isIPAddr { + return healthyCount == uint64(sz) } - retry-- - cnt++ - return p.getHealthyConn(ctx, cnt, retry) + return healthyCount > 0 } +// Len returns the number of connection slots. func (p *pool) Len() uint64 { - return uint64(p.len()) + return p.slotCount() } +// Size returns the configured pool size. func (p *pool) Size() uint64 { - return p.size.Load() + return p.poolSize.Load() +} + +// IsIPConn indicates whether the pool is using direct IP connections. +func (p *pool) IsIPConn() bool { + return p.isIPAddr } -func (p *pool) lookupIPAddr(ctx context.Context) (ips []string, err error) { +// String returns a string representation of the pool's state. +func (p *pool) String() string { + hash := "" + if rh := p.dnsHash.Load(); rh != nil { + hash = *rh + } + return fmt.Sprintf("addr: %s, host: %s, port: %d, isIP: %t, enableDNS: %t, dnsHash: %s, slotCount: %d, poolSize: %d, currentIndex: %d, dialTimeout: %s, oldConnCloseDelay: %s, closing: %t", + p.addr, p.host, p.port, p.isIPAddr, p.enableDNSLookup, hash, p.slotCount(), p.Size(), p.currentIndex.Load(), + p.dialTimeout.String(), p.oldConnCloseDelay.String(), p.closing.Load()) +} + +// lookupIPAddr performs DNS lookup for the host and returns a list of reachable IP addresses. +// It also attempts a short TCP dial ("ping") for each IP. +func (p *pool) lookupIPAddr(ctx context.Context) ([]string, error) { addrs, err := net.DefaultResolver.LookupIPAddr(ctx, p.host) if err != nil { - log.Debugf("failed to lookup ip addr for %s \terr: %s", p.addr, err.Error()) + log.Debugf("Failed to lookup IP addresses for %s: %s", p.addr, err.Error()) return nil, err } - if len(addrs) == 0 { return nil, errors.ErrGRPCLookupIPAddrNotFound(p.host) } - - ips = make([]string, 0, len(addrs)) + var ips []string for _, ip := range addrs { ipStr := ip.String() - var conn net.Conn - addr := net.JoinHostPort(ipStr, p.port) - ctx, cancel := context.WithTimeout(ctx, time.Millisecond*10) - conn, err := net.DialContext(ctx, net.TCP.String(), addr) + target := net.JoinHostPort(ipStr, p.port) + pingCtx, cancel := context.WithTimeout(ctx, 10*time.Millisecond) + conn, err := net.DialContext(pingCtx, net.TCP.String(), target) cancel() - if err != nil || conn == nil { - log.Warnf("failed to initialize ping addr: %s,\terr: %s", addr, err.Error()) - } else { + if err == nil { ips = append(ips, ipStr) + } else { + log.Warnf("Failed to ping %s: %s", target, err.Error()) } if conn != nil { err = conn.Close() @@ -691,34 +694,32 @@ func (p *pool) lookupIPAddr(ctx context.Context) (ips []string, err error) { } } } - if len(ips) == 0 { return nil, errors.ErrGRPCLookupIPAddrNotFound(p.host) } - - slices.Sort(ips) - + // Sorting can be added here if needed. return ips, nil } +// scanGRPCPort scans ports from startPort to endPort for a valid gRPC endpoint. func (p *pool) scanGRPCPort(ctx context.Context) (port uint16, err error) { ports, err := net.ScanPorts(ctx, p.startPort, p.endPort, p.host) if err != nil { return 0, err } + log.Debugf("Scanning ports: %v", ports) var conn *ClientConn for _, port := range ports { select { case <-ctx.Done(): return 0, ctx.Err() default: - // try gRPC dialing to target port - conn, err = grpc.NewClient(net.JoinHostPort(p.host, port), p.dopts...) - if err == nil && isHealthy(ctx, conn) && conn.Close() == nil { - // if no error and healthy the port is ready for gRPC + conn, err = grpc.NewClient(net.JoinHostPort(p.host, port), p.dialOpts...) + if err == nil && p.isHealthy(ctx, conn) { + _ = conn.Close() + log.Debugf("Found valid gRPC port: %d", port) return port, nil } - if conn != nil { _ = conn.Close() } @@ -727,81 +728,25 @@ func (p *pool) scanGRPCPort(ctx context.Context) (port uint16, err error) { return 0, errors.ErrInvalidGRPCPort(p.addr, p.host, p.port) } -func (p *pool) IsIPConn() (isIP bool) { - return p.isIP -} - -func (p *pool) String() (str string) { - if p == nil { - return "" - } - var hash string - rh := p.reconnectHash.Load() - if rh != nil { - hash = *rh - } - return fmt.Sprintf("addr: %s, host: %s, port %d, isIP: %v, resolveDNS: %v, hash: %s, pool_size: %d, current_seek: %d, dopt_len: %d, dial_timeout: %v, roccd: %v, closing: %v", - p.addr, - p.host, - p.port, - p.isIP, - p.resolveDNS, - hash, - p.size.Load(), - p.current.Load(), - len(p.dopts), - p.dialTimeout, - p.roccd, - p.closing.Load()) -} - -func (pc *poolConn) Close(ctx context.Context, delay time.Duration) error { - tdelay := delay / 10 - if tdelay < time.Millisecond*200 { - tdelay = time.Millisecond * 200 - } else if tdelay > time.Minute { - tdelay = time.Second * 5 - } - tick := time.NewTicker(tdelay) - defer tick.Stop() - ctx, cancel := context.WithTimeout(ctx, delay) - defer cancel() - for { - select { - case <-ctx.Done(): - err := pc.conn.Close() - if err != nil && !errors.Is(err, grpc.ErrClientConnClosing) { - if ctx.Err() != nil && - !errors.Is(ctx.Err(), context.DeadlineExceeded) && - !errors.Is(ctx.Err(), context.Canceled) { - return errors.Join(err, ctx.Err()) - } - return err - } - if ctx.Err() != nil && - !errors.Is(ctx.Err(), context.DeadlineExceeded) && - !errors.Is(ctx.Err(), context.Canceled) { - return ctx.Err() - } - return nil - case <-tick.C: - switch pc.conn.GetState() { - case connectivity.Idle, connectivity.Connecting, connectivity.TransientFailure: - err := pc.conn.Close() - if err != nil && !errors.Is(err, grpc.ErrClientConnClosing) { - return err - } - return nil - case connectivity.Shutdown: - return nil - } +// Metrics returns a map of healthy connection counts per target address. +func Metrics(ctx context.Context) map[string]uint64 { + result := make(map[string]uint64) + metrics.Range(func(addr string, count uint64) bool { + if addr != "" { + result[addr] = count } + return true + }) + if len(result) == 0 { + return nil } + return result } -func isHealthy(ctx context.Context, conn *ClientConn) bool { +// p.isHealthy checks whether a given gRPC connection is healthy by examining its connectivity state. +func (p *pool) isHealthy(ctx context.Context, conn *ClientConn) bool { if conn == nil { - log.Warn("gRPC target connection is nil") + log.Warn("gRPC connection is nil") return false } state := conn.GetState() @@ -809,23 +754,23 @@ func isHealthy(ctx context.Context, conn *ClientConn) bool { case connectivity.Ready: return true case connectivity.Connecting: - log.Debugf("gRPC target %s's connection status will be Ready soon\tstatus: %s", conn.Target(), state.String()) return true case connectivity.Idle: - log.Debugf("gRPC target %s's connection status is waiting for target\tstatus: %s\ttrying to re-connect...", conn.Target(), state.String()) - conn.Connect() + // Trigger connection if idle. + p.errGroup.Go(func() error { + conn.Connect() + return nil + }) if conn.WaitForStateChange(ctx, state) { - state = conn.GetState() - if state == connectivity.Ready || state == connectivity.Connecting { - log.Debugf("gRPC target %s's connection status enabled for target\tstatus: %s", conn.Target(), state.String()) - return true - } + return p.isHealthy(ctx, conn) } + log.Errorf("Connection %s did not recover from idle", conn.Target()) return false case connectivity.Shutdown, connectivity.TransientFailure: - log.Errorf("gRPC target %s's connection status is unhealthy\tstatus: %s", conn.Target(), state.String()) + log.Errorf("Connection %s is unhealthy (state: %s)", conn.Target(), state.String()) + return false + default: + log.Errorf("Connection %s has unknown state: %s", conn.Target(), state.String()) return false } - log.Errorf("gRPC target %s's connection status is unknown\tstatus: %s", conn.Target(), state.String()) - return false } diff --git a/internal/net/grpc/pool/pool_bench_test.go b/internal/net/grpc/pool/pool_bench_test.go index b9effb3712..955f3a6e79 100644 --- a/internal/net/grpc/pool/pool_bench_test.go +++ b/internal/net/grpc/pool/pool_bench_test.go @@ -117,7 +117,7 @@ func Benchmark_ConnPool(b *testing.B) { b.ResetTimer() b.ReportAllocs() b.StartTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { conn, ok := pool.Get(ctx) if ok { do(b, conn) @@ -141,7 +141,7 @@ func Benchmark_StaticDial(b *testing.B) { b.ResetTimer() b.ReportAllocs() b.StartTimer() - for i := 0; i < b.N; i++ { + for b.Loop() { val, ok := conns.Load(DefaultServerAddr) if ok { do(b, val) diff --git a/internal/net/grpc/status/status.go b/internal/net/grpc/status/status.go index 4113d872ac..4e6bfdcd6f 100644 --- a/internal/net/grpc/status/status.go +++ b/internal/net/grpc/status/status.go @@ -45,6 +45,10 @@ func New(c codes.Code, msg string) *Status { return status.New(c, msg) } +func Is(err error, code codes.Code) bool { + return status.Code(err) == code +} + func newStatus(code codes.Code, msg string, err error, details ...any) (st *Status) { st = New(code, msg) return withDetails(st, err, details...) diff --git a/internal/net/http/client/client.go b/internal/net/http/client/client.go index 0fcb9953be..b9ada6637d 100644 --- a/internal/net/http/client/client.go +++ b/internal/net/http/client/client.go @@ -49,7 +49,6 @@ func NewWithTransport(rt http.RoundTripper, opts ...Option) (*http.Client, error tr.Transport = http.DefaultTransport.(*http.Transport).Clone() } for _, opt := range append(defaultOptions, opts...) { - // ... existing code ... if err := opt(tr); err != nil { werr := errors.ErrOptionFailed(err, reflect.ValueOf(opt)) e := new(errors.ErrCriticalOption) @@ -61,7 +60,8 @@ func NewWithTransport(rt http.RoundTripper, opts ...Option) (*http.Client, error } } - err := http2.ConfigureTransport(tr.Transport) + var err error + tr.Transport, err = http2.ConfigureTransports(tr.Transport) if err != nil { log.Warnf("Transport is already configured for HTTP2 error: %v", err) } diff --git a/internal/net/http/client/option_test.go b/internal/net/http/client/option_test.go index a62d036990..32ac647b67 100644 --- a/internal/net/http/client/option_test.go +++ b/internal/net/http/client/option_test.go @@ -257,7 +257,7 @@ func TestWithTLSHandshakeTimeout(t *testing.T) { err: errors.NewErrCriticalOption( "TLSHandshakeTimeout", "dummy", - errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy")), + errors.New("time: invalid duration \"dummy\""), ), }, }, @@ -766,7 +766,7 @@ func TestWithIdleConnTimeout(t *testing.T) { err: errors.NewErrCriticalOption( "idleConnTimeout", "dummy", - errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy")), + errors.New("time: invalid duration \"dummy\""), ), }, }, @@ -863,7 +863,7 @@ func TestWithResponseHeaderTimeout(t *testing.T) { err: errors.NewErrCriticalOption( "responseHeaderTimeout", "dummy", - errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy")), + errors.New("time: invalid duration \"dummy\""), ), }, }, @@ -960,7 +960,7 @@ func TestWithExpectContinueTimeout(t *testing.T) { err: errors.NewErrCriticalOption( "expectContinueTimeout", "dummy", - errors.Join(errors.New("time: invalid duration \"dummy\""), errors.ErrTimeoutParseFailed("dummy")), + errors.New("time: invalid duration \"dummy\""), ), }, }, diff --git a/internal/params/params.go b/internal/params/params.go index 4d2f0de2b1..9239038549 100644 --- a/internal/params/params.go +++ b/internal/params/params.go @@ -85,6 +85,9 @@ func New(opts ...Option) Parser { // Parse parses command-line argument and returns parsed data and whether there is a help option or not and error. func (p *parser) Parse() (Data, bool, error) { + if p == nil || p.f == nil { + return nil, false, errors.ErrArgumentParserNotFound + } d := new(data) for _, key := range p.filePath.keys { p.f.StringVar(&d.configFilePath, diff --git a/internal/params/params_test.go b/internal/params/params_test.go index 128b89eac8..c1341fb98b 100644 --- a/internal/params/params_test.go +++ b/internal/params/params_test.go @@ -18,15 +18,16 @@ package params import ( + "flag" "os" "reflect" - "syscall" "testing" "github.com/vdaas/vald/internal/errors" "github.com/vdaas/vald/internal/test/goleak" ) +// TestNew tests the New function for creating a new parser instance. func TestNew(t *testing.T) { type args struct { opts []Option @@ -42,15 +43,37 @@ func TestNew(t *testing.T) { beforeFunc func(args) afterFunc func(args) } + // Custom check function: compare only the essential fields. defaultCheckFunc := func(w want, got Parser) error { - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) + p, ok := got.(*parser) + if !ok { + return errors.Errorf("got is not *parser") + } + // Check filePath fields. + if !reflect.DeepEqual(p.filePath.keys, w.want.filePath.keys) { + return errors.Errorf("filePath.keys mismatch: got %v, want %v", p.filePath.keys, w.want.filePath.keys) + } + if p.filePath.defaultPath != w.want.filePath.defaultPath { + return errors.Errorf("filePath.defaultPath mismatch: got %v, want %v", p.filePath.defaultPath, w.want.filePath.defaultPath) + } + if p.filePath.description != w.want.filePath.description { + return errors.Errorf("filePath.description mismatch: got %v, want %v", p.filePath.description, w.want.filePath.description) + } + // Check version fields. + if !reflect.DeepEqual(p.version.keys, w.want.version.keys) { + return errors.Errorf("version.keys mismatch: got %v, want %v", p.version.keys, w.want.version.keys) + } + if p.version.defaultFlag != w.want.version.defaultFlag { + return errors.Errorf("version.defaultFlag mismatch: got %v, want %v", p.version.defaultFlag, w.want.version.defaultFlag) + } + if p.version.description != w.want.version.description { + return errors.Errorf("version.description mismatch: got %v, want %v", p.version.description, w.want.version.description) } return nil } tests := []test{ { - name: "returns *parser when opts is nil", + name: "should return a default parser when no options are provided", want: want{ want: &parser{ filePath: struct { @@ -58,12 +81,7 @@ func TestNew(t *testing.T) { defaultPath string description string }{ - keys: []string{ - "f", - "file", - "c", - "config", - }, + keys: []string{"f", "file", "c", "config"}, defaultPath: "/etc/server/config.yaml", description: "config file path", }, @@ -72,20 +90,15 @@ func TestNew(t *testing.T) { defaultFlag bool description string }{ - keys: []string{ - "v", - "ver", - "version", - }, + keys: []string{"v", "ver", "version"}, defaultFlag: false, description: "show server version", }, }, }, }, - { - name: "returns *parser when opts is not nil", + name: "should return a parser with additional config file keys when options are provided", args: args{ opts: []Option{ WithConfigFilePathKeys("t", "test"), @@ -98,14 +111,7 @@ func TestNew(t *testing.T) { defaultPath string description string }{ - keys: []string{ - "f", - "file", - "c", - "config", - "t", - "test", - }, + keys: []string{"f", "file", "c", "config", "t", "test"}, defaultPath: "/etc/server/config.yaml", description: "config file path", }, @@ -114,11 +120,7 @@ func TestNew(t *testing.T) { defaultFlag bool description string }{ - keys: []string{ - "v", - "ver", - "version", - }, + keys: []string{"v", "ver", "version"}, defaultFlag: false, description: "show server version", }, @@ -137,19 +139,19 @@ func TestNew(t *testing.T) { if test.afterFunc != nil { defer test.afterFunc(test.args) } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc + check := test.checkFunc + if check == nil { + check = defaultCheckFunc } - got := New(test.args.opts...) - if err := checkFunc(test.want, got); err != nil { + if err := check(test.want, got); err != nil { tt.Errorf("error = %v", err) } }) } } +// Test_parser_Parse tests the Parse method of the parser. func Test_parser_Parse(t *testing.T) { type fields struct { filePath struct { @@ -164,142 +166,243 @@ func Test_parser_Parse(t *testing.T) { } } type want struct { - want Data - want1 bool - err error + want Data // expected Data (may be nil) + help bool // indicates if help option was triggered + err error // expected error } type test struct { name string fields fields + args []string // custom os.Args (optional) want want checkFunc func(want, Data, bool, error) error beforeFunc func(*testing.T) afterFunc func(*testing.T) } - defaultCheckFunc := func(w want, got Data, got1 bool, err error) error { + // Custom check function: compare only the essential fields of Data. + defaultCheckFunc := func(w want, got Data, gotHelp bool, err error) error { if !errors.Is(err, w.err) { - return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + return errors.Errorf("got error: %#v, want: %#v", err, w.err) + } + if gotHelp != w.help { + return errors.Errorf("got help flag: %#v, want: %#v", gotHelp, w.help) + } + // If no expected data is provided, skip field comparison. + if w.want == nil { + return nil + } + d, ok := got.(*data) + if !ok { + return errors.Errorf("got is not *data") } - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) + expected, ok := w.want.(*data) + if !ok { + return errors.Errorf("expected want is not *data") + } + // If expected configFilePath is non-empty, compare directly. + // Otherwise, ensure that got.ConfigFilePath() is not empty. + if expected.configFilePath != "" { + if d.configFilePath != expected.configFilePath { + return errors.Errorf("configFilePath mismatch: got %v, want %v", d.configFilePath, expected.configFilePath) + } + } else { + if d.configFilePath == "" { + return errors.Errorf("expected non-empty configFilePath, but got empty string") + } } - if !reflect.DeepEqual(got1, w.want1) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got1, w.want1) + if d.showVersion != expected.showVersion { + return errors.Errorf("showVersion mismatch: got %v, want %v", d.showVersion, expected.showVersion) } return nil } tests := []test{ { - name: "returns (d, false, nil) when parse succeed", + name: "should successfully parse valid config file and version flag false", fields: fields{ filePath: struct { keys []string defaultPath string description string }{ - keys: []string{ - "path", "p", - }, - defaultPath: "./params.go", - description: "sets file path", + keys: []string{"path", "p"}, + defaultPath: "", + description: "set file path", }, version: struct { keys []string defaultFlag bool description string }{ - keys: []string{ - "version", "v", - }, - defaultFlag: true, - description: "show version", + keys: []string{"version", "v"}, + defaultFlag: false, + description: "show version flag", }, }, beforeFunc: func(t *testing.T) { - t.Helper() - os.Args = []string{ - "test", "--path=./params.go", "--version=false", + // Create a temporary file to ensure config file existence. + tmpFile, err := os.CreateTemp("", "config-*.yaml") + if err != nil { + t.Fatal(err) } + // Ensure the temporary file is removed after the test. + t.Cleanup(func() { os.Remove(tmpFile.Name()) }) + tmpFile.Close() + os.Args = []string{"test", "--path=" + tmpFile.Name(), "--version=false"} }, afterFunc: func(t *testing.T) { - t.Helper() os.Args = nil }, want: want{ + // expected Data: only showVersion is checked directly. + // For configFilePath, we expect a non-empty string. want: &data{ - configFilePath: "./params.go", + configFilePath: "", // will be validated as non-empty showVersion: false, }, + help: false, + err: nil, + }, + checkFunc: func(w want, got Data, gotHelp bool, err error) error { + // Use the default check function for essential fields. + return defaultCheckFunc(w, got, gotHelp, err) }, }, - { - name: "returns (nil, true, nil) When parse fails but the help option is set", + name: "should parse and return valid data when version flag is true even if file does not exist", + fields: fields{ + filePath: struct { + keys []string + defaultPath string + description string + }{ + keys: []string{"path", "p"}, + defaultPath: "nonexistent.yaml", + description: "set file path", + }, + version: struct { + keys []string + defaultFlag bool + description string + }{ + keys: []string{"version", "v"}, + defaultFlag: false, + description: "show version flag", + }, + }, beforeFunc: func(t *testing.T) { - t.Helper() - os.Args = []string{ - "test", "--help", - } + os.Args = []string{"test", "--path=nonexistent.yaml", "--version=true"} }, afterFunc: func(t *testing.T) { - t.Helper() os.Args = nil }, want: want{ - want1: true, + want: &data{ + configFilePath: "nonexistent.yaml", + showVersion: true, + }, + help: false, + err: nil, }, }, - { - name: "returns (nil, true, nil) When parse fails but the help option is not set", + name: "should return help when --help flag is provided", + fields: fields{ + filePath: struct { + keys []string + defaultPath string + description string + }{ + keys: []string{"path", "p"}, + defaultPath: "/etc/server/config.yaml", + description: "set file path", + }, + version: struct { + keys []string + defaultFlag bool + description string + }{ + keys: []string{"version", "v"}, + defaultFlag: false, + description: "show version flag", + }, + }, beforeFunc: func(t *testing.T) { - t.Helper() - os.Args = []string{ - "test", "--name", - } + os.Args = []string{"test", "--help"} }, afterFunc: func(t *testing.T) { - t.Helper() os.Args = nil }, want: want{ - want1: false, - err: errors.ErrArgumentParseFailed(errors.New("flag provided but not defined: -name")), + help: true, + err: nil, }, }, - { - name: "returns (nil, true, error) When the configFilePath option is set but file dose not exist", + name: "should return parsing error for unknown flag", fields: fields{ filePath: struct { keys []string defaultPath string description string }{ - keys: []string{ - "path", "p", - }, - description: "sets file path", + keys: []string{"path", "p"}, + defaultPath: "", + description: "set file path", + }, + version: struct { + keys []string + defaultFlag bool + description string + }{ + keys: []string{"version", "v"}, + defaultFlag: false, + description: "show version flag", }, }, beforeFunc: func(t *testing.T) { - t.Helper() - os.Args = []string{ - "test", "--path=config.yaml", - } + os.Args = []string{"test", "--unknown"} }, afterFunc: func(t *testing.T) { - t.Helper() os.Args = nil }, want: want{ - want1: true, - err: &os.PathError{ - Op: "stat", - Path: "config.yaml", - Err: syscall.Errno(0x2), + help: false, + // The error message is wrapped by ErrArgumentParseFailed. + err: errors.ErrArgumentParseFailed(errors.New("flag provided but not defined: -unknown")), + }, + }, + { + name: "should return help when config file path is empty", + fields: fields{ + filePath: struct { + keys []string + defaultPath string + description string + }{ + keys: []string{"path", "p"}, + defaultPath: "", + description: "set file path", + }, + version: struct { + keys []string + defaultFlag bool + description string + }{ + keys: []string{"version", "v"}, + defaultFlag: false, + description: "show version flag", }, }, + beforeFunc: func(t *testing.T) { + os.Args = []string{"test", "--path=", "--version=false"} + }, + afterFunc: func(t *testing.T) { + os.Args = nil + }, + want: want{ + help: true, + err: errors.New("invalid argument"), + }, }, } @@ -313,23 +416,24 @@ func Test_parser_Parse(t *testing.T) { if test.afterFunc != nil { defer test.afterFunc(tt) } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } p := &parser{ filePath: test.fields.filePath, version: test.fields.version, + f: flag.NewFlagSet(os.Args[0], flag.ContinueOnError), } - - got, got1, err := p.Parse() - if err := checkFunc(test.want, got, got1, err); err != nil { + gotData, gotHelp, err := p.Parse() + if test.checkFunc != nil { + if err := test.checkFunc(test.want, gotData, gotHelp, err); err != nil { + tt.Errorf("error = %v", err) + } + } else if err := defaultCheckFunc(test.want, gotData, gotHelp, err); err != nil { tt.Errorf("error = %v", err) } }) } } +// Test_data_ConfigFilePath tests the ConfigFilePath getter of the data struct. func Test_data_ConfigFilePath(t *testing.T) { type fields struct { configFilePath string @@ -346,23 +450,22 @@ func Test_data_ConfigFilePath(t *testing.T) { afterFunc func(*testing.T) } defaultCheckFunc := func(w want, got string) error { - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) + if got != w.want { + return errors.Errorf("got: %v, want: %v", got, w.want) } return nil } tests := []test{ { - name: "returns `./path` when d.configFilePath is `./path`", + name: "should return the provided config file path", fields: fields{ - configFilePath: "./path", + configFilePath: "./path/to/config.yaml", }, want: want{ - want: "./path", + want: "./path/to/config.yaml", }, }, } - for _, tc := range tests { test := tc t.Run(test.name, func(tt *testing.T) { @@ -373,22 +476,22 @@ func Test_data_ConfigFilePath(t *testing.T) { if test.afterFunc != nil { defer test.afterFunc(tt) } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } d := &data{ configFilePath: test.fields.configFilePath, } - got := d.ConfigFilePath() - if err := checkFunc(test.want, got); err != nil { + if test.checkFunc != nil { + if err := test.checkFunc(test.want, got); err != nil { + tt.Errorf("error = %v", err) + } + } else if err := defaultCheckFunc(test.want, got); err != nil { tt.Errorf("error = %v", err) } }) } } +// Test_data_ShowVersion tests the ShowVersion getter of the data struct. func Test_data_ShowVersion(t *testing.T) { type fields struct { showVersion bool @@ -405,14 +508,14 @@ func Test_data_ShowVersion(t *testing.T) { afterFunc func(*testing.T) } defaultCheckFunc := func(w want, got bool) error { - if !reflect.DeepEqual(got, w.want) { - return errors.Errorf("got: \"%#v\",\n\t\t\t\twant: \"%#v\"", got, w.want) + if got != w.want { + return errors.Errorf("got: %v, want: %v", got, w.want) } return nil } tests := []test{ { - name: "returns true when d.showVersion is true", + name: "should return true when showVersion is set to true", fields: fields{ showVersion: true, }, @@ -420,8 +523,16 @@ func Test_data_ShowVersion(t *testing.T) { want: true, }, }, + { + name: "should return false when showVersion is set to false", + fields: fields{ + showVersion: false, + }, + want: want{ + want: false, + }, + }, } - for _, tc := range tests { test := tc t.Run(test.name, func(tt *testing.T) { @@ -432,258 +543,17 @@ func Test_data_ShowVersion(t *testing.T) { if test.afterFunc != nil { defer test.afterFunc(tt) } - checkFunc := test.checkFunc - if test.checkFunc == nil { - checkFunc = defaultCheckFunc - } d := &data{ showVersion: test.fields.showVersion, } - got := d.ShowVersion() - if err := checkFunc(test.want, got); err != nil { + if test.checkFunc != nil { + if err := test.checkFunc(test.want, got); err != nil { + tt.Errorf("error = %v", err) + } + } else if err := defaultCheckFunc(test.want, got); err != nil { tt.Errorf("error = %v", err) } }) } } - -// NOT IMPLEMENTED BELOW -// -// func Test_parser_Restore(t *testing.T) { -// type fields struct { -// overrideDefault bool -// name string -// filters []func(string) bool -// f *flag.FlagSet -// defaults *flag.FlagSet -// filePath struct { -// keys []string -// defaultPath string -// description string -// } -// version struct { -// keys []string -// defaultFlag bool -// description string -// } -// ErrorHandler ErrorHandling -// } -// type want struct{} -// type test struct { -// name string -// fields fields -// want want -// checkFunc func(want) error -// beforeFunc func(*testing.T) -// afterFunc func(*testing.T) -// } -// defaultCheckFunc := func(w want) error { -// return nil -// } -// tests := []test{ -// // TODO test cases -// /* -// { -// name: "test_case_1", -// fields: fields { -// overrideDefault:false, -// name:"", -// filters:nil, -// f:flag.FlagSet{}, -// defaults:flag.FlagSet{}, -// filePath:struct{keys []string; defaultPath string; description string}{}, -// version:struct{keys []string; defaultFlag bool; description string}{}, -// ErrorHandler:nil, -// }, -// want: want{}, -// checkFunc: defaultCheckFunc, -// beforeFunc: func(t *testing.T,) { -// t.Helper() -// }, -// afterFunc: func(t *testing.T,) { -// t.Helper() -// }, -// }, -// */ -// -// // TODO test cases -// /* -// func() test { -// return test { -// name: "test_case_2", -// fields: fields { -// overrideDefault:false, -// name:"", -// filters:nil, -// f:flag.FlagSet{}, -// defaults:flag.FlagSet{}, -// filePath:struct{keys []string; defaultPath string; description string}{}, -// version:struct{keys []string; defaultFlag bool; description string}{}, -// ErrorHandler:nil, -// }, -// want: want{}, -// checkFunc: defaultCheckFunc, -// beforeFunc: func(t *testing.T,) { -// t.Helper() -// }, -// afterFunc: func(t *testing.T,) { -// t.Helper() -// }, -// } -// }(), -// */ -// } -// -// for _, tc := range tests { -// test := tc -// t.Run(test.name, func(tt *testing.T) { -// tt.Parallel() -// defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) -// if test.beforeFunc != nil { -// test.beforeFunc(tt) -// } -// if test.afterFunc != nil { -// defer test.afterFunc(tt) -// } -// checkFunc := test.checkFunc -// if test.checkFunc == nil { -// checkFunc = defaultCheckFunc -// } -// p := &parser{ -// overrideDefault: test.fields.overrideDefault, -// name: test.fields.name, -// filters: test.fields.filters, -// f: test.fields.f, -// defaults: test.fields.defaults, -// filePath: test.fields.filePath, -// version: test.fields.version, -// ErrorHandler: test.fields.ErrorHandler, -// } -// -// p.Restore() -// if err := checkFunc(test.want); err != nil { -// tt.Errorf("error = %v", err) -// } -// }) -// } -// } -// -// func Test_parser_Override(t *testing.T) { -// type fields struct { -// overrideDefault bool -// name string -// filters []func(string) bool -// f *flag.FlagSet -// defaults *flag.FlagSet -// filePath struct { -// keys []string -// defaultPath string -// description string -// } -// version struct { -// keys []string -// defaultFlag bool -// description string -// } -// ErrorHandler ErrorHandling -// } -// type want struct{} -// type test struct { -// name string -// fields fields -// want want -// checkFunc func(want) error -// beforeFunc func(*testing.T) -// afterFunc func(*testing.T) -// } -// defaultCheckFunc := func(w want) error { -// return nil -// } -// tests := []test{ -// // TODO test cases -// /* -// { -// name: "test_case_1", -// fields: fields { -// overrideDefault:false, -// name:"", -// filters:nil, -// f:flag.FlagSet{}, -// defaults:flag.FlagSet{}, -// filePath:struct{keys []string; defaultPath string; description string}{}, -// version:struct{keys []string; defaultFlag bool; description string}{}, -// ErrorHandler:nil, -// }, -// want: want{}, -// checkFunc: defaultCheckFunc, -// beforeFunc: func(t *testing.T,) { -// t.Helper() -// }, -// afterFunc: func(t *testing.T,) { -// t.Helper() -// }, -// }, -// */ -// -// // TODO test cases -// /* -// func() test { -// return test { -// name: "test_case_2", -// fields: fields { -// overrideDefault:false, -// name:"", -// filters:nil, -// f:flag.FlagSet{}, -// defaults:flag.FlagSet{}, -// filePath:struct{keys []string; defaultPath string; description string}{}, -// version:struct{keys []string; defaultFlag bool; description string}{}, -// ErrorHandler:nil, -// }, -// want: want{}, -// checkFunc: defaultCheckFunc, -// beforeFunc: func(t *testing.T,) { -// t.Helper() -// }, -// afterFunc: func(t *testing.T,) { -// t.Helper() -// }, -// } -// }(), -// */ -// } -// -// for _, tc := range tests { -// test := tc -// t.Run(test.name, func(tt *testing.T) { -// tt.Parallel() -// defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) -// if test.beforeFunc != nil { -// test.beforeFunc(tt) -// } -// if test.afterFunc != nil { -// defer test.afterFunc(tt) -// } -// checkFunc := test.checkFunc -// if test.checkFunc == nil { -// checkFunc = defaultCheckFunc -// } -// p := &parser{ -// overrideDefault: test.fields.overrideDefault, -// name: test.fields.name, -// filters: test.fields.filters, -// f: test.fields.f, -// defaults: test.fields.defaults, -// filePath: test.fields.filePath, -// version: test.fields.version, -// ErrorHandler: test.fields.ErrorHandler, -// } -// -// p.Override() -// if err := checkFunc(test.want); err != nil { -// tt.Errorf("error = %v", err) -// } -// }) -// } -// } diff --git a/internal/sync/errgroup/group.go b/internal/sync/errgroup/group.go index 1aae147b41..a361ee6486 100644 --- a/internal/sync/errgroup/group.go +++ b/internal/sync/errgroup/group.go @@ -118,10 +118,8 @@ func TryGo(f func() error) bool { // A negative value indicates no limit. // This must not be modified while any tasks are active. func (g *group) SetLimit(limit int) { - if limit <= 1 { - // For serial execution, do not use a semaphore. - g.sem = nil - g.limit.Store(int64(limit)) + g.limit.Store(int64(limit)) + if limit < 0 { return } // For concurrent execution, initialize or resize the semaphore. @@ -130,7 +128,6 @@ func (g *group) SetLimit(limit int) { } else { g.sem.Resize(int64(limit)) } - g.limit.Store(int64(limit)) } // exec executes the provided function inline (synchronously) when limit == 1. @@ -196,11 +193,6 @@ func (g *group) TryGo(f func() error) bool { if f == nil { return false } - // Execute inline if in serial mode. - if g.limit.Load() == 1 { - g.exec(f) - return true - } // In concurrent mode, try to acquire the semaphore without blocking. if g.sem != nil && !g.sem.TryAcquire(1) { return false diff --git a/internal/timeutil/time.go b/internal/timeutil/time.go index ce55c2b90c..d8af39c31e 100644 --- a/internal/timeutil/time.go +++ b/internal/timeutil/time.go @@ -16,11 +16,7 @@ package timeutil -import ( - "time" - - "github.com/vdaas/vald/internal/errors" -) +import "time" // ParseTime parses string to time.Duration. func Parse(t string) (time.Duration, error) { @@ -29,7 +25,7 @@ func Parse(t string) (time.Duration, error) { } dur, err := time.ParseDuration(t) if err != nil { - return 0, errors.Join(err, errors.ErrTimeoutParseFailed(t)) + return 0, err } return dur, nil } diff --git a/pkg/agent/core/ngt/service/ngt_test.go b/pkg/agent/core/ngt/service/ngt_test.go index 96325386d1..ae3d5cef7c 100644 --- a/pkg/agent/core/ngt/service/ngt_test.go +++ b/pkg/agent/core/ngt/service/ngt_test.go @@ -36,7 +36,6 @@ import ( core "github.com/vdaas/vald/internal/core/algorithm/ngt" "github.com/vdaas/vald/internal/errors" "github.com/vdaas/vald/internal/file" - "github.com/vdaas/vald/internal/k8s/vald" kvald "github.com/vdaas/vald/internal/k8s/vald" "github.com/vdaas/vald/internal/log" "github.com/vdaas/vald/internal/net/grpc" diff --git a/pkg/agent/core/ngt/service/option_test.go b/pkg/agent/core/ngt/service/option_test.go index 6141d58add..2fac2a42c4 100644 --- a/pkg/agent/core/ngt/service/option_test.go +++ b/pkg/agent/core/ngt/service/option_test.go @@ -316,7 +316,7 @@ func TestWithAutoIndexCheckDuration(t *testing.T) { }, want: want{ obj: &T{}, - err: errors.Join(errors.New("time: unknown unit \"ss\" in duration \"5ss\""), errors.ErrTimeoutParseFailed("5ss")), + err: errors.New("time: unknown unit \"ss\" in duration \"5ss\""), }, }, } @@ -401,7 +401,7 @@ func TestWithAutoIndexDurationLimit(t *testing.T) { }, want: want{ obj: &T{}, - err: errors.Join(errors.New("time: unknown unit \"ss\" in duration \"5ss\""), errors.ErrTimeoutParseFailed("5ss")), + err: errors.New("time: unknown unit \"ss\" in duration \"5ss\""), }, }, } @@ -486,7 +486,7 @@ func TestWithAutoSaveIndexDuration(t *testing.T) { }, want: want{ obj: &T{}, - err: errors.Join(errors.New("time: unknown unit \"ss\" in duration \"5ss\""), errors.ErrTimeoutParseFailed("5ss")), + err: errors.New("time: unknown unit \"ss\" in duration \"5ss\""), }, }, } @@ -647,7 +647,7 @@ func TestWithInitialDelayMaxDuration(t *testing.T) { }, want: want{ obj: &T{}, - err: errors.Join(errors.New("time: unknown unit \"ss\" in duration \"5ss\""), errors.ErrTimeoutParseFailed("5ss")), + err: errors.New("time: unknown unit \"ss\" in duration \"5ss\""), }, }, } @@ -732,7 +732,7 @@ func TestWithMinLoadIndexTimeout(t *testing.T) { }, want: want{ obj: &T{}, - err: errors.Join(errors.New("time: unknown unit \"ss\" in duration \"5ss\""), errors.ErrTimeoutParseFailed("5ss")), + err: errors.New("time: unknown unit \"ss\" in duration \"5ss\""), }, }, } @@ -817,7 +817,7 @@ func TestWithMaxLoadIndexTimeout(t *testing.T) { }, want: want{ obj: &T{}, - err: errors.Join(errors.New("time: unknown unit \"ss\" in duration \"5ss\""), errors.ErrTimeoutParseFailed("5ss")), + err: errors.New("time: unknown unit \"ss\" in duration \"5ss\""), }, }, } @@ -902,7 +902,7 @@ func TestWithLoadIndexTimeoutFactor(t *testing.T) { }, want: want{ obj: &T{}, - err: errors.Join(errors.New("time: unknown unit \"ss\" in duration \"5ss\""), errors.ErrTimeoutParseFailed("5ss")), + err: errors.New("time: unknown unit \"ss\" in duration \"5ss\""), }, }, } @@ -1362,7 +1362,7 @@ func TestWithExportIndexInfoDuration(t *testing.T) { }, want: want{ obj: &T{}, - err: errors.Join(errors.New("time: unknown unit \"ss\" in duration \"5ss\""), errors.ErrTimeoutParseFailed("5ss")), + err: errors.New("time: unknown unit \"ss\" in duration \"5ss\""), }, }, } diff --git a/tests/v2/e2e/assets/unary_crud.yaml b/tests/v2/e2e/assets/unary_crud.yaml index 30bf46ee72..c34f3f360d 100644 --- a/tests/v2/e2e/assets/unary_crud.yaml +++ b/tests/v2/e2e/assets/unary_crud.yaml @@ -16,7 +16,7 @@ time_zone: UTC logging: format: raw - level: info + level: debug logger: glg dataset: name: _E2E_DATASET_PATH_ diff --git a/tests/v2/e2e/crud/strategy_test.go b/tests/v2/e2e/crud/strategy_test.go index 03266b1890..c4c91217ad 100644 --- a/tests/v2/e2e/crud/strategy_test.go +++ b/tests/v2/e2e/crud/strategy_test.go @@ -208,10 +208,13 @@ func (r *runner) processExecution(t *testing.T, ctx context.Context, idx int, e config.OpExists: train, test, neighbors := getDatasetSlices(ttt, e) if e.BaseConfig != nil { - log.Infof("started execution name: %s, type: %s, mode: %s, execution: %d, num: %d, offset: %d", - e.Name, e.Type, e.Mode, idx, e.Num, e.Offset) - defer log.Infof("finished execution name: %s type: %s, mode: %s, execution: %d, num: %d, offset: %d", - e.Name, e.Type, e.Mode, idx, e.Num, e.Offset) + start := time.Now() + log.Infof("started %s execution at %s, type: %s, mode: %s, execution: %d, num: %d, offset: %d, parallelism: %d, qps: %d", + e.Name, start.Format("2006-01-02 15:04:05"), e.Type, e.Mode, idx, e.Num, e.Offset, e.Parallelism, e.QPS) + defer func() { + log.Infof("finished %s execution in %s, type: %s, mode: %s, execution: %d, num: %d, offset: %d, parallelism: %d, qps: %d", + e.Name, time.Since(start).String(), e.Type, e.Mode, idx, e.Num, e.Offset, e.Parallelism, e.QPS) + }() } switch e.Type { case config.OpSearch,