From dd4cc8f4d450a8d5b0089c6457bfe1e24837ff4d Mon Sep 17 00:00:00 2001 From: "Jeffrey N. Johnson" Date: Thu, 1 Feb 2024 14:01:34 -0800 Subject: [PATCH] Improved logging and error checking for prototype. --- main.go | 30 ++++++++++++++------------- services/prototype.go | 48 +++++++++++++++++++++++++------------------ 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/main.go b/main.go index f33f4429..11fa76f7 100644 --- a/main.go +++ b/main.go @@ -45,7 +45,7 @@ import ( // gives it an endpoint prefix of "docs". To enable these endpoints, you must // use the "docs" build: go build -tags docs -// Prints usage info. +// prints usage info func usage() { fmt.Fprintf(os.Stderr, "%s: usage:\n", os.Args[0]) fmt.Fprintf(os.Stderr, "%s \n", os.Args[0]) @@ -97,16 +97,8 @@ func main() { log.Panicf("Couldn't create the service: %s\n", err.Error()) } - // Start the service in a goroutine so it doesn't block. - go func() { - err = service.Start(config.Service.Port) - if err != nil { - log.Println(err.Error()) - } - }() - - // Intercept the SIGINT, SIGHUP, SIGTERM, and SIGQUIT signals, shutting down - // the service as gracefully as possible if they are encountered. + // intercept the SIGINT, SIGHUP, SIGTERM, and SIGQUIT signals so we can shut + // down the service gracefully if they are encountered sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, @@ -114,14 +106,24 @@ func main() { syscall.SIGTERM, syscall.SIGQUIT) - // Block till we receive one of the above signals. + // start the service in a goroutine so it doesn't block + go func() { + err = service.Start(config.Service.Port) + if err != nil { // on error, log the error message and issue a SIGINT + log.Println(err.Error()) + thisProcess, _ := os.FindProcess(os.Getpid()) + thisProcess.Signal(os.Interrupt) + } + }() + + // block till we receive one of the above signals <-sigChan - // Create a deadline to wait for. + // create a deadline to wait for ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - // Wait for connections to close until the deadline elapses. + // wait for connections to close until the deadline elapses service.Shutdown(ctx) log.Println("Shutting down") os.Exit(0) diff --git a/services/prototype.go b/services/prototype.go index 5654c4d4..0d77d599 100644 --- a/services/prototype.go +++ b/services/prototype.go @@ -7,7 +7,7 @@ import ( "encoding/json" "fmt" "io" - "log" + "log/slog" "net" "net/http" "slices" @@ -98,12 +98,12 @@ func (service *prototype) getRoot(w http.ResponseWriter, _, _, err := getAuthInfo(r.Header) if err != nil { - log.Print(err.Error()) + slog.Error(err.Error()) writeError(w, err.Error(), http.StatusUnauthorized) return } - log.Printf("Querying root endpoint...") + slog.Info("Querying root endpoint...") data := RootResponse{ Name: service.Name, Version: service.Version, @@ -129,12 +129,12 @@ func (service *prototype) getDatabases(w http.ResponseWriter, _, _, err := getAuthInfo(r.Header) if err != nil { - log.Print(err.Error()) + slog.Error(err.Error()) writeError(w, err.Error(), http.StatusUnauthorized) return } - log.Printf("Querying organizational databases...") + slog.Info("Querying organizational databases...") dbs := make([]dbMetadata, 0) for dbName, db := range config.Databases { dbs = append(dbs, dbMetadata{ @@ -156,7 +156,7 @@ func (service *prototype) getDatabase(w http.ResponseWriter, _, _, err := getAuthInfo(r.Header) if err != nil { - log.Print(err.Error()) + slog.Error(err.Error()) writeError(w, err.Error(), http.StatusUnauthorized) return } @@ -164,11 +164,11 @@ func (service *prototype) getDatabase(w http.ResponseWriter, vars := mux.Vars(r) dbName := vars["db"] - log.Printf("Querying database %s...", dbName) + slog.Info(fmt.Sprintf("Querying database %s...", dbName)) db, ok := config.Databases[dbName] if !ok { errStr := fmt.Sprintf("Database %s not found", dbName) - log.Print(errStr) + slog.Error(errStr) writeError(w, errStr, http.StatusNotFound) } else { data, _ := json.Marshal(dbMetadata{ @@ -231,7 +231,7 @@ func (service *prototype) searchDatabase(w http.ResponseWriter, _, orcid, err := getAuthInfo(r.Header) if err != nil { - log.Print(err.Error()) + slog.Error(err.Error()) writeError(w, err.Error(), http.StatusUnauthorized) return } @@ -243,7 +243,7 @@ func (service *prototype) searchDatabase(w http.ResponseWriter, _, ok := config.Databases[dbName] if !ok { errStr := fmt.Sprintf("Database %s not found", dbName) - log.Print(errStr) + slog.Error(errStr) writeError(w, errStr, http.StatusNotFound) return } @@ -255,7 +255,7 @@ func (service *prototype) searchDatabase(w http.ResponseWriter, return } - log.Printf("Searching database %s for files...", dbName) + slog.Info(fmt.Sprintf("Searching database %s for files...", dbName)) db, err := databases.NewDatabase(orcid, dbName) if err != nil { writeError(w, err.Error(), http.StatusNotFound) @@ -334,6 +334,7 @@ func (service *prototype) createTransfer(w http.ResponseWriter, } return } + slog.Info(fmt.Sprintf("Transfer requested: %s", taskId.String())) jsonData, _ := json.Marshal(TransferResponse{Id: taskId}) writeJson(w, jsonData, http.StatusCreated) } @@ -363,7 +364,7 @@ func (service *prototype) getTransferStatus(w http.ResponseWriter, _, _, err := getAuthInfo(r.Header) if err != nil { - log.Print(err.Error()) + slog.Error(err.Error()) writeError(w, err.Error(), http.StatusUnauthorized) return } @@ -402,7 +403,7 @@ func (service *prototype) deleteTransfer(w http.ResponseWriter, _, _, err := getAuthInfo(r.Header) if err != nil { - log.Print(err.Error()) + slog.Error(err.Error()) writeError(w, err.Error(), http.StatusUnauthorized) return } @@ -482,8 +483,8 @@ func NewDTSPrototype() (TransferService, error) { // starts the prototype data transfer service func (service *prototype) Start(port int) error { - log.Printf("Starting %s service on port %d...", service.Name, port) - log.Printf("(Accepting up to %d connections)", config.Service.MaxConnections) + slog.Info(fmt.Sprintf("Starting %s service on port %d...", service.Name, port)) + slog.Info(fmt.Sprintf("(Accepting up to %d connections)", config.Service.MaxConnections)) service.StartTime = time.Now() @@ -497,7 +498,10 @@ func (service *prototype) Start(port int) error { listener = netutil.LimitListener(listener, config.Service.MaxConnections) // start tasks processing - tasks.Start() + err = tasks.Start() + if err != nil { + return err + } // start the server service.Server = &http.Server{ @@ -507,19 +511,23 @@ func (service *prototype) Start(port int) error { // we don't report the server closing as an error if err != http.ErrServerClosed { return err - } else { - return nil } + return nil } // gracefully shuts down the service without interrupting active connections func (service *prototype) Shutdown(ctx context.Context) error { tasks.Stop() - return service.Server.Shutdown(ctx) + if service.Server != nil { + return service.Server.Shutdown(ctx) + } + return nil } // closes down the service abruptly, freeing all resources func (service *prototype) Close() { tasks.Stop() - service.Server.Close() + if service.Server != nil { + service.Server.Close() + } }