Skip to content

Commit d8652d4

Browse files
committed
Add options to set custom replay detector
Support using custom implementation of replay detector.
1 parent def59cc commit d8652d4

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

option.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ func SRTCPNoReplayProtection() ContextOption {
5050
}
5151
}
5252

53+
// SRTPReplayDetectorFactory sets custom SRTP replay detector.
54+
func SRTPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextOption { // nolint:revive
55+
return func(c *Context) error {
56+
c.newSRTPReplayDetector = fn
57+
return nil
58+
}
59+
}
60+
61+
// SRTCPReplayDetectorFactory sets custom SRTCP replay detector.
62+
func SRTCPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextOption {
63+
return func(c *Context) error {
64+
c.newSRTCPReplayDetector = fn
65+
return nil
66+
}
67+
}
68+
5369
type nopReplayDetector struct{}
5470

5571
func (s *nopReplayDetector) Check(uint64) (func(), bool) {

srtcp_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"testing"
1111

1212
"github.com/pion/rtcp"
13+
"github.com/pion/transport/v2/replaydetector"
1314
"github.com/stretchr/testify/assert"
1415
)
1516

@@ -570,3 +571,26 @@ func TestRTCPMaxPackets(t *testing.T) {
570571
})
571572
}
572573
}
574+
575+
func TestRTCPReplayDetectorFactory(t *testing.T) {
576+
assert := assert.New(t)
577+
testCase := rtcpTestCases()["AEAD_AES_128_GCM"]
578+
data := testCase.packets[0]
579+
580+
var cntFactory int
581+
decryptContext, err := CreateContext(
582+
testCase.masterKey, testCase.masterSalt, testCase.algo,
583+
SRTCPReplayDetectorFactory(func() replaydetector.ReplayDetector {
584+
cntFactory++
585+
return &nopReplayDetector{}
586+
}),
587+
)
588+
if err != nil {
589+
t.Fatal(err)
590+
}
591+
592+
if _, err := decryptContext.DecryptRTCP(nil, data.encrypted, nil); err != nil {
593+
t.Fatal(err)
594+
}
595+
assert.Equal(1, cntFactory)
596+
}

srtp_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"testing"
1010

1111
"github.com/pion/rtp"
12+
"github.com/pion/transport/v2/replaydetector"
1213
"github.com/stretchr/testify/assert"
1314
)
1415

@@ -462,6 +463,37 @@ func TestRTPReplayProtection(t *testing.T) {
462463
t.Run("GCM", func(t *testing.T) { testRTPReplayProtection(t, profileGCM) })
463464
}
464465

466+
func TestRTPReplayDetectorFactory(t *testing.T) {
467+
assert := assert.New(t)
468+
profile := profileCTR
469+
data := rtpTestCases()[0]
470+
471+
var cntFactory int
472+
decryptContext, err := buildTestContext(
473+
profile, SRTPReplayDetectorFactory(func() replaydetector.ReplayDetector {
474+
cntFactory++
475+
return &nopReplayDetector{}
476+
}),
477+
)
478+
if err != nil {
479+
t.Fatal(err)
480+
}
481+
482+
pkt := &rtp.Packet{
483+
Payload: data.encrypted(profile),
484+
Header: rtp.Header{SequenceNumber: data.sequenceNumber},
485+
}
486+
in, err := pkt.Marshal()
487+
if err != nil {
488+
t.Fatal(err)
489+
}
490+
491+
if _, err := decryptContext.DecryptRTP(nil, in, nil); err != nil {
492+
t.Fatal(err)
493+
}
494+
assert.Equal(1, cntFactory)
495+
}
496+
465497
func benchmarkEncryptRTP(b *testing.B, profile ProtectionProfile, size int) {
466498
encryptContext, err := buildTestContext(profile)
467499
if err != nil {

0 commit comments

Comments
 (0)