Skip to content

Commit

Permalink
add basic SLO support
Browse files Browse the repository at this point in the history
  • Loading branch information
BeryJu committed Feb 21, 2022
1 parent 0bf99f3 commit ae03e21
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 14 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ This tool is full configured using environment variables.
- `http://localhost:9009/health`: Healthcheck URL, used by the docker healtcheck.
- `http://localhost:9009/saml/acs`: SAML ACS URL, needed to configure your IdP.
- `http://localhost:9009/saml/metadata`: SAML Metadata URL, needed to configure your IdP.
- `http://localhost:9009/saml/logout`: SAML Logout URL.
- `http://localhost:9009/`: Test URL, redirects to SAML SSO URL.

## Configuration
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module beryju.org/saml-test-sp

go 1.13
go 1.16

require (
github.com/crewjam/saml v0.4.6
Expand Down
79 changes: 66 additions & 13 deletions pkg/server/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,24 @@ import (
log "github.com/sirupsen/logrus"

"beryju.org/saml-test-sp/pkg/helpers"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
)

func hello(w http.ResponseWriter, r *http.Request) {
s := samlsp.SessionFromContext(r.Context())
type Server struct {
m *samlsp.Middleware
h *http.ServeMux
l *log.Entry
b string
}

func (s *Server) hello(w http.ResponseWriter, r *http.Request) {
sa := samlsp.SessionFromContext(r.Context())
if s == nil {
http.Error(w, "No Session", http.StatusInternalServerError)
return
}
sa, ok := s.(samlsp.SessionWithAttributes)
sa, ok := sa.(samlsp.SessionWithAttributes)
if !ok {
http.Error(w, "Session has no attributes", http.StatusInternalServerError)
return
Expand All @@ -32,18 +40,56 @@ func hello(w http.ResponseWriter, r *http.Request) {
w.Write(data)
}

func health(w http.ResponseWriter, r *http.Request) {
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
fmt.Fprint(w, "hello :)")
}

func logRequest(handler http.Handler) http.Handler {
func (s *Server) logRequest(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.WithField("remoteAddr", r.RemoteAddr).WithField("method", r.Method).Info(r.URL)
s.l.WithField("remoteAddr", r.RemoteAddr).WithField("method", r.Method).Info(r.URL)
handler.ServeHTTP(w, r)
})
}

func (s *Server) logout(w http.ResponseWriter, r *http.Request) {
nameID := samlsp.AttributeFromContext(r.Context(), "urn:oasis:names:tc:SAML:attribute:subject-id")
var binding *saml.Endpoint
for _, desc := range s.m.ServiceProvider.IDPMetadata.IDPSSODescriptors {
for _, slo := range desc.SingleLogoutServices {
s.l.WithField("slo", slo.Binding).Info("found SLO binding")
binding = &slo
break
}
}
if binding == nil {
s.l.Warning("no SLO descriptors found, aborting")
w.WriteHeader(400)
return
}

if binding.Binding == saml.HTTPRedirectBinding {
url, err := s.m.ServiceProvider.MakeRedirectLogoutRequest(nameID, s.b+"/health")
if err != nil {
s.l.WithError(err).Warning("failed to make redirect logout")
}
http.Redirect(w, r, url.String(), http.StatusFound)
} else if binding.Binding == saml.HTTPPostBinding {
res, err := s.m.ServiceProvider.MakePostLogoutRequest(nameID, s.b+"/health")
if err != nil {
s.l.WithError(err).Warning("failed to make post logout")
}
w.Header().Set("Content-Type", "text/html")
w.Write(res)
} else {
http.Error(w, "invalid binding", 500)
}
err := s.m.Session.DeleteSession(w, r)
if err != nil {
s.l.WithError(err).Warning("failed to delete session")
}
}

func RunServer() {
config := helpers.LoadConfig()

Expand All @@ -52,22 +98,29 @@ func RunServer() {
if err != nil {
panic(err)
}
http.Handle("/", samlSP.RequireAccount(http.HandlerFunc(hello)))
http.Handle("/saml/", samlSP)
http.HandleFunc("/health", health)
server := Server{
m: samlSP,
h: http.NewServeMux(),
l: log.WithField("component", "server"),
}
server.h.Handle("/", samlSP.RequireAccount(http.HandlerFunc(server.hello)))
server.h.Handle("/saml/logout", samlSP.RequireAccount(http.HandlerFunc(server.logout)))
server.h.Handle("/saml/", samlSP)
server.h.HandleFunc("/health", server.health)

listen := helpers.Env("SP_BIND", "localhost:9009")
log.Infof("Server listening on '%s'", listen)
log.Infof("ACS URL is '%s'", samlSP.ServiceProvider.AcsURL.String())
server.b = listen
server.l.Infof("Server listening on '%s'", listen)
server.l.Infof("ACS URL is '%s'", samlSP.ServiceProvider.AcsURL.String())

if _, set := os.LookupEnv("SP_SSL_CERT"); set {
// SP_SSL_CERT set, so we run SSL mode
err := http.ListenAndServeTLS(listen, os.Getenv("SP_SSL_CERT"), os.Getenv("SP_SSL_KEY"), logRequest(http.DefaultServeMux))
err := http.ListenAndServeTLS(listen, os.Getenv("SP_SSL_CERT"), os.Getenv("SP_SSL_KEY"), server.logRequest(server.h))
if err != nil {
panic(err)
}
} else {
err = http.ListenAndServe(listen, logRequest(http.DefaultServeMux))
err = http.ListenAndServe(listen, server.logRequest(server.h))
if err != nil {
panic(err)
}
Expand Down

0 comments on commit ae03e21

Please sign in to comment.