-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathctx.go
157 lines (135 loc) · 3.66 KB
/
ctx.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
package mq
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"sync"
"github.com/oarkflow/errors"
"github.com/oarkflow/xid"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
)
type Handler func(context.Context, *Task) Result
func IsClosed(conn net.Conn) bool {
_, err := conn.Read(make([]byte, 1))
if err != nil {
if err == net.ErrClosed {
return true
}
}
return false
}
func SetHeaders(ctx context.Context, headers map[string]string) context.Context {
hd, _ := GetHeaders(ctx)
if hd == nil {
hd = memory.New[string, string]()
}
for key, val := range headers {
hd.Set(key, val)
}
return context.WithValue(ctx, consts.HeaderKey, hd)
}
func WithHeaders(ctx context.Context, headers map[string]string) map[string]string {
hd, _ := GetHeaders(ctx)
if hd == nil {
hd = memory.New[string, string]()
}
for key, val := range headers {
hd.Set(key, val)
}
return hd.AsMap()
}
func GetHeaders(ctx context.Context) (storage.IMap[string, string], bool) {
headers, ok := ctx.Value(consts.HeaderKey).(storage.IMap[string, string])
return headers, ok
}
func GetHeader(ctx context.Context, key string) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {
return "", false
}
val, ok := headers.Get(key)
return val, ok
}
func GetContentType(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.ContentType)
}
func GetQueue(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.QueueKey)
}
func GetConsumerID(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.ConsumerKey)
}
func GetTriggerNode(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.TriggerNode)
}
func GetAwaitResponse(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.AwaitResponseKey)
}
func GetPublisherID(ctx context.Context) (string, bool) {
return GetHeader(ctx, consts.PublisherKey)
}
func NewID() string {
return xid.New().String()
}
func createTLSConnection(addr, certPath, keyPath string, caPath ...string) (net.Conn, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("failed to load client cert/key: %w", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert,
InsecureSkipVerify: true,
}
if len(caPath) > 0 && caPath[0] != "" {
caCert, err := os.ReadFile(caPath[0])
if err != nil {
return nil, fmt.Errorf("failed to load CA cert: %w", err)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
conn, err := tls.Dial("tcp", addr, tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial TLS connection: %w", err)
}
return conn, nil
}
// Global connection pool
var connPool sync.Map
// Modified GetConnection: reuse existing connection if valid.
func GetConnection(addr string, config TLSConfig) (net.Conn, error) {
key := fmt.Sprintf("%s_%t", addr, config.UseTLS)
// Check if a connection exists and reuse it if not closed.
if c, ok := connPool.Load(key); ok {
conn := c.(net.Conn)
if !IsClosed(conn) {
return conn, nil
}
// If closed, delete the stale connection.
connPool.Delete(key)
}
var conn net.Conn
var err error
if config.UseTLS {
conn, err = createTLSConnection(addr, config.CertPath, config.KeyPath, config.CAPath)
} else {
conn, err = net.Dial("tcp", addr)
}
if err != nil {
return nil, err
}
// Store the new connection in the pool.
connPool.Store(key, conn)
return conn, nil
}
func WrapError(err error, msg, op string) error {
return errors.Wrap(err, msg, op)
}