Skip to content

Commit 0b102d0

Browse files
committed
NOISSUE - Add support for interceptor (#51)
Signed-off-by: Dusan Borovcanin <[email protected]>
1 parent 0ffbc4f commit 0b102d0

File tree

7 files changed

+83
-54
lines changed

7 files changed

+83
-54
lines changed

cmd/main.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,31 +267,31 @@ func loadConfig() config {
267267

268268
func proxyMQTTWS(cfg WSMQTTConfig, logger mglog.Logger, handler session.Handler, errs chan error) {
269269
target := fmt.Sprintf("%s:%s", cfg.targetHost, cfg.targetPort)
270-
wp := websocket.New(target, cfg.targetPath, cfg.targetScheme, handler, logger)
270+
wp := websocket.New(target, cfg.targetPath, cfg.targetScheme, handler, nil, logger)
271271
http.Handle(cfg.path, wp.Handler())
272272

273273
errs <- wp.Listen(cfg.port)
274274
}
275275

276276
func proxyMQTTWSS(cfg config, tlsCfg *tls.Config, logger mglog.Logger, handler session.Handler, errs chan error) {
277277
target := fmt.Sprintf("%s:%s", cfg.wsMQTTConfig.targetHost, cfg.wsMQTTConfig.targetPort)
278-
wp := websocket.New(target, cfg.wsMQTTConfig.targetPath, cfg.wsMQTTConfig.targetScheme, handler, logger)
278+
wp := websocket.New(target, cfg.wsMQTTConfig.targetPath, cfg.wsMQTTConfig.targetScheme, handler, nil, logger)
279279
http.Handle(cfg.wsMQTTConfig.wssPath, wp.Handler())
280280
errs <- wp.ListenTLS(tlsCfg, cfg.serverCert, cfg.serverKey, cfg.wsMQTTConfig.wssPort)
281281
}
282282

283283
func proxyMQTT(ctx context.Context, cfg MQTTConfig, logger mglog.Logger, handler session.Handler, errs chan error) {
284284
address := fmt.Sprintf("%s:%s", cfg.host, cfg.port)
285285
target := fmt.Sprintf("%s:%s", cfg.targetHost, cfg.targetPort)
286-
mp := mqtt.New(address, target, handler, logger)
286+
mp := mqtt.New(address, target, handler, nil, logger)
287287

288288
errs <- mp.Listen(ctx)
289289
}
290290

291291
func proxyMQTTS(ctx context.Context, cfg MQTTConfig, tlsCfg *tls.Config, logger mglog.Logger, handler session.Handler, errs chan error) {
292292
address := fmt.Sprintf("%s:%s", cfg.host, cfg.mqttsPort)
293293
target := fmt.Sprintf("%s:%s", cfg.targetHost, cfg.targetPort)
294-
mp := mqtt.New(address, target, handler, logger)
294+
mp := mqtt.New(address, target, handler, nil, logger)
295295

296296
errs <- mp.ListenTLS(ctx, tlsCfg)
297297
}

go.mod

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ toolchain go1.21.4
77
require (
88
github.com/absmach/magistrala v0.11.1-0.20231220185538-1fe2e74a741f
99
github.com/eclipse/paho.mqtt.golang v1.4.3
10-
github.com/google/uuid v1.4.0
11-
github.com/gorilla/websocket v1.5.0
12-
golang.org/x/sync v0.4.0
10+
github.com/google/uuid v1.5.0
11+
github.com/gorilla/websocket v1.5.1
12+
golang.org/x/sync v0.6.0
1313
)
1414

1515
require (
1616
github.com/go-kit/log v0.2.1 // indirect
1717
github.com/go-logfmt/logfmt v0.6.0 // indirect
18-
golang.org/x/net v0.17.0 // indirect
18+
golang.org/x/net v0.20.0 // indirect
1919
)

go.sum

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@ github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU=
88
github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0=
99
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
1010
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
11-
github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4=
12-
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
13-
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
14-
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
11+
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
12+
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
13+
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
14+
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
1515
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
1616
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
1717
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
1818
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
19-
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
20-
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
21-
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
22-
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
19+
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
20+
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
21+
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
22+
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
2323
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
2424
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

pkg/mqtt/mqtt.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,24 @@ import (
1212
mptls "github.com/absmach/mproxy/pkg/tls"
1313
)
1414

15-
// Proxy is main MQTT proxy struct
15+
// Proxy is main MQTT proxy struct.
1616
type Proxy struct {
17-
address string
18-
target string
19-
handler session.Handler
20-
logger logger.Logger
21-
dialer net.Dialer
17+
address string
18+
target string
19+
handler session.Handler
20+
interceptor session.Interceptor
21+
logger logger.Logger
22+
dialer net.Dialer
2223
}
2324

24-
// New returns a new mqtt Proxy instance.
25-
func New(address, target string, handler session.Handler, logger logger.Logger) *Proxy {
25+
// New returns a new MQTT Proxy instance.
26+
func New(address, target string, handler session.Handler, interceptor session.Interceptor, logger logger.Logger) *Proxy {
2627
return &Proxy{
27-
address: address,
28-
target: target,
29-
handler: handler,
30-
logger: logger,
28+
address: address,
29+
target: target,
30+
handler: handler,
31+
logger: logger,
32+
interceptor: interceptor,
3133
}
3234
}
3335

@@ -59,7 +61,7 @@ func (p Proxy) handle(ctx context.Context, inbound net.Conn) {
5961
return
6062
}
6163

62-
if err = session.Stream(ctx, inbound, outbound, p.handler, clientCert); err != io.EOF {
64+
if err = session.Stream(ctx, inbound, outbound, p.handler, p.interceptor, clientCert); err != io.EOF {
6365
p.logger.Warn(err.Error())
6466
}
6567
}
@@ -79,7 +81,7 @@ func (p Proxy) Listen(ctx context.Context) error {
7981
return nil
8082
}
8183

82-
// ListenTLS - version of Listen with TLS encryption
84+
// ListenTLS - version of Listen with TLS encryption.
8385
func (p Proxy) ListenTLS(ctx context.Context, tlsCfg *tls.Config) error {
8486
l, err := tls.Listen("tcp", p.address, tlsCfg)
8587
if err != nil {

pkg/mqtt/websocket/websocket.go

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,24 @@ import (
1616

1717
// Proxy represents WS Proxy.
1818
type Proxy struct {
19-
target string
20-
path string
21-
scheme string
22-
event session.Handler
19+
target string
20+
path string
21+
scheme string
22+
handler session.Handler
23+
interceptor session.Interceptor
24+
2325
logger logger.Logger
2426
}
2527

2628
// New - creates new WS proxy
27-
func New(target, path, scheme string, event session.Handler, logger logger.Logger) *Proxy {
29+
func New(target, path, scheme string, handler session.Handler, interceptor session.Interceptor, logger logger.Logger) *Proxy {
2830
return &Proxy{
29-
target: target,
30-
path: path,
31-
scheme: scheme,
32-
event: event,
33-
logger: logger,
31+
target: target,
32+
path: path,
33+
scheme: scheme,
34+
handler: handler,
35+
interceptor: interceptor,
36+
logger: logger,
3437
}
3538
}
3639

@@ -94,7 +97,7 @@ func (p Proxy) pass(ctx context.Context, in *websocket.Conn) {
9497
return
9598
}
9699

97-
err = session.Stream(ctx, inboundConn, outboundConn, p.event, clientCert)
100+
err = session.Stream(ctx, inboundConn, outboundConn, p.handler, p.interceptor, clientCert)
98101
errc <- err
99102
p.logger.Warn("Broken connection for client with error: " + err.Error())
100103
}

pkg/session/interceptor.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package session
2+
3+
import (
4+
"context"
5+
6+
"github.com/eclipse/paho.mqtt.golang/packets"
7+
)
8+
9+
// Interceptor is an interface for mProxy intercept hook.
10+
type Interceptor interface {
11+
// Intercept is called on every packet flowing through the Proxy.
12+
// Packets can be modified before being sent to the broker or the client.
13+
// If the interceptor returns a non-nil packet, the modified packet is sent.
14+
// The error indicates unsuccessful interception and mProxy is cancelling the packet.
15+
Intercept(ctx context.Context, pkt packets.ControlPacket, dir Direction) (packets.ControlPacket, error)
16+
}

pkg/session/stream.go

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ import (
1010
"github.com/eclipse/paho.mqtt.golang/packets"
1111
)
1212

13-
type direction int
13+
type Direction int
1414

1515
const (
16-
up direction = iota
17-
down
16+
Up Direction = iota
17+
Down
1818
)
1919

2020
const unknownID = "unknown"
@@ -25,26 +25,26 @@ var (
2525
)
2626

2727
// Stream starts proxy between client and broker.
28-
func Stream(ctx context.Context, inbound, outbound net.Conn, handler Handler, cert x509.Certificate) error {
28+
func Stream(ctx context.Context, in, out net.Conn, h Handler, ic Interceptor, cert x509.Certificate) error {
2929
s := Session{
3030
Cert: cert,
3131
}
3232
ctx = NewContext(ctx, &s)
3333
errs := make(chan error, 2)
3434

35-
go stream(ctx, up, inbound, outbound, handler, errs)
36-
go stream(ctx, down, outbound, inbound, handler, errs)
35+
go stream(ctx, Up, in, out, h, ic, errs)
36+
go stream(ctx, Down, out, in, h, ic, errs)
3737

3838
// Handle whichever error happens first.
3939
// The other routine won't be blocked when writing
4040
// to the errors channel because it is buffered.
4141
err := <-errs
4242

43-
handler.Disconnect(ctx)
43+
h.Disconnect(ctx)
4444
return err
4545
}
4646

47-
func stream(ctx context.Context, dir direction, r, w net.Conn, h Handler, errs chan error) {
47+
func stream(ctx context.Context, dir Direction, r, w net.Conn, h Handler, ic Interceptor, errs chan error) {
4848
for {
4949
// Read from one connection.
5050
pkt, err := packets.ReadPacket(r)
@@ -53,20 +53,28 @@ func stream(ctx context.Context, dir direction, r, w net.Conn, h Handler, errs c
5353
return
5454
}
5555

56-
if dir == up {
56+
if dir == Up {
5757
if err = authorize(ctx, pkt, h); err != nil {
5858
errs <- wrap(ctx, err, dir)
5959
return
6060
}
6161
}
62+
if ic != nil {
63+
pkt, err = ic.Intercept(ctx, pkt, dir)
64+
if err != nil {
65+
errs <- wrap(ctx, err, dir)
66+
return
67+
}
68+
}
6269

6370
// Send to another.
6471
if err := pkt.Write(w); err != nil {
6572
errs <- wrap(ctx, err, dir)
6673
return
6774
}
6875

69-
if dir == up {
76+
// Notify only for packets sent from client to broker (incoming packets).
77+
if dir == Up {
7078
if err := notify(ctx, pkt, h); err != nil {
7179
errs <- wrap(ctx, err, dir)
7280
}
@@ -118,7 +126,7 @@ func notify(ctx context.Context, pkt packets.ControlPacket, h Handler) error {
118126
}
119127
}
120128

121-
func wrap(ctx context.Context, err error, dir direction) error {
129+
func wrap(ctx context.Context, err error, dir Direction) error {
122130
if err == io.EOF {
123131
return err
124132
}
@@ -127,9 +135,9 @@ func wrap(ctx context.Context, err error, dir direction) error {
127135
cid = s.ID
128136
}
129137
switch dir {
130-
case up:
138+
case Up:
131139
return fmt.Errorf(errClient, cid, err.Error())
132-
case down:
140+
case Down:
133141
return fmt.Errorf(errBroker, cid, err.Error())
134142
default:
135143
return err

0 commit comments

Comments
 (0)