Skip to content

Commit

Permalink
Improved logging and error checking for prototype.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeff-cohere committed Feb 1, 2024
1 parent 37f5a14 commit dd4cc8f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 34 deletions.
30 changes: 16 additions & 14 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import (
// gives it an endpoint prefix of "docs". To enable these endpoints, you must
// use the "docs" build: go build -tags docs

// Prints usage info.
// prints usage info
func usage() {
fmt.Fprintf(os.Stderr, "%s: usage:\n", os.Args[0])
fmt.Fprintf(os.Stderr, "%s <config_file>\n", os.Args[0])
Expand Down Expand Up @@ -97,31 +97,33 @@ func main() {
log.Panicf("Couldn't create the service: %s\n", err.Error())
}

// Start the service in a goroutine so it doesn't block.
go func() {
err = service.Start(config.Service.Port)
if err != nil {
log.Println(err.Error())
}
}()

// Intercept the SIGINT, SIGHUP, SIGTERM, and SIGQUIT signals, shutting down
// the service as gracefully as possible if they are encountered.
// intercept the SIGINT, SIGHUP, SIGTERM, and SIGQUIT signals so we can shut
// down the service gracefully if they are encountered
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan,
syscall.SIGINT,
syscall.SIGHUP,
syscall.SIGTERM,
syscall.SIGQUIT)

// Block till we receive one of the above signals.
// start the service in a goroutine so it doesn't block
go func() {
err = service.Start(config.Service.Port)
if err != nil { // on error, log the error message and issue a SIGINT
log.Println(err.Error())
thisProcess, _ := os.FindProcess(os.Getpid())
thisProcess.Signal(os.Interrupt)
}
}()

// block till we receive one of the above signals
<-sigChan

// Create a deadline to wait for.
// create a deadline to wait for
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

// Wait for connections to close until the deadline elapses.
// wait for connections to close until the deadline elapses
service.Shutdown(ctx)
log.Println("Shutting down")
os.Exit(0)
Expand Down
48 changes: 28 additions & 20 deletions services/prototype.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"log/slog"
"net"
"net/http"
"slices"
Expand Down Expand Up @@ -98,12 +98,12 @@ func (service *prototype) getRoot(w http.ResponseWriter,

_, _, err := getAuthInfo(r.Header)
if err != nil {
log.Print(err.Error())
slog.Error(err.Error())
writeError(w, err.Error(), http.StatusUnauthorized)
return
}

log.Printf("Querying root endpoint...")
slog.Info("Querying root endpoint...")
data := RootResponse{
Name: service.Name,
Version: service.Version,
Expand All @@ -129,12 +129,12 @@ func (service *prototype) getDatabases(w http.ResponseWriter,

_, _, err := getAuthInfo(r.Header)
if err != nil {
log.Print(err.Error())
slog.Error(err.Error())
writeError(w, err.Error(), http.StatusUnauthorized)
return
}

log.Printf("Querying organizational databases...")
slog.Info("Querying organizational databases...")
dbs := make([]dbMetadata, 0)
for dbName, db := range config.Databases {
dbs = append(dbs, dbMetadata{
Expand All @@ -156,19 +156,19 @@ func (service *prototype) getDatabase(w http.ResponseWriter,

_, _, err := getAuthInfo(r.Header)
if err != nil {
log.Print(err.Error())
slog.Error(err.Error())
writeError(w, err.Error(), http.StatusUnauthorized)
return
}

vars := mux.Vars(r)
dbName := vars["db"]

log.Printf("Querying database %s...", dbName)
slog.Info(fmt.Sprintf("Querying database %s...", dbName))
db, ok := config.Databases[dbName]
if !ok {
errStr := fmt.Sprintf("Database %s not found", dbName)
log.Print(errStr)
slog.Error(errStr)
writeError(w, errStr, http.StatusNotFound)
} else {
data, _ := json.Marshal(dbMetadata{
Expand Down Expand Up @@ -231,7 +231,7 @@ func (service *prototype) searchDatabase(w http.ResponseWriter,

_, orcid, err := getAuthInfo(r.Header)
if err != nil {
log.Print(err.Error())
slog.Error(err.Error())
writeError(w, err.Error(), http.StatusUnauthorized)
return
}
Expand All @@ -243,7 +243,7 @@ func (service *prototype) searchDatabase(w http.ResponseWriter,
_, ok := config.Databases[dbName]
if !ok {
errStr := fmt.Sprintf("Database %s not found", dbName)
log.Print(errStr)
slog.Error(errStr)
writeError(w, errStr, http.StatusNotFound)
return
}
Expand All @@ -255,7 +255,7 @@ func (service *prototype) searchDatabase(w http.ResponseWriter,
return
}

log.Printf("Searching database %s for files...", dbName)
slog.Info(fmt.Sprintf("Searching database %s for files...", dbName))
db, err := databases.NewDatabase(orcid, dbName)
if err != nil {
writeError(w, err.Error(), http.StatusNotFound)
Expand Down Expand Up @@ -334,6 +334,7 @@ func (service *prototype) createTransfer(w http.ResponseWriter,
}
return
}
slog.Info(fmt.Sprintf("Transfer requested: %s", taskId.String()))
jsonData, _ := json.Marshal(TransferResponse{Id: taskId})
writeJson(w, jsonData, http.StatusCreated)
}
Expand Down Expand Up @@ -363,7 +364,7 @@ func (service *prototype) getTransferStatus(w http.ResponseWriter,

_, _, err := getAuthInfo(r.Header)
if err != nil {
log.Print(err.Error())
slog.Error(err.Error())
writeError(w, err.Error(), http.StatusUnauthorized)
return
}
Expand Down Expand Up @@ -402,7 +403,7 @@ func (service *prototype) deleteTransfer(w http.ResponseWriter,

_, _, err := getAuthInfo(r.Header)
if err != nil {
log.Print(err.Error())
slog.Error(err.Error())
writeError(w, err.Error(), http.StatusUnauthorized)
return
}
Expand Down Expand Up @@ -482,8 +483,8 @@ func NewDTSPrototype() (TransferService, error) {

// starts the prototype data transfer service
func (service *prototype) Start(port int) error {
log.Printf("Starting %s service on port %d...", service.Name, port)
log.Printf("(Accepting up to %d connections)", config.Service.MaxConnections)
slog.Info(fmt.Sprintf("Starting %s service on port %d...", service.Name, port))
slog.Info(fmt.Sprintf("(Accepting up to %d connections)", config.Service.MaxConnections))

service.StartTime = time.Now()

Expand All @@ -497,7 +498,10 @@ func (service *prototype) Start(port int) error {
listener = netutil.LimitListener(listener, config.Service.MaxConnections)

// start tasks processing
tasks.Start()
err = tasks.Start()
if err != nil {
return err
}

// start the server
service.Server = &http.Server{
Expand All @@ -507,19 +511,23 @@ func (service *prototype) Start(port int) error {
// we don't report the server closing as an error
if err != http.ErrServerClosed {
return err
} else {
return nil
}
return nil
}

// gracefully shuts down the service without interrupting active connections
func (service *prototype) Shutdown(ctx context.Context) error {
tasks.Stop()
return service.Server.Shutdown(ctx)
if service.Server != nil {
return service.Server.Shutdown(ctx)
}
return nil
}

// closes down the service abruptly, freeing all resources
func (service *prototype) Close() {
tasks.Stop()
service.Server.Close()
if service.Server != nil {
service.Server.Close()
}
}

0 comments on commit dd4cc8f

Please sign in to comment.