Skip to content

Commit

Permalink
Simulcast: Allow layer selection for publisher (#477)
Browse files Browse the repository at this point in the history
* Simulcast: Allow layer selection for publisher

Changed ion-sfu channel messaging to be peer based and added aditional types

Addressing PR comments

Rename IonSfuMessage to ChannelAPIMessage. Use interface{} instead of string, fixes double marshalling issue

* Rebase on new interface changes
  • Loading branch information
dreamerns committed May 12, 2021
1 parent 7add00c commit 6386736
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 36 deletions.
161 changes: 127 additions & 34 deletions pkg/middlewares/datachannel/subscriberapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,92 @@ package datachannel
import (
"context"
"encoding/json"
"fmt"

"github.com/pion/ion-sfu/pkg/sfu"
"github.com/pion/webrtc/v3"
)

const (
highValue = "high"
mediumValue = "medium"
lowValue = "low"
mutedValue = "none"
highValue = "high"
mediumValue = "medium"
lowValue = "low"
mutedValue = "none"
ActiveLayerMethod = "activeLayer"
)

type setRemoteMedia struct {
StreamID string `json:"streamId"`
Video string `json:"video"`
Framerate string `json:"framerate"`
Audio bool `json:"audio"`
StreamID string `json:"streamId"`
Video string `json:"video"`
Framerate string `json:"framerate"`
Audio bool `json:"audio"`
Layers []string `json:"layers"`
}

type activeLayerMessage struct {
StreamID string `json:"streamId"`
ActiveLayer string `json:"activeLayer"`
AvailableLayers []string `json:"availableLayers"`
}

func layerStrToInt(layer string) (int, error) {
switch layer {
case highValue:
return 2, nil
case mediumValue:
return 1, nil
case lowValue:
return 0, nil
default:
// unknown value
return -1, fmt.Errorf("Unknown value")
}
}

func layerIntToStr(layer int) (string, error) {
switch layer {
case 0:
return lowValue, nil
case 1:
return mediumValue, nil
case 2:
return highValue, nil
default:
return "", fmt.Errorf("Unknown value: %d", layer)
}
}

func transformLayers(layers []string) ([]uint16, error) {
res := make([]uint16, len(layers))
for _, layer := range layers {
if l, err := layerStrToInt(layer); err == nil {
res = append(res, uint16(l))
} else {
return nil, fmt.Errorf("Unknown layer value: %v", layer)
}
}
return res, nil
}

func sendMessage(streamID string, peer sfu.Peer, layers []string, activeLayer int) {
al, _ := layerIntToStr(activeLayer)
payload := activeLayerMessage{
StreamID: streamID,
ActiveLayer: al,
AvailableLayers: layers,
}
msg := sfu.ChannelAPIMessage{
Method: ActiveLayerMethod,
Params: payload,
}
bytes, err := json.Marshal(msg)
if err != nil {
sfu.Logger.Error(err, "unable to marshal active layer message")
}

if err := peer.SendAPIChannelMessage(&bytes); err != nil {
sfu.Logger.Error(err, "unable to send ActiveLayerMessage to peer", "peer_id", peer.ID())
}
}

func SubscriberAPI(next sfu.MessageProcessor) sfu.MessageProcessor {
Expand All @@ -28,35 +97,59 @@ func SubscriberAPI(next sfu.MessageProcessor) sfu.MessageProcessor {
if err := json.Unmarshal(args.Message.Data, srm); err != nil {
return
}
downTracks := args.Peer.Subscriber().GetDownTracks(srm.StreamID)
for _, dt := range downTracks {
switch dt.Kind() {
case webrtc.RTPCodecTypeAudio:
dt.Mute(!srm.Audio)
case webrtc.RTPCodecTypeVideo:
switch srm.Video {
case highValue:
dt.Mute(false)
dt.SwitchSpatialLayer(2, true)
case mediumValue:
dt.Mute(false)
dt.SwitchSpatialLayer(1, true)
case lowValue:
dt.Mute(false)
dt.SwitchSpatialLayer(0, true)
case mutedValue:
dt.Mute(true)
}
switch srm.Framerate {
case highValue:
dt.SwitchTemporalLayer(2, true)
case mediumValue:
dt.SwitchTemporalLayer(1, true)
case lowValue:
dt.SwitchTemporalLayer(0, true)
// Publisher changing active layers
if srm.Layers != nil && len(srm.Layers) > 0 {
layers, err := transformLayers(srm.Layers)
if err != nil {
sfu.Logger.Error(err, "error reading layers")
next.Process(ctx, args)
return
}

session := args.Peer.Session()
peers := session.Peers()
for _, peer := range peers {
if peer.ID() != args.Peer.ID() {
downTracks := peer.Subscriber().GetDownTracks(srm.StreamID)
for _, dt := range downTracks {
if dt.Kind() == webrtc.RTPCodecTypeVideo {
newLayer, _ := dt.UptrackLayersChange(layers)
sendMessage(srm.StreamID, peer, srm.Layers, int(newLayer))
}
}
}
}
} else {
downTracks := args.Peer.Subscriber().GetDownTracks(srm.StreamID)
for _, dt := range downTracks {
switch dt.Kind() {
case webrtc.RTPCodecTypeAudio:
dt.Mute(!srm.Audio)
case webrtc.RTPCodecTypeVideo:
switch srm.Video {
case highValue:
dt.Mute(false)
dt.SwitchSpatialLayer(2, true)
case mediumValue:
dt.Mute(false)
dt.SwitchSpatialLayer(1, true)
case lowValue:
dt.Mute(false)
dt.SwitchSpatialLayer(0, true)
case mutedValue:
dt.Mute(true)
}
switch srm.Framerate {
case highValue:
dt.SwitchTemporalLayer(2, true)
case mediumValue:
dt.SwitchTemporalLayer(1, true)
case lowValue:
dt.SwitchTemporalLayer(0, true)
}
}

}
}
next.Process(ctx, args)
})
Expand Down
35 changes: 35 additions & 0 deletions pkg/sfu/downtrack.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sfu

import (
"fmt"
"strings"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -211,6 +212,40 @@ func (d *DownTrack) SwitchSpatialLayer(targetLayer int64, setAsMax bool) {
}
}

func (d *DownTrack) UptrackLayersChange(availableLayers []uint16) (int64, error) {
if d.trackType == SimulcastDownTrack {
currentLayer := uint16(atomic.LoadInt32(&d.spatialLayer))
maxLayer := uint16(atomic.LoadInt64(&d.maxSpatialLayer))

var maxFound uint16 = 0
layerFound := false
var minFound uint16 = 0
for _, target := range availableLayers {
if target <= maxLayer {
if target > maxFound {
maxFound = target
layerFound = true
}
} else {
if minFound > target {
minFound = target
}
}
}
var targetLayer uint16
if layerFound {
targetLayer = maxFound
} else {
targetLayer = minFound
}
if currentLayer != targetLayer {
d.SwitchSpatialLayer(int64(targetLayer), false)
}
return int64(targetLayer), nil
}
return -1, fmt.Errorf("Downtrack %s does not support simulcast", d.id)
}

func (d *DownTrack) SwitchTemporalLayer(targetLayer int64, setAsMax bool) {
if d.trackType == SimulcastDownTrack {
layer := atomic.LoadInt32(&d.temporalLayer)
Expand Down
22 changes: 22 additions & 0 deletions pkg/sfu/peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Peer interface {
Publisher() *Publisher
Subscriber() *Subscriber
Close() error
SendAPIChannelMessage(msg *[]byte) error
}

// JoinConfig allow adding more control to the peers joining a SessionLocal.
Expand All @@ -48,6 +49,11 @@ type SessionProvider interface {
GetSession(sid string) (Session, WebRTCTransportConfig)
}

type ChannelAPIMessage struct {
Method string `json:"method"`
Params interface{} `json:"params,omitempty"`
}

// PeerLocal represents a pair peer connection
type PeerLocal struct {
sync.Mutex
Expand Down Expand Up @@ -241,6 +247,22 @@ func (p *PeerLocal) Trickle(candidate webrtc.ICECandidateInit, target int) error
return nil
}

func (p *PeerLocal) SendAPIChannelMessage(msg *[]byte) error {
if p.subscriber == nil {
return fmt.Errorf("No subscriber for this peer")
}
dc := p.subscriber.DataChannel(APIChannelLabel)

if dc == nil {
return fmt.Errorf("Data channel %s doesn't exist", APIChannelLabel)
}

if err := dc.SendText(string(*msg)); err != nil {
return fmt.Errorf("Failed to send message: %v", err)
}
return nil
}

// Close shuts down the peer connection and sends true to the done channel
func (p *PeerLocal) Close() error {
p.Lock()
Expand Down
13 changes: 11 additions & 2 deletions pkg/sfu/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type Session interface {
GetDCMiddlewares() []*Datachannel
GetDataChannelLabels() []string
GetDataChannels(origin, label string) (dcs []*webrtc.DataChannel)
Peers() []Peer
}

type SessionLocal struct {
Expand All @@ -35,6 +36,10 @@ type SessionLocal struct {
onCloseHandler func()
}

const (
AudioLevelsMethod = "audioLevels"
)

// NewSession creates a new SessionLocal
func NewSession(id string, dcs []*Datachannel, cfg WebRTCTransportConfig) Session {
s := &SessionLocal{
Expand All @@ -45,7 +50,6 @@ func NewSession(id string, dcs []*Datachannel, cfg WebRTCTransportConfig) Sessio
}
go s.audioLevelObserver(cfg.Router.AudioLevelInterval)
return s

}

// ID return SessionLocal id
Expand Down Expand Up @@ -251,7 +255,12 @@ func (s *SessionLocal) audioLevelObserver(audioLevelInterval int) {
continue
}

l, err := json.Marshal(&levels)
msg := ChannelAPIMessage{
Method: AudioLevelsMethod,
Params: levels,
}

l, err := json.Marshal(&msg)
if err != nil {
Logger.Error(err, "Marshaling audio levels err")
continue
Expand Down

0 comments on commit 6386736

Please sign in to comment.