diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 34f4acc..7a15f8e 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -3,6 +3,7 @@ package cmd import ( "bytes" "math/rand" + "net" "os" "strconv" "testing" @@ -19,6 +20,12 @@ func randDuration() time.Duration { return time.Duration(randSecs) * time.Second } +func getNetworkInterfaceName(t *testing.T) string { + interfaces, err := net.Interfaces() + assert.NoError(t, err) + return interfaces[0].Name +} + func TestFlags(t *testing.T) { defer func() { config.Reset() @@ -27,6 +34,7 @@ func TestFlags(t *testing.T) { discoverInterval := randDuration() pausedInterval := randDuration() playingInterval := randDuration() + networkInterface := getNetworkInterfaceName(t) var cmd *cobra.Command if !assert.NotPanics(t, func() { @@ -36,7 +44,7 @@ func TestFlags(t *testing.T) { } cmd.SetArgs([]string{ "--log-level=debug", - "--network-interface=eno1", + "--network-interface=" + networkInterface, "--discover-interval=" + discoverInterval.String(), "--paused-interval=" + pausedInterval.String(), "--playing-interval=" + playingInterval.String(), @@ -52,7 +60,7 @@ func TestFlags(t *testing.T) { } assert.Equal(t, "debug", config.Default.LogLevel) - assert.Equal(t, "eno1", config.Default.NetworkInterface) + assert.Equal(t, networkInterface, config.Default.NetworkInterfaceName) assert.Equal(t, discoverInterval, config.Default.DiscoverInterval) assert.Equal(t, pausedInterval, config.Default.PausedInterval) assert.Equal(t, playingInterval, config.Default.PlayingInterval) @@ -70,6 +78,7 @@ func TestEnvs(t *testing.T) { discoverInterval := randDuration() pausedInterval := randDuration() playingInterval := randDuration() + networkInterface := getNetworkInterfaceName(t) defer func() { _ = os.Unsetenv("CSS_LOG_LEVEL") @@ -82,7 +91,7 @@ func TestEnvs(t *testing.T) { _ = os.Unsetenv("CSS_MUTE_ADS") }() _ = os.Setenv("CSS_LOG_LEVEL", "warn") - _ = os.Setenv("CSS_NETWORK_INTERFACE", "eno1") + _ = os.Setenv("CSS_NETWORK_INTERFACE", networkInterface) _ = os.Setenv("CSS_DISCOVER_INTERVAL", discoverInterval.String()) _ = os.Setenv("CSS_PAUSED_INTERVAL", pausedInterval.String()) _ = os.Setenv("CSS_PLAYING_INTERVAL", playingInterval.String()) @@ -104,7 +113,7 @@ func TestEnvs(t *testing.T) { } assert.Equal(t, "warn", config.Default.LogLevel) - assert.Equal(t, "eno1", config.Default.NetworkInterface) + assert.Equal(t, networkInterface, config.Default.NetworkInterfaceName) assert.Equal(t, discoverInterval, config.Default.DiscoverInterval) assert.Equal(t, pausedInterval, config.Default.PausedInterval) assert.Equal(t, playingInterval, config.Default.PlayingInterval) diff --git a/internal/config/config.go b/internal/config/config.go index 2099e12..b2ee47d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "net" "strings" "time" @@ -24,7 +25,7 @@ func Reset() { PlayingInterval: 500 * time.Millisecond, SkipDelay: 0, - NetworkInterface: "", + NetworkInterface: nil, Categories: []string{"sponsor"}, ActionTypes: []string{"skip", "mute"}, @@ -44,7 +45,8 @@ type Config struct { PlayingInterval time.Duration `mapstructure:"playing-interval"` SkipDelay time.Duration `mapstructure:"skip-delay"` - NetworkInterface string `mapstructure:"network-interface"` + NetworkInterfaceName string `mapstructure:"network-interface"` + NetworkInterface *net.Interface Categories []string ActionTypes []string `mapstructure:"action-types"` @@ -87,5 +89,16 @@ func (c *Config) Load() error { } } - return c.viper.Unmarshal(c) + if err := c.viper.Unmarshal(c); err != nil { + return err + } + + if c.NetworkInterfaceName != "" { + var err error + if c.NetworkInterface, err = net.InterfaceByName(c.NetworkInterfaceName); err != nil { + return err + } + } + + return nil } diff --git a/internal/config/network.go b/internal/config/network.go index b70e418..19d3eb1 100644 --- a/internal/config/network.go +++ b/internal/config/network.go @@ -9,7 +9,7 @@ import ( func (c *Config) RegisterNetworkInterface(cmd *cobra.Command) { key := "network-interface" - cmd.PersistentFlags().StringP(key, "i", Default.NetworkInterface, "Network interface to use for multicast dns discovery. (default all interfaces)") + cmd.PersistentFlags().StringP(key, "i", Default.NetworkInterfaceName, "Network interface to use for multicast dns discovery. (default all interfaces)") if err := c.viper.BindPFlag(key, cmd.PersistentFlags().Lookup(key)); err != nil { panic(err) } diff --git a/internal/device/dns.go b/internal/device/dns.go index fdb3f77..8810e4a 100644 --- a/internal/device/dns.go +++ b/internal/device/dns.go @@ -15,19 +15,10 @@ import ( var ErrDeviceNotFound = errors.New("device not found") func DiscoverCastDNSEntryByUuid(ctx context.Context, uuid string) (castdns.CastEntry, error) { - var iface *net.Interface - if config.Default.NetworkInterface != "" { - var err error - iface, err = net.InterfaceByName(config.Default.NetworkInterface) - if err != nil { - return castdns.CastEntry{}, err - } - } - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - entries, err := castdns.DiscoverCastDNSEntries(ctx, iface) + entries, err := castdns.DiscoverCastDNSEntries(ctx, config.Default.NetworkInterface) if err != nil { return castdns.CastEntry{}, err } @@ -64,14 +55,8 @@ func DiscoverCastDNSEntries(ctx context.Context, iface *net.Interface, ch chan c } func BeginDiscover(ctx context.Context) (<-chan castdns.CastEntry, error) { - var iface *net.Interface - if config.Default.NetworkInterface != "" { - var err error - iface, err = net.InterfaceByName(config.Default.NetworkInterface) - if err != nil { - return nil, err - } - slog.Info("Searching for devices...", "interface", config.Default.NetworkInterface) + if config.Default.NetworkInterface != nil { + slog.Info("Searching for devices...", "interface", config.Default.NetworkInterfaceName) } else { slog.Info("Searching for devices...") } @@ -87,7 +72,7 @@ func BeginDiscover(ctx context.Context) (<-chan castdns.CastEntry, error) { return } - if err := DiscoverCastDNSEntries(ctx, iface, ch); err != nil { + if err := DiscoverCastDNSEntries(ctx, config.Default.NetworkInterface, ch); err != nil { slog.Error("Failed to discover devices.", "error", err.Error()) continue } diff --git a/internal/device/watch.go b/internal/device/watch.go index 943c71b..0a010d5 100644 --- a/internal/device/watch.go +++ b/internal/device/watch.go @@ -222,6 +222,7 @@ func (d *Device) connect(opts ...application.ApplicationOption) error { opts, application.WithSkipadSleep(config.Default.PlayingInterval), application.WithSkipadRetries(int(time.Minute/config.Default.PlayingInterval)), + application.WithIface(config.Default.NetworkInterface), ) d.app = application.NewApplication(opts...) d.app.AddMessageFunc(d.onMessage)