From b182a214ce07ccaf1899b38e4b24e4ee5c5703ae Mon Sep 17 00:00:00 2001
From: Gabe Cook <gabe565@gmail.com>
Date: Wed, 1 Nov 2023 01:10:03 -0500
Subject: [PATCH] fix(device): Fix search not triggering segment download

Ref #27
---
 internal/device/query_state.go       | 12 ++++
 internal/device/querystate_string.go | 26 ++++++++
 internal/device/video_meta.go        | 28 ++++++++
 internal/device/watch.go             | 97 ++++++++++++++++------------
 4 files changed, 121 insertions(+), 42 deletions(-)
 create mode 100644 internal/device/query_state.go
 create mode 100644 internal/device/querystate_string.go
 create mode 100644 internal/device/video_meta.go

diff --git a/internal/device/query_state.go b/internal/device/query_state.go
new file mode 100644
index 0000000..7f4f2e5
--- /dev/null
+++ b/internal/device/query_state.go
@@ -0,0 +1,12 @@
+package device
+
+//go:generate stringer -type QueryState -linecomment
+
+type QueryState uint8
+
+const (
+	QueryNone    QueryState = iota // none
+	QueryStarted                   // started
+	QuerySuccess                   // success
+	QueryFailed                    // failed
+)
diff --git a/internal/device/querystate_string.go b/internal/device/querystate_string.go
new file mode 100644
index 0000000..10d38c9
--- /dev/null
+++ b/internal/device/querystate_string.go
@@ -0,0 +1,26 @@
+// Code generated by "stringer -type QueryState -linecomment"; DO NOT EDIT.
+
+package device
+
+import "strconv"
+
+func _() {
+	// An "invalid array index" compiler error signifies that the constant values have changed.
+	// Re-run the stringer command to generate them again.
+	var x [1]struct{}
+	_ = x[QueryNone-0]
+	_ = x[QueryStarted-1]
+	_ = x[QuerySuccess-2]
+	_ = x[QueryFailed-3]
+}
+
+const _QueryState_name = "nonestartedsuccessfailed"
+
+var _QueryState_index = [...]uint8{0, 4, 11, 18, 24}
+
+func (i QueryState) String() string {
+	if i >= QueryState(len(_QueryState_index)-1) {
+		return "QueryState(" + strconv.FormatInt(int64(i), 10) + ")"
+	}
+	return _QueryState_name[_QueryState_index[i]:_QueryState_index[i+1]]
+}
diff --git a/internal/device/video_meta.go b/internal/device/video_meta.go
new file mode 100644
index 0000000..7e49a8d
--- /dev/null
+++ b/internal/device/video_meta.go
@@ -0,0 +1,28 @@
+package device
+
+type VideoMeta struct {
+	CurrVideoId string
+	CurrArtist  string
+	CurrTitle   string
+
+	PrevVideoId string
+	PrevArtist  string
+	PrevTitle   string
+}
+
+func (v *VideoMeta) Clear() {
+	v.CurrVideoId = ""
+	v.CurrArtist = ""
+	v.CurrTitle = ""
+	v.PrevVideoId = ""
+	v.PrevArtist = ""
+	v.PrevTitle = ""
+}
+
+func (v VideoMeta) Empty() bool {
+	return v.CurrArtist == "" || v.CurrTitle == ""
+}
+
+func (v VideoMeta) SameVideo() bool {
+	return v.CurrArtist == v.PrevArtist && v.CurrTitle == v.PrevTitle
+}
diff --git a/internal/device/watch.go b/internal/device/watch.go
index dc6e875..f58c6ac 100644
--- a/internal/device/watch.go
+++ b/internal/device/watch.go
@@ -45,9 +45,8 @@ type Device struct {
 	tickInterval time.Duration
 	ticker       *time.Ticker
 
-	prevVideoId    string
-	prevArtist     string
-	prevTitle      string
+	meta           VideoMeta
+	queryState     QueryState
 	mediaSessionId int
 	segments       []sponsorblock.Segment
 	mutedSegmentId int
@@ -180,18 +179,29 @@ func (d *Device) tick() error {
 	case StateAd:
 		d.muteAd(castVol)
 	default:
-		if castMedia.Media.ContentId == "" {
-			d.queryVideoId(castMedia)
+		if castMedia.Media.Metadata.Artist != "" {
+			d.meta.CurrArtist = castMedia.Media.Metadata.Artist
+		} else {
+			d.meta.CurrArtist = castMedia.Media.Metadata.Subtitle
+		}
+		d.meta.CurrTitle = castMedia.Media.Metadata.Title
+
+		if castMedia.Media.ContentId != "" {
+			d.meta.CurrVideoId = castMedia.Media.ContentId
+			d.queryState = QueryNone
+		} else {
+			d.queryVideoId()
+			break
 		}
 
-		if castMedia.Media.ContentId != d.prevVideoId {
-			if castMedia.Media.ContentId != "" {
-				d.logger.Info("Detected video stream.", "video_id", castMedia.Media.ContentId)
+		if d.meta.CurrVideoId != d.meta.PrevVideoId {
+			d.segments = nil
+			if d.meta.CurrVideoId != "" {
+				d.logger.Info("Detected video stream.", "video_id", d.meta.CurrVideoId)
+				d.meta.PrevVideoId = d.meta.CurrVideoId
+				go d.querySegments(castMedia)
 			}
-			d.prevVideoId = castMedia.Media.ContentId
 			d.unmuteSegment()
-			d.segments = nil
-			go d.querySegments(castMedia)
 			break
 		}
 
@@ -279,7 +289,7 @@ func (d *Device) onMessage(msg *api.CastMessage) {
 	case "CLOSE":
 		d.unmuteSegment()
 		d.segments = nil
-		d.prevTitle, d.prevArtist, d.prevVideoId = "", "", ""
+		d.meta.Clear()
 		d.mediaSessionId = 0
 	}
 }
@@ -304,38 +314,41 @@ func (d *Device) update() error {
 	return err
 }
 
-func (d *Device) queryVideoId(castMedia *cast.Media) {
-	var currArtist string
-	if castMedia.Media.Metadata.Artist != "" {
-		currArtist = castMedia.Media.Metadata.Artist
-	} else {
-		currArtist = castMedia.Media.Metadata.Subtitle
+func (d *Device) queryVideoId() {
+	switch d.queryState {
+	case QueryStarted:
+		return
+	case QueryFailed, QuerySuccess:
+		if d.meta.Empty() || d.meta.SameVideo() {
+			return
+		}
 	}
-	currTitle := castMedia.Media.Metadata.Title
 
-	if currArtist == d.prevArtist && currTitle == d.prevTitle {
-		castMedia.Media.ContentId = d.prevVideoId
+	if config.Default.YouTubeAPIKey == "" {
+		d.logger.Warn("Video ID not found. Please set a YouTube API key.")
 	} else {
-		if config.Default.YouTubeAPIKey == "" {
-			d.logger.Warn("Video ID not found. Please set a YouTube API key.")
-		} else {
-			d.logger.Info("Video ID not found. Searching for video on YouTube...")
-			d.prevArtist = currArtist
-			d.prevTitle = currTitle
-			go func() {
-				if err := util.Retry(d.ctx, 3, time.Second, func(try uint) (err error) {
-					castMedia.Media.ContentId, err = youtube.QueryVideoId(d.ctx, currArtist, currTitle)
-					if err != nil {
-						d.logger.Error("YouTube search failed.", "error", err.Error())
-					}
+		d.logger.Info("Video ID not found. Searching for video on YouTube...")
+		d.queryState = QueryStarted
+		d.meta.PrevArtist = d.meta.CurrArtist
+		d.meta.PrevTitle = d.meta.CurrTitle
+		go func() {
+			if err := util.Retry(d.ctx, 3, time.Second, func(try uint) (err error) {
+				contentId, err := youtube.QueryVideoId(d.ctx, d.meta.CurrArtist, d.meta.CurrTitle)
+				if err != nil {
+					d.logger.Error("YouTube search failed.", "error", err.Error())
+					d.queryState = QueryFailed
 					return err
-				}); err != nil {
-					d.logger.Debug("Halting YouTube search retries.")
-					return
 				}
-				d.logger.Debug("YouTube search found video ID", "video_id", castMedia.Media.ContentId)
-			}()
-		}
+
+				d.meta.CurrVideoId = contentId
+				d.queryState = QuerySuccess
+				return nil
+			}); err != nil {
+				d.logger.Debug("Halting YouTube search retries.")
+				return
+			}
+			d.logger.Debug("YouTube search found video ID", "video_id", d.meta.CurrVideoId)
+		}()
 	}
 }
 
@@ -399,16 +412,16 @@ func (d *Device) unmuteSegment() {
 }
 
 func (d *Device) querySegments(castMedia *cast.Media) {
-	if castMedia.Media.ContentId == "" {
+	if d.meta.CurrVideoId == "" {
 		return
 	}
 
 	if err := util.Retry(d.ctx, 10, 500*time.Millisecond, func(try uint) (err error) {
-		d.segments, err = sponsorblock.QuerySegments(d.ctx, castMedia.Media.ContentId)
+		d.segments, err = sponsorblock.QuerySegments(d.ctx, d.meta.CurrVideoId)
 		return err
 	}); err == nil {
 		if len(d.segments) == 0 {
-			d.logger.Info("No segments found for video.", "video_id", castMedia.Media.ContentId)
+			d.logger.Info("No segments found for video.", "video_id", d.meta.CurrVideoId)
 		} else {
 			d.logger.Info("Found segments for video.", "segments", len(d.segments))
 		}