diff --git a/README.md b/README.md index 5283244..d9f048f 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ This tool is full configured using environment variables. - `http://localhost:9009/health`: Healthcheck URL, used by the docker healtcheck. - `http://localhost:9009/saml/acs`: SAML ACS URL, needed to configure your IdP. - `http://localhost:9009/saml/metadata`: SAML Metadata URL, needed to configure your IdP. +- `http://localhost:9009/saml/logout`: SAML Logout URL. - `http://localhost:9009/`: Test URL, redirects to SAML SSO URL. ## Configuration diff --git a/go.mod b/go.mod index 5a17f7c..a822833 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module beryju.org/saml-test-sp -go 1.13 +go 1.16 require ( github.com/crewjam/saml v0.4.6 diff --git a/pkg/server/root.go b/pkg/server/root.go index d445456..c00c632 100644 --- a/pkg/server/root.go +++ b/pkg/server/root.go @@ -9,16 +9,24 @@ import ( log "github.com/sirupsen/logrus" "beryju.org/saml-test-sp/pkg/helpers" + "github.com/crewjam/saml" "github.com/crewjam/saml/samlsp" ) -func hello(w http.ResponseWriter, r *http.Request) { - s := samlsp.SessionFromContext(r.Context()) +type Server struct { + m *samlsp.Middleware + h *http.ServeMux + l *log.Entry + b string +} + +func (s *Server) hello(w http.ResponseWriter, r *http.Request) { + sa := samlsp.SessionFromContext(r.Context()) if s == nil { http.Error(w, "No Session", http.StatusInternalServerError) return } - sa, ok := s.(samlsp.SessionWithAttributes) + sa, ok := sa.(samlsp.SessionWithAttributes) if !ok { http.Error(w, "Session has no attributes", http.StatusInternalServerError) return @@ -32,18 +40,56 @@ func hello(w http.ResponseWriter, r *http.Request) { w.Write(data) } -func health(w http.ResponseWriter, r *http.Request) { +func (s *Server) health(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) fmt.Fprint(w, "hello :)") } -func logRequest(handler http.Handler) http.Handler { +func (s *Server) logRequest(handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.WithField("remoteAddr", r.RemoteAddr).WithField("method", r.Method).Info(r.URL) + s.l.WithField("remoteAddr", r.RemoteAddr).WithField("method", r.Method).Info(r.URL) handler.ServeHTTP(w, r) }) } +func (s *Server) logout(w http.ResponseWriter, r *http.Request) { + nameID := samlsp.AttributeFromContext(r.Context(), "urn:oasis:names:tc:SAML:attribute:subject-id") + var binding *saml.Endpoint + for _, desc := range s.m.ServiceProvider.IDPMetadata.IDPSSODescriptors { + for _, slo := range desc.SingleLogoutServices { + s.l.WithField("slo", slo.Binding).Info("found SLO binding") + binding = &slo + break + } + } + if binding == nil { + s.l.Warning("no SLO descriptors found, aborting") + w.WriteHeader(400) + return + } + + if binding.Binding == saml.HTTPRedirectBinding { + url, err := s.m.ServiceProvider.MakeRedirectLogoutRequest(nameID, s.b+"/health") + if err != nil { + s.l.WithError(err).Warning("failed to make redirect logout") + } + http.Redirect(w, r, url.String(), http.StatusFound) + } else if binding.Binding == saml.HTTPPostBinding { + res, err := s.m.ServiceProvider.MakePostLogoutRequest(nameID, s.b+"/health") + if err != nil { + s.l.WithError(err).Warning("failed to make post logout") + } + w.Header().Set("Content-Type", "text/html") + w.Write(res) + } else { + http.Error(w, "invalid binding", 500) + } + err := s.m.Session.DeleteSession(w, r) + if err != nil { + s.l.WithError(err).Warning("failed to delete session") + } +} + func RunServer() { config := helpers.LoadConfig() @@ -52,22 +98,29 @@ func RunServer() { if err != nil { panic(err) } - http.Handle("/", samlSP.RequireAccount(http.HandlerFunc(hello))) - http.Handle("/saml/", samlSP) - http.HandleFunc("/health", health) + server := Server{ + m: samlSP, + h: http.NewServeMux(), + l: log.WithField("component", "server"), + } + server.h.Handle("/", samlSP.RequireAccount(http.HandlerFunc(server.hello))) + server.h.Handle("/saml/logout", samlSP.RequireAccount(http.HandlerFunc(server.logout))) + server.h.Handle("/saml/", samlSP) + server.h.HandleFunc("/health", server.health) listen := helpers.Env("SP_BIND", "localhost:9009") - log.Infof("Server listening on '%s'", listen) - log.Infof("ACS URL is '%s'", samlSP.ServiceProvider.AcsURL.String()) + server.b = listen + server.l.Infof("Server listening on '%s'", listen) + server.l.Infof("ACS URL is '%s'", samlSP.ServiceProvider.AcsURL.String()) if _, set := os.LookupEnv("SP_SSL_CERT"); set { // SP_SSL_CERT set, so we run SSL mode - err := http.ListenAndServeTLS(listen, os.Getenv("SP_SSL_CERT"), os.Getenv("SP_SSL_KEY"), logRequest(http.DefaultServeMux)) + err := http.ListenAndServeTLS(listen, os.Getenv("SP_SSL_CERT"), os.Getenv("SP_SSL_KEY"), server.logRequest(server.h)) if err != nil { panic(err) } } else { - err = http.ListenAndServe(listen, logRequest(http.DefaultServeMux)) + err = http.ListenAndServe(listen, server.logRequest(server.h)) if err != nil { panic(err) }