diff --git a/auth/auth.go b/auth/auth.go index 74e7799f..25229ec2 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -22,14 +22,26 @@ import ( ) type Wrapper struct { + credentials []Credential +} + +type Credential struct { username []byte password []byte } +func NewWrapperMultiple(users map[string]string) *Wrapper { + var cs []Credential + for k, v := range users { + u := sha256.Sum256([]byte(k)) + p := sha256.Sum256([]byte(v)) + cs = append(cs, Credential{username: u[:], password: p[:]}) + } + return &Wrapper{credentials: cs} +} + func NewWrapper(username, password string) *Wrapper { - u := sha256.Sum256([]byte(username)) - p := sha256.Sum256([]byte(password)) - return &Wrapper{username: u[:], password: p[:]} + return NewWrapperMultiple(map[string]string{username: password}) } const notAuthorized = "Not Authorized" @@ -58,9 +70,19 @@ func (wrapper *Wrapper) WrapFunc(handlerFunc http.HandlerFunc) http.HandlerFunc func authorized(wrapper *Wrapper, r *http.Request) bool { username, password, isOk := r.BasicAuth() - u := sha256.Sum256([]byte(username)) - p := sha256.Sum256([]byte(password)) - return isOk && - subtle.ConstantTimeCompare(wrapper.username, u[:]) == 1 && - subtle.ConstantTimeCompare(wrapper.password, p[:]) == 1 + if isOk { + u := sha256.Sum256([]byte(username)) + p := sha256.Sum256([]byte(password)) + for _, c := range wrapper.credentials { + if c.isAuthorized(u, p) { + return true + } + } + } + return false +} + +func (c Credential) isAuthorized(uChecksum [32]byte, pChecksum [32]byte) bool { + return subtle.ConstantTimeCompare(c.username, uChecksum[:]) == 1 && + subtle.ConstantTimeCompare(c.password, pChecksum[:]) == 1 } diff --git a/auth/auth_test.go b/auth/auth_test.go index c813ebea..812be5ab 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -80,6 +80,57 @@ var _ = Describe("Auth Wrapper", func() { }) }) + Describe("wrapped multiple handler", func() { + var ( + username2 string + password2 string + credentials map[string]string + wrappedHandler http.Handler + ) + BeforeEach(func() { + username2 = "username2" + password2 = "password2" + credentials = make(map[string]string) + credentials[username] = password + credentials[username2] = password2 + }) + + BeforeEach(func() { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusCreated) + }) + wrappedHandler = auth.NewWrapperMultiple(credentials).Wrap(handler) + }) + + It("works when the credentials are correct", func() { + request := newRequest(username, password) + wrappedHandler.ServeHTTP(httpRecorder, request) + Expect(httpRecorder.Code).To(Equal(http.StatusCreated)) + + request = newRequest(username2, password2) + wrappedHandler.ServeHTTP(httpRecorder, request) + Expect(httpRecorder.Code).To(Equal(http.StatusCreated)) + }) + + It("fails when the username is empty", func() { + request := newRequest("", password) + wrappedHandler.ServeHTTP(httpRecorder, request) + Expect(httpRecorder.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("fails when the password is empty", func() { + request := newRequest(username, "") + wrappedHandler.ServeHTTP(httpRecorder, request) + Expect(httpRecorder.Code).To(Equal(http.StatusUnauthorized)) + }) + + It("fails when the credentials are wrong", func() { + request := newRequest("thats", "apar") + wrappedHandler.ServeHTTP(httpRecorder, request) + Expect(httpRecorder.Code).To(Equal(http.StatusUnauthorized)) + }) + }) + Describe("wrapped handlerFunc", func() { var wrappedHandlerFunc http.HandlerFunc