diff --git a/cmd/marrano-bot/downloader.go b/cmd/marrano-bot/downloader.go new file mode 100644 index 0000000..8d5e615 --- /dev/null +++ b/cmd/marrano-bot/downloader.go @@ -0,0 +1,79 @@ +package main + +import ( + "context" + "log/slog" + "os" + "path" + "strings" + "time" + + "github.com/moolite/bot/internal/db" + "github.com/moolite/bot/internal/telegram" +) + +const ( + FILE_DOWNLOAD_OK = iota + FILE_DOWNLOAD_ERR +) + +func fileDownloaderWorker(workerId int, filenames <-chan []string) { + for file := range filenames { + fileId := file[0] + filename := file[1] + + if err := telegram.DownloadFileId(Cfg, fileId, filename); err != nil { + slog.Error("File failed to download", "workerId", workerId, "fileId", fileId, "filename", filename, "err", err) + } else { + slog.Info("downloaded file", "workerId", workerId, "filename", filename) + } + } +} + +func SyncFolder(folder string) error { + ctx := context.Background() + dbResults, err := db.SelectAllMedia(ctx) + if err != nil { + return err + } + + slog.Debug("synchronizing files", "files", len(dbResults), "folder", folder) + + fileList, err := os.ReadDir(folder) + if err != nil { + return err + } + + var fileIds [][]string + for _, res := range dbResults { + skip := false + for _, f := range fileList { + skip = strings.Contains(f.Name(), res.Data) + } + + if !skip { + fileExtension := "jpg" + if res.Kind != "photo" { + fileExtension = "m4v" + } + filename := path.Join(folder, res.Data+"."+fileExtension) + fileIds = append(fileIds, []string{res.Data, filename}) + } + } + + slog.Debug("files to download", "files", len(fileIds)) + + limiter := time.Tick(100 * time.Millisecond) + jobs := make(chan []string, len(fileList)) + for w := 1; w <= 5; w++ { + go fileDownloaderWorker(w, jobs) + } + + for _, job := range fileIds { + jobs <- job + <-limiter + } + close(jobs) + + return nil +} diff --git a/cmd/marrano-bot/main.go b/cmd/marrano-bot/main.go index 49d3acf..1590bdc 100644 --- a/cmd/marrano-bot/main.go +++ b/cmd/marrano-bot/main.go @@ -17,13 +17,15 @@ var version string = "0.10.0" var Cfg *config.Config var ( - flagHelp bool - flagDebug bool - flagConfigPath string - flagInit bool - flagDump bool - flagExportDB bool - flagExportDBPath string + flagHelp bool + flagDebug bool + flagConfigPath string + flagInit bool + flagDump bool + flagExportDB bool + flagExportDBPath string + flagSyncMedia bool + flagSyncMediaFolder string ) func setupLogging() { @@ -83,6 +85,7 @@ func main() { pflag.BoolVarP(&flagDump, "dump", "D", false, "dump configuration object") pflag.BoolVarP(&flagExportDB, "export", "E", false, "export database data as csv (defaults to stdout)") pflag.StringVar(&flagExportDBPath, "export-dir", cwd, "folder to write database exported data csv files") + pflag.StringVarP(&flagSyncMediaFolder, "export-media", "M", "", "sync media files to the specified folder.") pflag.Parse() setupLogging() @@ -137,6 +140,16 @@ func main() { return } + if flagSyncMediaFolder != "" { + slog.Info("sync media to folder", "folder", flagSyncMediaFolder) + if err := SyncFolder(flagSyncMediaFolder); err != nil { + slog.Error("error syncronizing media folder", "folder", flagSyncMediaFolder, "err", err) + os.Exit(1) + } + os.Exit(0) + return + } + err = core.Listen(Cfg) if err != nil { slog.Error("server error", "err", err) diff --git a/internal/core/handler_test.go b/internal/core/handler_test.go index e36498b..939998f 100644 --- a/internal/core/handler_test.go +++ b/internal/core/handler_test.go @@ -98,8 +98,10 @@ func TestHandlerRemember(t *testing.T) { "message": { "chat": { "id": `+test.gid+` }, "text": "/ricorda `+test.description+`", - "`+test.kind+`": [{"file_id":"`+test.fileId+`"}, - {"file_id":"wrong"}] + "`+test.kind+`": [ + {"file_id":"`+test.fileId+`"}, + {"file_id":"wrong"} + ] } }`) assert.NilError(t, err) diff --git a/internal/db/media.go b/internal/db/media.go index 4ba6584..530cbee 100644 --- a/internal/db/media.go +++ b/internal/db/media.go @@ -62,8 +62,8 @@ func SelectOneMediaByData(ctx context.Context, m *Media) error { return row.Scan(&m.Data, &m.Description, &m.GID, &m.Kind) } -func SelectAllMedia(ctx context.Context, gid string) ([]Media, error) { - var results []Media +func SelectAllMediaGroup(ctx context.Context, gid string) ([]*Media, error) { + var results []*Media q, err := prepareStmt( `SELECT data,description,gid,kind FROM ` + mediaTable + ` WHERE gid=?`, ) @@ -78,11 +78,39 @@ func SelectAllMedia(ctx context.Context, gid string) ([]Media, error) { defer rows.Close() for rows.Next() { - var m *Media - err = rows.Scan(&m.Data, &m.Description, &m.GID, &m.Kind) - if err != nil { + m := new(Media) + + if err = rows.Scan(&m.Data, &m.Description, &m.GID, &m.Kind); err != nil { + return results, err + } + results = append(results, m) + } + + return results, nil +} + +func SelectAllMedia(ctx context.Context) ([]*Media, error) { + var results []*Media + q, err := prepareStmt( + `SELECT data,description,gid,kind FROM ` + mediaTable, + ) + if err != nil { + return results, err + } + + rows, err := q.Query() + if err != nil { + return results, err + } + defer rows.Close() + + for rows.Next() { + m := new(Media) + + if err = rows.Scan(&m.Data, &m.Description, &m.GID, &m.Kind); err != nil { return results, err } + results = append(results, m) } return results, nil @@ -109,7 +137,7 @@ func SelectRandomMedia(ctx context.Context, m *Media) error { func SearchMedia(ctx context.Context, gid, term string) (*Media, error) { likeTerm := fmt.Sprintf("%%%s%%", term) - var m *Media + m := new(Media) q, err := prepareStmt( `SELECT gid,kind,data,description FROM ` + mediaTable + ` diff --git a/internal/db/migrations/06_statistics.up.sql b/internal/db/migrations/06_statistics.up.sql index e685c9c..5e42d7e 100644 --- a/internal/db/migrations/06_statistics.up.sql +++ b/internal/db/migrations/06_statistics.up.sql @@ -18,8 +18,8 @@ CREATE TABLE IF NOT EXISTS statistics CREATE TRIGGER "statistics_date" AFTER INSERT ON "statistics" BEGIN - UPDATE statistics - SET date = datetime('now') - WHERE rowid = NEW.rowid +UPDATE statistics +SET date = datetime('now') +WHERE rowid = NEW.rowid ; END; diff --git a/internal/telegram/api.go b/internal/telegram/api.go index faf47e7..e12c262 100644 --- a/internal/telegram/api.go +++ b/internal/telegram/api.go @@ -4,28 +4,38 @@ import ( "bytes" "io" "net/http" + "time" ) var ( tgBaseApi string = "https://api.telegram.org/bot" ) -func apiRequest(token string, body []byte) ([]byte, error) { - bodyReader := bytes.NewReader(body) +func apiRequest(token, method string, body []byte) ([]byte, error) { + bodyReader := bytes.NewBuffer(body) req, err := http.NewRequest( "POST", - tgBaseApi+token, + tgBaseApi+token+"/"+method, bodyReader, ) if err != nil { return nil, err } - defer req.Body.Close() + req.Header.Add("Content-Type", "application/json") - body, err = io.ReadAll(req.Body) + client := http.Client{ + Timeout: 60 * time.Second, + } + res, err := client.Do(req) + if err != nil { + return nil, err + } + defer res.Body.Close() + + response, err := io.ReadAll(res.Body) if err != nil { return nil, err } - return body, nil + return response, nil } diff --git a/internal/telegram/media.go b/internal/telegram/media.go new file mode 100644 index 0000000..c497ca1 --- /dev/null +++ b/internal/telegram/media.go @@ -0,0 +1,83 @@ +package telegram + +import ( + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os" + + "github.com/moolite/bot/internal/config" + "github.com/valyala/fastjson" +) + +var ( + tgBaseFileApi string = "https://api.telegram.org/file/bot" +) + +func GetLink(cfg *config.Config, id string) (string, error) { + body := map[string]string{ + "file_id": id, + } + + bodyJson, err := json.Marshal(body) + if err != nil { + return "", err + } + + resp, err := apiRequest(cfg.Telegram.Token, "getFile", bodyJson) + if err != nil { + return "", err + } + + slog.Debug("api request result", "json", string(resp)) + + p, err := fastjson.ParseBytes(resp) + if err != nil { + return "", err + } + + if ok := p.GetBool("ok"); ok != true { + return "", fmt.Errorf("error while fetching result: %s", resp) + } + + fileId := string(p.GetStringBytes("result", "file_path")) + if fileId == "" { + return "", fmt.Errorf("file_path not found in json") + } + + return tgBaseFileApi + cfg.Telegram.Token + "/" + fileId, nil +} + +func DownloadFileId(cfg *config.Config, id, filename string) error { + uri, err := GetLink(cfg, id) + if err != nil { + return err + } + + resp, err := http.Get(uri) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode == 404 { + return fmt.Errorf("404 file not found.") + } + + if resp.StatusCode > 201 { + return fmt.Errorf("error downloading file. StatusCode %d", resp.StatusCode) + } + + fd, err := os.Create(filename) + if err != nil { + return err + } + defer fd.Close() + + if _, err := io.Copy(fd, resp.Body); err != nil { + return err + } + return nil +}