Skip to content

Commit

Permalink
Persistent database states now seem to work. Probably still need more…
Browse files Browse the repository at this point in the history
… testing.
  • Loading branch information
jeff-cohere committed Nov 28, 2024
1 parent 055fbad commit af101c3
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 38 deletions.
29 changes: 19 additions & 10 deletions databases/databases.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,20 @@ type Database interface {
Load(state DatabaseSaveState) error
}

// data representing a saved database state -- useful for service restarts
// represents a saved database state (for service restarts)
type DatabaseSaveState struct {
// database name
Name string
// serialized database in bytes
Data []byte
}

// represents a collection of saved database states
type DatabaseSaveStates struct {
// mapping of orcid/database keys to database save states
Data map[string]DatabaseSaveState
}

// parameters that define a search for files
type SearchParameters struct {
// ElasticSearch query string
Expand Down Expand Up @@ -149,27 +155,30 @@ func NewDatabase(orcid, dbName string) (Database, error) {

// saves the internal states of all resident databases, returning a map to
// their save states
func Save() (map[string]DatabaseSaveState, error) {
states := make(map[string]DatabaseSaveState)
func Save() (DatabaseSaveStates, error) {
states := DatabaseSaveStates{
Data: make(map[string]DatabaseSaveState),
}
for key, db := range allDatabases_ {
saveState, err := db.Save()
if err != nil {
return nil, err
return states, err
}
states[key] = saveState
states.Data[key] = saveState
}
return states, nil
}

// loads a previously saved map of save states for all databases, restoring
// their previous states
func Load(states map[string]DatabaseSaveState) error {
for key, state := range states {
orcidIndex := strings.LastIndex(key, "db: ") + 3
if orcidIndex == 2 {
func Load(states DatabaseSaveStates) error {
for key, state := range states.Data {
start := strings.Index(key, "orcid: ") + 8
end := strings.Index(key, "db: ") - 1
if start == 7 || end == -2 {
return fmt.Errorf("Couldn't load saved state for database '%s'", state.Name)
}
orcid := key[orcidIndex:]
orcid := key[start:end]
db, err := NewDatabase(orcid, state.Name)
if err != nil {
return err
Expand Down
48 changes: 34 additions & 14 deletions databases/jdp/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"slices"
"strconv"
"strings"
"time"
"unicode"

"github.com/google/uuid"
Expand All @@ -59,8 +60,14 @@ type Database struct {
// SSO token used for interim JDP access
SsoToken string
// mapping from staging UUIDs to JDP restoration request ID
// FIXME: not persistent between service restarts!!
StagingIds map[uuid.UUID]int
StagingRequests map[uuid.UUID]StagingRequest
}

type StagingRequest struct {
// JDP staging request ID
Id int
// time of staging request (for purging)
Time time.Time
}

func NewDatabase(orcid string) (databases.Database, error) {
Expand All @@ -86,11 +93,11 @@ func NewDatabase(orcid string) (databases.Database, error) {
}

return &Database{
Id: "jdp",
Orcid: orcid,
Secret: secret,
SsoToken: os.Getenv("DTS_JDP_SSO_TOKEN"),
StagingIds: make(map[uuid.UUID]int),
Id: "jdp",
Orcid: orcid,
Secret: secret,
SsoToken: os.Getenv("DTS_JDP_SSO_TOKEN"),
StagingRequests: make(map[uuid.UUID]StagingRequest),
}, nil
}

Expand Down Expand Up @@ -279,7 +286,10 @@ func (db *Database) StageFiles(fileIds []string) (uuid.UUID, error) {
slog.Debug(fmt.Sprintf("Requested %d archived files from JDP (request ID: %d)",
len(fileIds), jdpResp.RequestId))
xferId = uuid.New()
db.StagingIds[xferId] = jdpResp.RequestId
db.StagingRequests[xferId] = StagingRequest{
Id: jdpResp.RequestId,
Time: time.Now(),
}
return xferId, err
case 404:
return xferId, databases.ResourceNotFoundError{
Expand All @@ -292,9 +302,9 @@ func (db *Database) StageFiles(fileIds []string) (uuid.UUID, error) {
}

func (db *Database) StagingStatus(id uuid.UUID) (databases.StagingStatus, error) {
// FIXME: db.StagingIds is not persistent between service restarts!!
if restoreId, found := db.StagingIds[id]; found {
resource := fmt.Sprintf("request_archived_files/requests/%d", restoreId)
db.pruneStagingRequests()
if request, found := db.StagingRequests[id]; found {
resource := fmt.Sprintf("request_archived_files/requests/%d", request.Id)
resp, err := db.get(resource, url.Values{})
if err != nil {
return databases.StagingStatusUnknown, err
Expand Down Expand Up @@ -336,19 +346,19 @@ func (db *Database) LocalUser(orcid string) (string, error) {
func (db Database) Save() (databases.DatabaseSaveState, error) {
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
err := enc.Encode(db.StagingIds)
err := enc.Encode(db.StagingRequests)
if err != nil {
return databases.DatabaseSaveState{}, err
}
return databases.DatabaseSaveState{
Name: "NMDC",
Name: "jdp",
Data: buffer.Bytes(),
}, nil
}

func (db *Database) Load(state databases.DatabaseSaveState) error {
enc := gob.NewDecoder(bytes.NewReader(state.Data))
return enc.Decode(&db.StagingIds)
return enc.Decode(&db.StagingRequests)
}

//--------------------
Expand Down Expand Up @@ -853,3 +863,13 @@ func (db Database) addSpecificSearchParameters(params map[string]json.RawMessage
}
return nil
}

func (db *Database) pruneStagingRequests() {
deleteAfter := time.Duration(config.Service.DeleteAfter) * time.Second
for uuid, request := range db.StagingRequests {
requestAge := time.Since(request.Time)
if requestAge > deleteAfter {
delete(db.StagingRequests, uuid)
}
}
}
12 changes: 12 additions & 0 deletions databases/kbase/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,15 @@ func (db *Database) LocalUser(orcid string) (string, error) {
// auth server proxy
return auth.KBaseLocalUsernameForOrcid(orcid)
}

func (db Database) Save() (databases.DatabaseSaveState, error) {
// so far, this database has no internal state
return databases.DatabaseSaveState{
Name: "kbase",
}, nil
}

func (db *Database) Load(state databases.DatabaseSaveState) error {
// no internal state -> nothing to do
return nil
}
2 changes: 1 addition & 1 deletion databases/nmdc/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (db Database) LocalUser(orcid string) (string, error) {
func (db Database) Save() (databases.DatabaseSaveState, error) {
// so far, this database has no internal state
return databases.DatabaseSaveState{
Name: "NMDC",
Name: "nmdc",
}, nil
}

Expand Down
23 changes: 10 additions & 13 deletions tasks/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,17 @@ func createOrLoadTasks(dataFile string) map[uuid.UUID]transferTask {
defer file.Close()
enc := gob.NewDecoder(file)
var tasks map[uuid.UUID]transferTask
err = enc.Decode(&tasks)
if err == nil {
var databaseStates map[string]DatabaseSaveState
var databaseStates databases.DatabaseSaveStates
if err = enc.Decode(&tasks); err == nil {
err = enc.Decode(&databaseStates)
if err == nil {
err = databases.Load(databaseStates)
}
}
if err != nil { // file not readable
slog.Error(fmt.Sprintf("Reading task file %s: %s", dataFile, err.Error()))
return make(map[uuid.UUID]transferTask)
}
if err = databases.Load(databaseStates); err != nil {
slog.Error(fmt.Sprintf("Restoring database states: %s", err.Error()))
}
slog.Debug(fmt.Sprintf("Restored %d tasks from %s", len(tasks), dataFile))
return tasks
}
Expand All @@ -282,10 +281,11 @@ func saveTasks(tasks map[uuid.UUID]transferTask, dataFile string) error {
return fmt.Errorf("Opening task file %s: %s", dataFile, err.Error())
}
enc := gob.NewEncoder(file)
err = enc.Encode(tasks)
if err == nil {
databaseStates := databases.Save()
err = enc.Encode(databaseStates)
if err = enc.Encode(tasks); err == nil {
var databaseStates databases.DatabaseSaveStates
if databaseStates, err = databases.Save(); err == nil {
err = enc.Encode(databaseStates)
}
}
if err != nil {
file.Close()
Expand Down Expand Up @@ -375,7 +375,6 @@ func processTasks() {
case <-pollChan: // time to move things along
for taskId, task := range tasks {
if !task.Completed() {
slog.Debug(fmt.Sprintf("Task %s is incomplete, proceeding...", taskId.String()))
oldStatus := task.Status
err := task.Update()
if err != nil {
Expand Down Expand Up @@ -404,8 +403,6 @@ func processTasks() {
slog.Info(fmt.Sprintf("Task %s: failed", task.Id.String()))
}
}
} else {
slog.Debug(fmt.Sprintf("Task %s is complete.", taskId.String()))
}

// if the task completed a long enough time go, delete its entry
Expand Down

0 comments on commit af101c3

Please sign in to comment.