Skip to content

Commit

Permalink
Separated "user" and "client" concepts for DTS.
Browse files Browse the repository at this point in the history
This PR breaks our `auth.UserInfo` into `auth.User` and `auth.Client` types in order to distinguish
between the rights/privileges of a DTS client used by one or more users, and each user. This allows
us to easily associate a specific user (via ORCID) with each transfer request, and allows us to retain
client-specific data (like active database proxies) that aren't tied to individual users.

For now, if no user ORCID is specified by a transfer request, the DTS falls back to the client's ORCID.
We'll remove this fallback when existing services have user-specific transfer ORCIDs in place within
transfer requests.

Closes #81
  • Loading branch information
jeff-cohere committed Dec 13, 2024
1 parent e1823b9 commit bfd8b36
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 97 deletions.
56 changes: 29 additions & 27 deletions auth/kbase_auth_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,18 @@ func NewKBaseAuthServer(accessToken string) (*KBaseAuthServer, error) {
AccessToken: accessToken,
}

// verify that the access token works (i.e. that the user is logged in)
userInfo, err := server.kbaseUserInfo()
// verify that the access token works (i.e. that the client is logged in)
kbaseUser, err := server.kbaseUser()
if err != nil {
return nil, err
}

// register the local username under all its ORCIDs with our KBase user
// federation mechanism
for _, pid := range userInfo.Idents {
for _, pid := range kbaseUser.Idents {
if pid.Provider == "OrcID" {
orcid := pid.UserName
err = SetKBaseLocalUsernameForOrcid(orcid, userInfo.Username)
err = SetKBaseLocalUsernameForOrcid(orcid, kbaseUser.Username)
if err != nil {
break
}
Expand All @@ -84,23 +84,25 @@ func NewKBaseAuthServer(accessToken string) (*KBaseAuthServer, error) {
}
}

// returns a normalized user info record for the current KBase user
func (server KBaseAuthServer) UserInfo() (UserInfo, error) {
kbUserInfo, err := server.kbaseUserInfo()
// returns a normalized user record for the current KBase user
func (server KBaseAuthServer) Client() (Client, error) {
kbUser, err := server.kbaseUser()
if err != nil {
return UserInfo{}, err
return Client{}, err
}
userInfo := UserInfo{
Name: kbUserInfo.Display,
Username: kbUserInfo.Username,
Email: kbUserInfo.Email,
client := Client{
Name: kbUser.Display,
Username: kbUser.Username,
Email: kbUser.Email,
}
for _, pid := range kbUserInfo.Idents {
for _, pid := range kbUser.Idents {
// grab the first ORCID associated with the user
if pid.Provider == "OrcID" {
userInfo.Orcid = pid.UserName
client.Orcid = pid.UserName
break
}
}
return userInfo, nil
return client, nil
}

//-----------
Expand All @@ -113,7 +115,7 @@ const (

// a record containing information about a user logged into the KBase Auth2
// server
type kbaseUserInfo struct {
type kbaseUser struct {
// KBase username
Username string `json:"user"`
// KBase user display name
Expand Down Expand Up @@ -193,38 +195,38 @@ func (server KBaseAuthServer) get(resource string) (*http.Response, error) {
}

// returns information for the current KBase user accessing the auth server
func (server KBaseAuthServer) kbaseUserInfo() (kbaseUserInfo, error) {
var userInfo kbaseUserInfo
func (server KBaseAuthServer) kbaseUser() (kbaseUser, error) {
var user kbaseUser
resp, err := server.get("me")
if err != nil {
return userInfo, err
return user, err
}
if resp.StatusCode != 200 {
err = kbaseAuthError(resp)
if err != nil {
return userInfo, err
return user, err
}
}
var body []byte
body, err = io.ReadAll(resp.Body)
if err != nil {
return userInfo, err
return user, err
}
err = json.Unmarshal(body, &userInfo)
err = json.Unmarshal(body, &user)

// make sure we have at least one ORCID for this user
if len(userInfo.Idents) < 1 {
return userInfo, fmt.Errorf("KBase Auth2: No providers associated with this user!")
if len(user.Idents) < 1 {
return user, fmt.Errorf("KBase Auth2: No providers associated with this user!")
}
foundOrcid := false
for _, pid := range userInfo.Idents {
for _, pid := range user.Idents {
if pid.Provider == "OrcID" {
foundOrcid = true
break
}
}
if !foundOrcid {
return userInfo, fmt.Errorf("KBase Auth2: No ORCID IDs associated with this user!")
return user, fmt.Errorf("KBase Auth2: No ORCID IDs associated with this user!")
}
return userInfo, err
return user, err
}
12 changes: 6 additions & 6 deletions auth/kbase_auth_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ func TestInvalidToken(t *testing.T) {
}

// tests whether the authentication server can return information for the
// user associated with the specified developer token
func TestUserInfo(t *testing.T) {
// client (the user associated with the specified developer token)
func TestClient(t *testing.T) {
assert := assert.New(t)
devToken := os.Getenv("DTS_KBASE_DEV_TOKEN")
server, _ := NewKBaseAuthServer(devToken)
assert.NotNil(server)
userInfo, err := server.UserInfo()
client, err := server.Client()
assert.Nil(err)

assert.True(len(userInfo.Username) > 0)
assert.True(len(userInfo.Email) > 0)
assert.Equal(os.Getenv("DTS_KBASE_TEST_ORCID"), userInfo.Orcid)
assert.True(len(client.Username) > 0)
assert.True(len(client.Email) > 0)
assert.Equal(os.Getenv("DTS_KBASE_TEST_ORCID"), client.Orcid)
}
56 changes: 30 additions & 26 deletions auth/kbase_user_federation.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,36 @@ import (
// users who have made requests to the DTS. This prevents us from having to
// rely on a secondary data source for this information.

// associates the given KBase username with the given ORCID ID
func SetKBaseLocalUsernameForOrcid(orcid, username string) error {
if !kbaseUserFederationStarted {
// fire it up!
started := make(chan struct{})
go kbaseUserFederation(started)

// wait for it to start
<-started
}
kbaseOrcidUserChan <- [2]string{orcid, username}
err := <-kbaseErrorChan
return err
}

// returns the local KBase username associated with the given ORCID ID
func KBaseLocalUsernameForOrcid(orcid string) (string, error) {
if !kbaseUserFederationStarted { // no one's logged in!
return "", fmt.Errorf("KBase federated user table not available!")
}
kbaseOrcidChan <- orcid
username := <-kbaseUserChan
err := <-kbaseErrorChan
return username, err
}

//-----------
// Internals
//-----------

var kbaseUserFederationStarted = false
var kbaseOrcidChan chan string // passes ORCIDs in
var kbaseOrcidUserChan chan [2]string // passes (ORCIDs, username) pairs in
Expand Down Expand Up @@ -80,29 +110,3 @@ func kbaseUserFederation(started chan struct{}) {
}
}
}

// associates the given KBase username with the given ORCID ID
func SetKBaseLocalUsernameForOrcid(orcid, username string) error {
if !kbaseUserFederationStarted {
// fire it up!
started := make(chan struct{})
go kbaseUserFederation(started)

// wait for it to start
<-started
}
kbaseOrcidUserChan <- [2]string{orcid, username}
err := <-kbaseErrorChan
return err
}

// returns the local KBase username associated with the given ORCID ID
func KBaseLocalUsernameForOrcid(orcid string) (string, error) {
if !kbaseUserFederationStarted { // no one's logged in!
return "", fmt.Errorf("KBase federated user table not available!")
}
kbaseOrcidChan <- orcid
username := <-kbaseUserChan
err := <-kbaseErrorChan
return username, err
}
26 changes: 21 additions & 5 deletions auth/user_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,29 @@

package auth

// a record containing information about a DTS user
type UserInfo struct {
// user's name (human-readable and display-friendly)
// A record containing information about a DTS client. A DTS client is a KBase
// user whose KBase developer token is used to authorize with the DTS.
type Client struct {
// client name (human-readable and display-friendly)
Name string
// username used to access DTS
// KBase username used by client to access DTS
Username string
// user's email address
// client email address
Email string
// ORCID identifier associated with this client
Orcid string
// organization with which this client is affiliated
Organization string
}

// A record containing information about a DTS user using a DTS client to
// request file transfers. A DTS user need not have a KBase developer token
// (but should have a KBase account if they are requesting files be transferred
// to KBase).
type User struct {
// client name (human-readable and display-friendly)
Name string
// client email address
Email string
// ORCID identifier associated with this user
Orcid string
Expand Down
52 changes: 37 additions & 15 deletions services/prototype.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,28 +150,32 @@ var version = fmt.Sprintf("%d.%d.%d", majorVersion, minorVersion, patchVersion)
// authorize clients for the DTS, returning information about the user
// corresponding to the token in the header (or an error describing any issue
// encountered)
func authorize(authorizationHeader string) (auth.UserInfo, error) {
func authorize(authorizationHeader string) (auth.Client, error) {
if !strings.Contains(authorizationHeader, "Bearer") {
return auth.UserInfo{}, fmt.Errorf("Invalid authorization header")
return auth.Client{}, fmt.Errorf("Invalid authorization header")
}
b64Token := authorizationHeader[len("Bearer "):]
accessTokenBytes, err := base64.StdEncoding.DecodeString(b64Token)
if err != nil {
return auth.UserInfo{}, huma.Error401Unauthorized(err.Error())
return auth.Client{}, huma.Error401Unauthorized(err.Error())
}
accessToken := strings.TrimSpace(string(accessTokenBytes))

// check the access token against the KBase auth server
// and return info about the corresponding user
authServer, err := auth.NewKBaseAuthServer(accessToken)
if err != nil {
return auth.UserInfo{}, huma.Error401Unauthorized(err.Error())
return auth.Client{}, huma.Error401Unauthorized(err.Error())
}
userInfo, err := authServer.UserInfo()
client, err := authServer.Client()
if err != nil {
return userInfo, huma.Error401Unauthorized(err.Error())
return client, huma.Error401Unauthorized(err.Error())
}
return userInfo, nil
// the client needs at least one associated ORCID
if client.Orcid == "" {
return client, huma.Error403Forbidden("The DTS client has no associated ORCID!")
}
return client, nil
}

type ServiceInfoOutput struct {
Expand Down Expand Up @@ -342,7 +346,7 @@ func (service *prototype) getDatabaseSearchParameters(ctx context.Context,
Database string `path:"db" example:"jdp" doc:"the abbreviated name of a database"`
}) (*SearchParametersOutput, error) {

userInfo, err := authorize(input.Authorization)
client, err := authorize(input.Authorization)
if err != nil {
return nil, err
}
Expand All @@ -352,7 +356,7 @@ func (service *prototype) getDatabaseSearchParameters(ctx context.Context,
if !ok {
return nil, fmt.Errorf("Database %s not found", input.Database)
}
db, err := databases.NewDatabase(userInfo.Orcid, input.Database)
db, err := databases.NewDatabase(client.Orcid, input.Database)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -407,7 +411,7 @@ func searchDatabase(_ context.Context,
input *SearchDatabaseInput,
specific map[string]json.RawMessage) (*SearchResultsOutput, error) {

userInfo, err := authorize(input.Authorization)
client, err := authorize(input.Authorization)
if err != nil {
return nil, err
}
Expand All @@ -432,7 +436,7 @@ func searchDatabase(_ context.Context,
}

slog.Info(fmt.Sprintf("Searching database %s for files...", input.Database))
db, err := databases.NewDatabase(userInfo.Orcid, input.Database)
db, err := databases.NewDatabase(client.Orcid, input.Database)
if err != nil {
return nil, databaseError(err)
}
Expand Down Expand Up @@ -508,7 +512,7 @@ func (service *prototype) fetchFileMetadata(ctx context.Context,
Limit int `json:"limit" query:"limit" example:"50" doc:"Limits the number of metadata records returned"`
}) (*FileMetadataOutput, error) {

userInfo, err := authorize(input.Authorization)
client, err := authorize(input.Authorization)
if err != nil {
return nil, err
}
Expand All @@ -527,7 +531,7 @@ func (service *prototype) fetchFileMetadata(ctx context.Context,

slog.Info(fmt.Sprintf("Fetching file metadata for %d files in database %s...",
len(ids), input.Database))
db, err := databases.NewDatabase(userInfo.Orcid, input.Database)
db, err := databases.NewDatabase(client.Orcid, input.Database)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -558,13 +562,31 @@ func (service *prototype) createTransfer(ctx context.Context,
ContentType string `header:"Content-Type" doc:"Content-Type header (must be application/json)"`
}) (*TransferOutput, error) {

userInfo, err := authorize(input.Authorization)
client, err := authorize(input.Authorization)
if err != nil {
return nil, err
}

// fetch information about the requesting user
var user auth.User
if input.Body.Orcid != "" {
// FIXME: we just extract the ORCID at the moment
// FIXME: we should get the other stuff from the ORCID public API
user.Orcid = input.Body.Orcid
} else {
// FIXME: for now, while we're in transition, we can fall back to the client's
// FIXME: info if a user ORCID is not provided
user = auth.User{
Name: client.Name,
Email: client.Email,
Orcid: client.Orcid,
Organization: client.Organization,
}
}

taskId, err := tasks.Create(tasks.Specification{
UserInfo: userInfo,
Client: client,
User: user,
Source: input.Body.Source,
Destination: input.Body.Destination,
FileIds: input.Body.FileIds,
Expand Down
2 changes: 2 additions & 0 deletions services/transfer_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ type FileMetadataResponse struct {

// a request for a file transfer (POST)
type TransferRequest struct {
// user ORCID
Orcid string `json:"orcid" example:"0000-0002-9227-8514" doc:"ORCID for user requesting transfer"`
// name of source database
Source string `json:"source" example:"jdp" doc:"source database identifier"`
// identifiers for files to be transferred
Expand Down
Loading

0 comments on commit bfd8b36

Please sign in to comment.