From e6d274ea803a75e509663ea613fcf9abbc9aeafd Mon Sep 17 00:00:00 2001 From: dyhkwong <50692134+dyhkwong@users.noreply.github.com> Date: Sun, 8 Sep 2024 21:25:46 +0800 Subject: [PATCH] hysteria2: fix dialer reuse --- transport/internet/hysteria2/dialer.go | 47 ++++++++++++++------------ 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/transport/internet/hysteria2/dialer.go b/transport/internet/hysteria2/dialer.go index 2e4f25265f5..66a5704b246 100644 --- a/transport/internet/hysteria2/dialer.go +++ b/transport/internet/hysteria2/dialer.go @@ -15,7 +15,12 @@ import ( "github.com/v2fly/v2ray-core/v5/transport/internet/tls" ) -var RunningClient map[net.Addr](hyClient.Client) +type dialerConf struct { + net.Destination + *internet.MemoryStreamConfig +} + +var RunningClient map[dialerConf](hyClient.Client) var ClientMutex sync.Mutex var MBps uint64 = 1000000 / 8 // MByte @@ -61,12 +66,17 @@ func (f *connFactory) New(addr net.Addr) (net.PacketConn, error) { return f.NewFunc(addr) } -func NewHyClient(serverAddr net.Addr, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) { +func NewHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) { tlsConfig, err := GetClientTLSConfig(streamSettings) if err != nil { return nil, err } + serverAddr, err := ResolveAddress(dest) + if err != nil { + return nil, err + } + config := streamSettings.ProtocolSettings.(*Config) client, _, err := hyClient.NewClient(&hyClient.Config{ Auth: config.GetPassword(), @@ -93,36 +103,36 @@ func NewHyClient(serverAddr net.Addr, streamSettings *internet.MemoryStreamConfi return client, nil } -func CloseHyClient(serverAddr net.Addr) error { +func CloseHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) error { ClientMutex.Lock() defer ClientMutex.Unlock() - client, found := RunningClient[serverAddr] + client, found := RunningClient[dialerConf{dest, streamSettings}] if found { - delete(RunningClient, serverAddr) + delete(RunningClient, dialerConf{dest, streamSettings}) return client.Close() } return nil } -func GetHyClient(serverAddr net.Addr, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) { +func GetHyClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) (hyClient.Client, error) { var err error var client hyClient.Client ClientMutex.Lock() - client, found := RunningClient[serverAddr] + client, found := RunningClient[dialerConf{dest, streamSettings}] ClientMutex.Unlock() if !found || !CheckHyClientHealthy(client) { if found { // retry - CloseHyClient(serverAddr) + CloseHyClient(dest, streamSettings) } - client, err = NewHyClient(serverAddr, streamSettings) + client, err = NewHyClient(dest, streamSettings) if err != nil { return nil, err } ClientMutex.Lock() - RunningClient[serverAddr] = client + RunningClient[dialerConf{dest, streamSettings}] = client ClientMutex.Unlock() } return client, nil @@ -144,14 +154,9 @@ func CheckHyClientHealthy(client hyClient.Client) bool { func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { config := streamSettings.ProtocolSettings.(*Config) - serverAddr, err := ResolveAddress(dest) - if err != nil { - return nil, err - } - - client, err := GetHyClient(serverAddr, streamSettings) + client, err := GetHyClient(dest, streamSettings) if err != nil { - CloseHyClient(serverAddr) + CloseHyClient(dest, streamSettings) return nil, err } @@ -172,7 +177,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me conn.IsServer = false conn.ClientUDPSession, err = client.UDP() if err != nil { - CloseHyClient(serverAddr) + CloseHyClient(dest, streamSettings) return nil, err } return conn, nil @@ -180,7 +185,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me conn.stream, err = client.OpenStream() if err != nil { - CloseHyClient(serverAddr) + CloseHyClient(dest, streamSettings) return nil, err } @@ -190,13 +195,13 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me hyProtocol.VarintPut(buf, hyProtocol.FrameTypeTCPRequest) _, err = conn.stream.Write(buf) if err != nil { - CloseHyClient(serverAddr) + CloseHyClient(dest, streamSettings) return nil, err } return conn, nil } func init() { - RunningClient = make(map[net.Addr]hyClient.Client) + RunningClient = make(map[dialerConf]hyClient.Client) common.Must(internet.RegisterTransportDialer(protocolName, Dial)) }