Skip to content

Commit

Permalink
fix(device): Use configured network interface
Browse files Browse the repository at this point in the history
  • Loading branch information
gabe565 committed Sep 11, 2023
1 parent 4bb37e2 commit 62e3214
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 26 deletions.
17 changes: 13 additions & 4 deletions cmd/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"bytes"
"math/rand"
"net"
"os"
"strconv"
"testing"
Expand All @@ -19,6 +20,12 @@ func randDuration() time.Duration {
return time.Duration(randSecs) * time.Second
}

func getNetworkInterfaceName(t *testing.T) string {
interfaces, err := net.Interfaces()
assert.NoError(t, err)
return interfaces[0].Name
}

func TestFlags(t *testing.T) {
defer func() {
config.Reset()
Expand All @@ -27,6 +34,7 @@ func TestFlags(t *testing.T) {
discoverInterval := randDuration()
pausedInterval := randDuration()
playingInterval := randDuration()
networkInterface := getNetworkInterfaceName(t)

var cmd *cobra.Command
if !assert.NotPanics(t, func() {
Expand All @@ -36,7 +44,7 @@ func TestFlags(t *testing.T) {
}
cmd.SetArgs([]string{
"--log-level=debug",
"--network-interface=eno1",
"--network-interface=" + networkInterface,
"--discover-interval=" + discoverInterval.String(),
"--paused-interval=" + pausedInterval.String(),
"--playing-interval=" + playingInterval.String(),
Expand All @@ -52,7 +60,7 @@ func TestFlags(t *testing.T) {
}

assert.Equal(t, "debug", config.Default.LogLevel)
assert.Equal(t, "eno1", config.Default.NetworkInterface)
assert.Equal(t, networkInterface, config.Default.NetworkInterfaceName)
assert.Equal(t, discoverInterval, config.Default.DiscoverInterval)
assert.Equal(t, pausedInterval, config.Default.PausedInterval)
assert.Equal(t, playingInterval, config.Default.PlayingInterval)
Expand All @@ -70,6 +78,7 @@ func TestEnvs(t *testing.T) {
discoverInterval := randDuration()
pausedInterval := randDuration()
playingInterval := randDuration()
networkInterface := getNetworkInterfaceName(t)

defer func() {
_ = os.Unsetenv("CSS_LOG_LEVEL")
Expand All @@ -82,7 +91,7 @@ func TestEnvs(t *testing.T) {
_ = os.Unsetenv("CSS_MUTE_ADS")
}()
_ = os.Setenv("CSS_LOG_LEVEL", "warn")
_ = os.Setenv("CSS_NETWORK_INTERFACE", "eno1")
_ = os.Setenv("CSS_NETWORK_INTERFACE", networkInterface)
_ = os.Setenv("CSS_DISCOVER_INTERVAL", discoverInterval.String())
_ = os.Setenv("CSS_PAUSED_INTERVAL", pausedInterval.String())
_ = os.Setenv("CSS_PLAYING_INTERVAL", playingInterval.String())
Expand All @@ -104,7 +113,7 @@ func TestEnvs(t *testing.T) {
}

assert.Equal(t, "warn", config.Default.LogLevel)
assert.Equal(t, "eno1", config.Default.NetworkInterface)
assert.Equal(t, networkInterface, config.Default.NetworkInterfaceName)
assert.Equal(t, discoverInterval, config.Default.DiscoverInterval)
assert.Equal(t, pausedInterval, config.Default.PausedInterval)
assert.Equal(t, playingInterval, config.Default.PlayingInterval)
Expand Down
19 changes: 16 additions & 3 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"fmt"
"net"
"strings"
"time"

Expand All @@ -24,7 +25,7 @@ func Reset() {
PlayingInterval: 500 * time.Millisecond,
SkipDelay: 0,

NetworkInterface: "",
NetworkInterface: nil,

Categories: []string{"sponsor"},
ActionTypes: []string{"skip", "mute"},
Expand All @@ -44,7 +45,8 @@ type Config struct {
PlayingInterval time.Duration `mapstructure:"playing-interval"`
SkipDelay time.Duration `mapstructure:"skip-delay"`

NetworkInterface string `mapstructure:"network-interface"`
NetworkInterfaceName string `mapstructure:"network-interface"`
NetworkInterface *net.Interface

Categories []string
ActionTypes []string `mapstructure:"action-types"`
Expand Down Expand Up @@ -87,5 +89,16 @@ func (c *Config) Load() error {
}
}

return c.viper.Unmarshal(c)
if err := c.viper.Unmarshal(c); err != nil {
return err
}

if c.NetworkInterfaceName != "" {
var err error
if c.NetworkInterface, err = net.InterfaceByName(c.NetworkInterfaceName); err != nil {
return err
}
}

return nil
}
2 changes: 1 addition & 1 deletion internal/config/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

func (c *Config) RegisterNetworkInterface(cmd *cobra.Command) {
key := "network-interface"
cmd.PersistentFlags().StringP(key, "i", Default.NetworkInterface, "Network interface to use for multicast dns discovery. (default all interfaces)")
cmd.PersistentFlags().StringP(key, "i", Default.NetworkInterfaceName, "Network interface to use for multicast dns discovery. (default all interfaces)")
if err := c.viper.BindPFlag(key, cmd.PersistentFlags().Lookup(key)); err != nil {
panic(err)
}
Expand Down
21 changes: 3 additions & 18 deletions internal/device/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,10 @@ import (
var ErrDeviceNotFound = errors.New("device not found")

func DiscoverCastDNSEntryByUuid(ctx context.Context, uuid string) (castdns.CastEntry, error) {
var iface *net.Interface
if config.Default.NetworkInterface != "" {
var err error
iface, err = net.InterfaceByName(config.Default.NetworkInterface)
if err != nil {
return castdns.CastEntry{}, err
}
}

ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()

entries, err := castdns.DiscoverCastDNSEntries(ctx, iface)
entries, err := castdns.DiscoverCastDNSEntries(ctx, config.Default.NetworkInterface)
if err != nil {
return castdns.CastEntry{}, err
}
Expand Down Expand Up @@ -64,13 +55,7 @@ func DiscoverCastDNSEntries(ctx context.Context, iface *net.Interface, ch chan c
}

func BeginDiscover(ctx context.Context) (<-chan castdns.CastEntry, error) {
var iface *net.Interface
if config.Default.NetworkInterface != "" {
var err error
iface, err = net.InterfaceByName(config.Default.NetworkInterface)
if err != nil {
return nil, err
}
if config.Default.NetworkInterface != nil {
slog.Info("Searching for devices...", "interface", config.Default.NetworkInterface)
} else {
slog.Info("Searching for devices...")
Expand All @@ -87,7 +72,7 @@ func BeginDiscover(ctx context.Context) (<-chan castdns.CastEntry, error) {
return
}

if err := DiscoverCastDNSEntries(ctx, iface, ch); err != nil {
if err := DiscoverCastDNSEntries(ctx, config.Default.NetworkInterface, ch); err != nil {
slog.Error("Failed to discover devices.", "error", err.Error())
continue
}
Expand Down
1 change: 1 addition & 0 deletions internal/device/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ func (d *Device) connect(opts ...application.ApplicationOption) error {
opts,
application.WithSkipadSleep(config.Default.PlayingInterval),
application.WithSkipadRetries(int(time.Minute/config.Default.PlayingInterval)),
application.WithIface(config.Default.NetworkInterface),
)
d.app = application.NewApplication(opts...)
d.app.AddMessageFunc(d.onMessage)
Expand Down

0 comments on commit 62e3214

Please sign in to comment.