diff --git a/Makefile b/Makefile index b2a106e..81a3012 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,13 @@ BINDIR := $(CURDIR)/bin INSTALL_PATH ?= /usr/local/bin -BINNAME ?= rdpgw -BINNAME2 ?= rdpgw-auth + +ifneq ($(GOOS),windows) + BINNAME ?= rdpgw + BINNAME2 ?= rdpgw-auth +else + BINNAME ?= rdpgw.exe + BINNAME2 ?= rdpgw-auth.exe +endif # Rebuild the binary if any of these files change SRC := $(shell find . -type f -name '*.go' -print) go.mod go.sum diff --git a/cmd/auth/auth.go b/cmd/auth/auth.go index ede837a..521070f 100644 --- a/cmd/auth/auth.go +++ b/cmd/auth/auth.go @@ -2,31 +2,13 @@ package main import ( "context" - "errors" - "fmt" "github.com/bolkedebruin/rdpgw/cmd/auth/config" "github.com/bolkedebruin/rdpgw/cmd/auth/database" "github.com/bolkedebruin/rdpgw/cmd/auth/ntlm" "github.com/bolkedebruin/rdpgw/shared/auth" - "github.com/msteinert/pam/v2" - "github.com/thought-machine/go-flags" - "google.golang.org/grpc" "log" - "net" - "os" - "syscall" ) -const ( - protocol = "unix" -) - -var opts struct { - ServiceName string `short:"n" long:"name" default:"rdpgw" description:"the PAM service name to use"` - SocketAddr string `short:"s" long:"socket" default:"/tmp/rdpgw-auth.sock" description:"the location of the socket"` - ConfigFile string `short:"c" long:"conf" default:"rdpgw-auth.yaml" description:"users config file for NTLM (yaml)"` -} - type AuthServiceImpl struct { auth.UnimplementedAuthenticateServer @@ -45,49 +27,6 @@ func NewAuthService(serviceName string, database database.Database) auth.Authent return s } -func (s *AuthServiceImpl) Authenticate(ctx context.Context, message *auth.UserPass) (*auth.AuthResponse, error) { - t, err := pam.StartFunc(s.serviceName, message.Username, func(s pam.Style, msg string) (string, error) { - switch s { - case pam.PromptEchoOff: - return message.Password, nil - case pam.PromptEchoOn, pam.ErrorMsg, pam.TextInfo: - return "", nil - } - return "", errors.New("unrecognized PAM message style") - }) - - r := &auth.AuthResponse{} - r.Authenticated = false - - if err != nil { - log.Printf("Error authenticating user: %s due to: %s", message.Username, err) - r.Error = err.Error() - return r, err - } - defer func() { - err := t.End() - if err != nil { - fmt.Fprintf(os.Stderr, "end: %v\n", err) - os.Exit(1) - } - }() - if err = t.Authenticate(0); err != nil { - log.Printf("Authentication for user: %s failed due to: %s", message.Username, err) - r.Error = err.Error() - return r, nil - } - - if err = t.AcctMgmt(0); err != nil { - log.Printf("Account authorization for user: %s failed due to %s", message.Username, err) - r.Error = err.Error() - return r, nil - } - - log.Printf("User: %s authenticated", message.Username) - r.Authenticated = true - return r, nil -} - func (s *AuthServiceImpl) NTLM(ctx context.Context, message *auth.NtlmRequest) (*auth.NtlmResponse, error) { r, err := s.ntlm.Authenticate(message) @@ -101,41 +40,3 @@ func (s *AuthServiceImpl) NTLM(ctx context.Context, message *auth.NtlmRequest) ( return r, err } - -func main() { - _, err := flags.Parse(&opts) - if err != nil { - var fErr *flags.Error - if errors.As(err, &fErr) { - if fErr.Type == flags.ErrHelp { - fmt.Printf("Acknowledgements:\n") - fmt.Printf(" - This product includes software developed by the Thomson Reuters Global Resources. (go-ntlm - https://github.com/m7913d/go-ntlm - BSD-4 License)\n") - } - } - return - } - - conf = config.Load(opts.ConfigFile) - - log.Printf("Starting auth server on %s", opts.SocketAddr) - cleanup := func() { - if _, err := os.Stat(opts.SocketAddr); err == nil { - if err := os.RemoveAll(opts.SocketAddr); err != nil { - log.Fatal(err) - } - } - } - cleanup() - - oldUmask := syscall.Umask(0) - listener, err := net.Listen(protocol, opts.SocketAddr) - syscall.Umask(oldUmask) - if err != nil { - log.Fatal(err) - } - server := grpc.NewServer() - db := database.NewConfig(conf.Users) - service := NewAuthService(opts.ServiceName, db) - auth.RegisterAuthenticateServer(server, service) - server.Serve(listener) -} diff --git a/cmd/auth/auth_unix.go b/cmd/auth/auth_unix.go new file mode 100644 index 0000000..a312771 --- /dev/null +++ b/cmd/auth/auth_unix.go @@ -0,0 +1,110 @@ +// +build !windows + +package main + +import ( + "context" + "errors" + "fmt" + "github.com/bolkedebruin/rdpgw/cmd/auth/config" + "github.com/bolkedebruin/rdpgw/cmd/auth/database" + "github.com/bolkedebruin/rdpgw/shared/auth" + "github.com/thought-machine/go-flags" + "github.com/msteinert/pam/v2" + "google.golang.org/grpc" + "log" + "net" + "os" + "syscall" +) + +const ( + protocol = "unix" +) + +var opts struct { + ServiceName string `short:"n" long:"name" default:"rdpgw" description:"the PAM service name to use"` + SocketAddr string `short:"s" long:"socket" default:"/tmp/rdpgw-auth.sock" description:"the location of the socket"` + ConfigFile string `short:"c" long:"conf" default:"rdpgw-auth.yaml" description:"users config file for NTLM (yaml)"` +} + +func (s *AuthServiceImpl) Authenticate(ctx context.Context, message *auth.UserPass) (*auth.AuthResponse, error) { + t, err := pam.StartFunc(s.serviceName, message.Username, func(s pam.Style, msg string) (string, error) { + switch s { + case pam.PromptEchoOff: + return message.Password, nil + case pam.PromptEchoOn, pam.ErrorMsg, pam.TextInfo: + return "", nil + } + return "", errors.New("unrecognized PAM message style") + }) + + r := &auth.AuthResponse{} + r.Authenticated = false + + if err != nil { + log.Printf("Error authenticating user: %s due to: %s", message.Username, err) + r.Error = err.Error() + return r, err + } + defer func() { + err := t.End() + if err != nil { + fmt.Fprintf(os.Stderr, "end: %v\n", err) + os.Exit(1) + } + }() + if err = t.Authenticate(0); err != nil { + log.Printf("Authentication for user: %s failed due to: %s", message.Username, err) + r.Error = err.Error() + return r, nil + } + + if err = t.AcctMgmt(0); err != nil { + log.Printf("Account authorization for user: %s failed due to %s", message.Username, err) + r.Error = err.Error() + return r, nil + } + + log.Printf("User: %s authenticated", message.Username) + r.Authenticated = true + return r, nil +} + +func main() { + _, err := flags.Parse(&opts) + if err != nil { + var fErr *flags.Error + if errors.As(err, &fErr) { + if fErr.Type == flags.ErrHelp { + fmt.Printf("Acknowledgements:\n") + fmt.Printf(" - This product includes software developed by the Thomson Reuters Global Resources. (go-ntlm - https://github.com/m7913d/go-ntlm - BSD-4 License)\n") + } + } + return + } + + conf = config.Load(opts.ConfigFile) + + log.Printf("Starting auth server on %s", opts.SocketAddr) + cleanup := func() { + if _, err := os.Stat(opts.SocketAddr); err == nil { + if err := os.RemoveAll(opts.SocketAddr); err != nil { + log.Fatal(err) + } + } + } + cleanup() + + oldUmask := syscall.Umask(0) + listener, err := net.Listen(protocol, opts.SocketAddr) + syscall.Umask(oldUmask) + if err != nil { + log.Fatal(err) + } + server := grpc.NewServer() + db := database.NewConfig(conf.Users) + service := NewAuthService(opts.ServiceName, db) + auth.RegisterAuthenticateServer(server, service) + server.Serve(listener) +} diff --git a/cmd/auth/auth_windows.go b/cmd/auth/auth_windows.go new file mode 100644 index 0000000..79e78af --- /dev/null +++ b/cmd/auth/auth_windows.go @@ -0,0 +1,63 @@ +// +build windows + +package main + +import ( + "errors" + "fmt" + "github.com/bolkedebruin/rdpgw/cmd/auth/config" + "github.com/bolkedebruin/rdpgw/cmd/auth/database" + "github.com/bolkedebruin/rdpgw/shared/auth" + "github.com/thought-machine/go-flags" + "google.golang.org/grpc" + "log" + "net" + "os" +) + +const ( + protocol = "tcp" +) + +var opts struct { + ServiceName string `short:"n" long:"name" default:"rdpgw" description:"the PAM service name to use"` + SocketAddr string `short:"s" long:"socket" default:"127.0.0.1:3000" description:"the location of the socket"` + ConfigFile string `short:"c" long:"conf" default:"rdpgw-auth.yaml" description:"users config file for NTLM (yaml)"` +} + + +func main() { + _, err := flags.Parse(&opts) + if err != nil { + var fErr *flags.Error + if errors.As(err, &fErr) { + if fErr.Type == flags.ErrHelp { + fmt.Printf("Acknowledgements:\n") + fmt.Printf(" - This product includes software developed by the Thomson Reuters Global Resources. (go-ntlm - https://github.com/m7913d/go-ntlm - BSD-4 License)\n") + } + } + return + } + + conf = config.Load(opts.ConfigFile) + + log.Printf("Starting auth server on %s", opts.SocketAddr) + cleanup := func() { + if _, err := os.Stat(opts.SocketAddr); err == nil { + if err := os.RemoveAll(opts.SocketAddr); err != nil { + log.Fatal(err) + } + } + } + cleanup() + + listener, err := net.Listen(protocol, opts.SocketAddr) + if err != nil { + log.Fatal(err) + } + server := grpc.NewServer() + db := database.NewConfig(conf.Users) + service := NewAuthService(opts.ServiceName, db) + auth.RegisterAuthenticateServer(server, service) + server.Serve(listener) +} diff --git a/cmd/rdpgw/config/configuration.go b/cmd/rdpgw/config/configuration.go index ead4140..d8a89a2 100644 --- a/cmd/rdpgw/config/configuration.go +++ b/cmd/rdpgw/config/configuration.go @@ -1,14 +1,6 @@ package config import ( - "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" - "github.com/knadh/koanf/parsers/yaml" - "github.com/knadh/koanf/providers/confmap" - "github.com/knadh/koanf/providers/env" - "github.com/knadh/koanf/providers/file" - "github.com/knadh/koanf/v2" - "log" - "os" "strings" ) @@ -135,117 +127,6 @@ func ToCamel(s string) string { var Conf Configuration -func Load(configFile string) Configuration { - - var k = koanf.New(".") - - k.Load(confmap.Provider(map[string]interface{}{ - "Server.Tls": "auto", - "Server.Port": 443, - "Server.SessionStore": "cookie", - "Server.HostSelection": "roundrobin", - "Server.Authentication": "openid", - "Server.AuthSocket": "/tmp/rdpgw-auth.sock", - "Server.BasicAuthTimeout": 5, - "Client.NetworkAutoDetect": 1, - "Client.BandwidthAutoDetect": 1, - "Security.VerifyClientIp": true, - "Caps.TokenAuth": true, - }, "."), nil) - - if _, err := os.Stat(configFile); os.IsNotExist(err) { - log.Printf("Config file %s not found, using defaults and environment", configFile) - } else { - if err := k.Load(file.Provider(configFile), yaml.Parser()); err != nil { - log.Fatalf("Error loading config from file: %v", err) - } - } - - if err := k.Load(env.ProviderWithValue("RDPGW_", ".", func(s string, v string) (string, interface{}) { - key := strings.Replace(strings.ToLower(strings.TrimPrefix(s, "RDPGW_")), "__", ".", -1) - key = ToCamel(key) - - v = strings.Trim(v, " ") - - // handle lists - if strings.Contains(v, " ") { - return key, strings.Split(v, " ") - } - return key, v - - }), nil); err != nil { - log.Fatalf("Error loading config from environment: %v", err) - } - - koanfTag := koanf.UnmarshalConf{Tag: "koanf"} - k.UnmarshalWithConf("Server", &Conf.Server, koanfTag) - k.UnmarshalWithConf("OpenId", &Conf.OpenId, koanfTag) - k.UnmarshalWithConf("Caps", &Conf.Caps, koanfTag) - k.UnmarshalWithConf("Security", &Conf.Security, koanfTag) - k.UnmarshalWithConf("Client", &Conf.Client, koanfTag) - k.UnmarshalWithConf("Kerberos", &Conf.Kerberos, koanfTag) - - if len(Conf.Security.PAATokenEncryptionKey) != 32 { - Conf.Security.PAATokenEncryptionKey, _ = security.GenerateRandomString(32) - log.Printf("No valid `security.paatokenencryptionkey` specified (empty or not 32 characters). Setting to random") - } - - if len(Conf.Security.PAATokenSigningKey) != 32 { - Conf.Security.PAATokenSigningKey, _ = security.GenerateRandomString(32) - log.Printf("No valid `security.paatokensigningkey` specified (empty or not 32 characters). Setting to random") - } - - if Conf.Security.EnableUserToken { - if len(Conf.Security.UserTokenEncryptionKey) != 32 { - Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32) - log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random") - } - - if len(Conf.Security.UserTokenSigningKey) != 32 { - Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32) - log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random") - } - } - - if len(Conf.Server.SessionKey) != 32 { - Conf.Server.SessionKey, _ = security.GenerateRandomString(32) - log.Printf("No valid `server.sessionkey` specified (empty or not 32 characters). Setting to random") - } - - if len(Conf.Server.SessionEncryptionKey) != 32 { - Conf.Server.SessionEncryptionKey, _ = security.GenerateRandomString(32) - log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random") - } - - if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 { - log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set") - } - - if Conf.Server.BasicAuthEnabled() && Conf.Server.Tls == "disable" { - log.Fatalf("basicauth=local and tls=disable are mutually exclusive") - } - - if Conf.Server.NtlmEnabled() && Conf.Server.KerberosEnabled() { - log.Fatalf("ntlm and kerberos authentication are not stackable") - } - - if !Conf.Caps.TokenAuth && Conf.Server.OpenIDEnabled() { - log.Fatalf("openid is configured but tokenauth disabled") - } - - if Conf.Server.KerberosEnabled() && Conf.Kerberos.Keytab == "" { - log.Fatalf("kerberos is configured but no keytab was specified") - } - - // prepend '//' if required for URL parsing - if !strings.Contains(Conf.Server.GatewayAddress, "//") { - Conf.Server.GatewayAddress = "//" + Conf.Server.GatewayAddress - } - - return Conf - -} - func (s *ServerConfig) OpenIDEnabled() bool { return s.matchAuth("openid") } diff --git a/cmd/rdpgw/config/configuration_unix.go b/cmd/rdpgw/config/configuration_unix.go new file mode 100644 index 0000000..735e999 --- /dev/null +++ b/cmd/rdpgw/config/configuration_unix.go @@ -0,0 +1,126 @@ +// +build !windows + +package config + +import ( + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" + "github.com/knadh/koanf/parsers/yaml" + "github.com/knadh/koanf/providers/confmap" + "github.com/knadh/koanf/providers/env" + "github.com/knadh/koanf/providers/file" + "github.com/knadh/koanf/v2" + "log" + "os" + "strings" +) + +func Load(configFile string) Configuration { + + var k = koanf.New(".") + + k.Load(confmap.Provider(map[string]interface{}{ + "Server.Tls": "auto", + "Server.Port": 443, + "Server.SessionStore": "cookie", + "Server.HostSelection": "roundrobin", + "Server.Authentication": "openid", + "Server.AuthSocket": "/tmp/rdpgw-auth.sock", + "Server.BasicAuthTimeout": 5, + "Client.NetworkAutoDetect": 1, + "Client.BandwidthAutoDetect": 1, + "Security.VerifyClientIp": true, + "Caps.TokenAuth": true, + }, "."), nil) + + if _, err := os.Stat(configFile); os.IsNotExist(err) { + log.Printf("Config file %s not found, using defaults and environment", configFile) + } else { + if err := k.Load(file.Provider(configFile), yaml.Parser()); err != nil { + log.Fatalf("Error loading config from file: %v", err) + } + } + + if err := k.Load(env.ProviderWithValue("RDPGW_", ".", func(s string, v string) (string, interface{}) { + key := strings.Replace(strings.ToLower(strings.TrimPrefix(s, "RDPGW_")), "__", ".", -1) + key = ToCamel(key) + + v = strings.Trim(v, " ") + + // handle lists + if strings.Contains(v, " ") { + return key, strings.Split(v, " ") + } + return key, v + + }), nil); err != nil { + log.Fatalf("Error loading config from environment: %v", err) + } + + koanfTag := koanf.UnmarshalConf{Tag: "koanf"} + k.UnmarshalWithConf("Server", &Conf.Server, koanfTag) + k.UnmarshalWithConf("OpenId", &Conf.OpenId, koanfTag) + k.UnmarshalWithConf("Caps", &Conf.Caps, koanfTag) + k.UnmarshalWithConf("Security", &Conf.Security, koanfTag) + k.UnmarshalWithConf("Client", &Conf.Client, koanfTag) + k.UnmarshalWithConf("Kerberos", &Conf.Kerberos, koanfTag) + + if len(Conf.Security.PAATokenEncryptionKey) != 32 { + Conf.Security.PAATokenEncryptionKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `security.paatokenencryptionkey` specified (empty or not 32 characters). Setting to random") + } + + if len(Conf.Security.PAATokenSigningKey) != 32 { + Conf.Security.PAATokenSigningKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `security.paatokensigningkey` specified (empty or not 32 characters). Setting to random") + } + + if Conf.Security.EnableUserToken { + if len(Conf.Security.UserTokenEncryptionKey) != 32 { + Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random") + } + + if len(Conf.Security.UserTokenSigningKey) != 32 { + Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random") + } + } + + if len(Conf.Server.SessionKey) != 32 { + Conf.Server.SessionKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `server.sessionkey` specified (empty or not 32 characters). Setting to random") + } + + if len(Conf.Server.SessionEncryptionKey) != 32 { + Conf.Server.SessionEncryptionKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random") + } + + if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 { + log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set") + } + + if Conf.Server.BasicAuthEnabled() && Conf.Server.Tls == "disable" { + log.Fatalf("basicauth=local and tls=disable are mutually exclusive") + } + + if Conf.Server.NtlmEnabled() && Conf.Server.KerberosEnabled() { + log.Fatalf("ntlm and kerberos authentication are not stackable") + } + + if !Conf.Caps.TokenAuth && Conf.Server.OpenIDEnabled() { + log.Fatalf("openid is configured but tokenauth disabled") + } + + if Conf.Server.KerberosEnabled() && Conf.Kerberos.Keytab == "" { + log.Fatalf("kerberos is configured but no keytab was specified") + } + + // prepend '//' if required for URL parsing + if !strings.Contains(Conf.Server.GatewayAddress, "//") { + Conf.Server.GatewayAddress = "//" + Conf.Server.GatewayAddress + } + + return Conf + +} diff --git a/cmd/rdpgw/config/configuration_windows.go b/cmd/rdpgw/config/configuration_windows.go new file mode 100644 index 0000000..4707eef --- /dev/null +++ b/cmd/rdpgw/config/configuration_windows.go @@ -0,0 +1,126 @@ +// +build windows + +package config + +import ( + "github.com/bolkedebruin/rdpgw/cmd/rdpgw/security" + "github.com/knadh/koanf/parsers/yaml" + "github.com/knadh/koanf/providers/confmap" + "github.com/knadh/koanf/providers/env" + "github.com/knadh/koanf/providers/file" + "github.com/knadh/koanf/v2" + "log" + "os" + "strings" +) + +func Load(configFile string) Configuration { + + var k = koanf.New(".") + + k.Load(confmap.Provider(map[string]interface{}{ + "Server.Tls": "auto", + "Server.Port": 443, + "Server.SessionStore": "cookie", + "Server.HostSelection": "roundrobin", + "Server.Authentication": "openid", + "Server.AuthSocket": "127.0.0.1:3000", + "Server.BasicAuthTimeout": 5, + "Client.NetworkAutoDetect": 1, + "Client.BandwidthAutoDetect": 1, + "Security.VerifyClientIp": true, + "Caps.TokenAuth": true, + }, "."), nil) + + if _, err := os.Stat(configFile); os.IsNotExist(err) { + log.Printf("Config file %s not found, using defaults and environment", configFile) + } else { + if err := k.Load(file.Provider(configFile), yaml.Parser()); err != nil { + log.Fatalf("Error loading config from file: %v", err) + } + } + + if err := k.Load(env.ProviderWithValue("RDPGW_", ".", func(s string, v string) (string, interface{}) { + key := strings.Replace(strings.ToLower(strings.TrimPrefix(s, "RDPGW_")), "__", ".", -1) + key = ToCamel(key) + + v = strings.Trim(v, " ") + + // handle lists + if strings.Contains(v, " ") { + return key, strings.Split(v, " ") + } + return key, v + + }), nil); err != nil { + log.Fatalf("Error loading config from environment: %v", err) + } + + koanfTag := koanf.UnmarshalConf{Tag: "koanf"} + k.UnmarshalWithConf("Server", &Conf.Server, koanfTag) + k.UnmarshalWithConf("OpenId", &Conf.OpenId, koanfTag) + k.UnmarshalWithConf("Caps", &Conf.Caps, koanfTag) + k.UnmarshalWithConf("Security", &Conf.Security, koanfTag) + k.UnmarshalWithConf("Client", &Conf.Client, koanfTag) + k.UnmarshalWithConf("Kerberos", &Conf.Kerberos, koanfTag) + + if len(Conf.Security.PAATokenEncryptionKey) != 32 { + Conf.Security.PAATokenEncryptionKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `security.paatokenencryptionkey` specified (empty or not 32 characters). Setting to random") + } + + if len(Conf.Security.PAATokenSigningKey) != 32 { + Conf.Security.PAATokenSigningKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `security.paatokensigningkey` specified (empty or not 32 characters). Setting to random") + } + + if Conf.Security.EnableUserToken { + if len(Conf.Security.UserTokenEncryptionKey) != 32 { + Conf.Security.UserTokenEncryptionKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `security.usertokenencryptionkey` specified (empty or not 32 characters). Setting to random") + } + + if len(Conf.Security.UserTokenSigningKey) != 32 { + Conf.Security.UserTokenSigningKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `security.usertokensigningkey` specified (empty or not 32 characters). Setting to random") + } + } + + if len(Conf.Server.SessionKey) != 32 { + Conf.Server.SessionKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `server.sessionkey` specified (empty or not 32 characters). Setting to random") + } + + if len(Conf.Server.SessionEncryptionKey) != 32 { + Conf.Server.SessionEncryptionKey, _ = security.GenerateRandomString(32) + log.Printf("No valid `server.sessionencryptionkey` specified (empty or not 32 characters). Setting to random") + } + + if Conf.Server.HostSelection == "signed" && len(Conf.Security.QueryTokenSigningKey) == 0 { + log.Fatalf("host selection is set to `signed` but `querytokensigningkey` is not set") + } + + if Conf.Server.BasicAuthEnabled() && Conf.Server.Tls == "disable" { + log.Fatalf("basicauth=local and tls=disable are mutually exclusive") + } + + if Conf.Server.NtlmEnabled() && Conf.Server.KerberosEnabled() { + log.Fatalf("ntlm and kerberos authentication are not stackable") + } + + if !Conf.Caps.TokenAuth && Conf.Server.OpenIDEnabled() { + log.Fatalf("openid is configured but tokenauth disabled") + } + + if Conf.Server.KerberosEnabled() && Conf.Kerberos.Keytab == "" { + log.Fatalf("kerberos is configured but no keytab was specified") + } + + // prepend '//' if required for URL parsing + if !strings.Contains(Conf.Server.GatewayAddress, "//") { + Conf.Server.GatewayAddress = "//" + Conf.Server.GatewayAddress + } + + return Conf + +} diff --git a/cmd/rdpgw/protocol/gateway.go b/cmd/rdpgw/protocol/gateway.go index 51fae1a..a52dcc1 100644 --- a/cmd/rdpgw/protocol/gateway.go +++ b/cmd/rdpgw/protocol/gateway.go @@ -2,17 +2,13 @@ package protocol import ( "context" - "errors" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/identity" "github.com/bolkedebruin/rdpgw/cmd/rdpgw/transport" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/patrickmn/go-cache" "log" - "net" "net/http" - "reflect" - "syscall" "time" ) @@ -100,65 +96,6 @@ func (g *Gateway) HandleGatewayProtocol(w http.ResponseWriter, r *http.Request) } } -func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error { - if g.SendBuf < 1 && g.ReceiveBuf < 1 { - return nil - } - - // conn == tls.Tunnel - ptr := reflect.ValueOf(conn) - val := reflect.Indirect(ptr) - - if val.Kind() != reflect.Struct { - return errors.New("didn't get a struct from conn") - } - - // this gets net.Tunnel -> *net.TCPConn -> net.TCPConn - ptrConn := val.FieldByName("conn") - valConn := reflect.Indirect(ptrConn) - if !valConn.IsValid() { - return errors.New("cannot find conn field") - } - valConn = valConn.Elem().Elem() - - // net.FD - ptrNetFd := valConn.FieldByName("fd") - valNetFd := reflect.Indirect(ptrNetFd) - if !valNetFd.IsValid() { - return errors.New("cannot find fd field") - } - - // pfd member - ptrPfd := valNetFd.FieldByName("pfd") - valPfd := reflect.Indirect(ptrPfd) - if !valPfd.IsValid() { - return errors.New("cannot find pfd field") - } - - // finally the exported Sysfd - ptrSysFd := valPfd.FieldByName("Sysfd") - if !ptrSysFd.IsValid() { - return errors.New("cannot find Sysfd field") - } - fd := int(ptrSysFd.Int()) - - if g.ReceiveBuf > 0 { - err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ReceiveBuf) - if err != nil { - return wrapSyscallError("setsockopt", err) - } - } - - if g.SendBuf > 0 { - err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.SendBuf) - if err != nil { - return wrapSyscallError("setsockopt", err) - } - } - - return nil -} - func (g *Gateway) handleWebsocketProtocol(ctx context.Context, c *websocket.Conn, t *Tunnel) { websocketConnections.Inc() defer websocketConnections.Dec() diff --git a/cmd/rdpgw/protocol/gateway_unix.go b/cmd/rdpgw/protocol/gateway_unix.go new file mode 100644 index 0000000..b8a5b44 --- /dev/null +++ b/cmd/rdpgw/protocol/gateway_unix.go @@ -0,0 +1,69 @@ +// +build !windows + +package protocol + +import ( + "errors" + "net" + "reflect" + "syscall" +) + +func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error { + if g.SendBuf < 1 && g.ReceiveBuf < 1 { + return nil + } + + // conn == tls.Tunnel + ptr := reflect.ValueOf(conn) + val := reflect.Indirect(ptr) + + if val.Kind() != reflect.Struct { + return errors.New("didn't get a struct from conn") + } + + // this gets net.Tunnel -> *net.TCPConn -> net.TCPConn + ptrConn := val.FieldByName("conn") + valConn := reflect.Indirect(ptrConn) + if !valConn.IsValid() { + return errors.New("cannot find conn field") + } + valConn = valConn.Elem().Elem() + + // net.FD + ptrNetFd := valConn.FieldByName("fd") + valNetFd := reflect.Indirect(ptrNetFd) + if !valNetFd.IsValid() { + return errors.New("cannot find fd field") + } + + // pfd member + ptrPfd := valNetFd.FieldByName("pfd") + valPfd := reflect.Indirect(ptrPfd) + if !valPfd.IsValid() { + return errors.New("cannot find pfd field") + } + + // finally the exported Sysfd + ptrSysFd := valPfd.FieldByName("Sysfd") + if !ptrSysFd.IsValid() { + return errors.New("cannot find Sysfd field") + } + fd := int(ptrSysFd.Int()) + + if g.ReceiveBuf > 0 { + err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ReceiveBuf) + if err != nil { + return wrapSyscallError("setsockopt", err) + } + } + + if g.SendBuf > 0 { + err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.SendBuf) + if err != nil { + return wrapSyscallError("setsockopt", err) + } + } + + return nil +} diff --git a/cmd/rdpgw/protocol/gateway_windows.go b/cmd/rdpgw/protocol/gateway_windows.go new file mode 100644 index 0000000..eed61f7 --- /dev/null +++ b/cmd/rdpgw/protocol/gateway_windows.go @@ -0,0 +1,69 @@ +// +build windows + +package protocol + +import ( + "errors" + "net" + "reflect" + "syscall" +) + +func (g *Gateway) setSendReceiveBuffers(conn net.Conn) error { + if g.SendBuf < 1 && g.ReceiveBuf < 1 { + return nil + } + + // conn == tls.Tunnel + ptr := reflect.ValueOf(conn) + val := reflect.Indirect(ptr) + + if val.Kind() != reflect.Struct { + return errors.New("didn't get a struct from conn") + } + + // this gets net.Tunnel -> *net.TCPConn -> net.TCPConn + ptrConn := val.FieldByName("conn") + valConn := reflect.Indirect(ptrConn) + if !valConn.IsValid() { + return errors.New("cannot find conn field") + } + valConn = valConn.Elem().Elem() + + // net.FD + ptrNetFd := valConn.FieldByName("fd") + valNetFd := reflect.Indirect(ptrNetFd) + if !valNetFd.IsValid() { + return errors.New("cannot find fd field") + } + + // pfd member + ptrPfd := valNetFd.FieldByName("pfd") + valPfd := reflect.Indirect(ptrPfd) + if !valPfd.IsValid() { + return errors.New("cannot find pfd field") + } + + // finally the exported Sysfd + ptrSysFd := valPfd.FieldByName("Sysfd") + if !ptrSysFd.IsValid() { + return errors.New("cannot find Sysfd field") + } + fd := ptrSysFd.Uint() + + if g.ReceiveBuf > 0 { + err := syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUF, g.ReceiveBuf) + if err != nil { + return wrapSyscallError("setsockopt", err) + } + } + + if g.SendBuf > 0 { + err := syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUF, g.SendBuf) + if err != nil { + return wrapSyscallError("setsockopt", err) + } + } + + return nil +} diff --git a/cmd/rdpgw/web/basic.go b/cmd/rdpgw/web/basic.go index b852634..58d2a63 100644 --- a/cmd/rdpgw/web/basic.go +++ b/cmd/rdpgw/web/basic.go @@ -12,10 +12,6 @@ import ( "time" ) -const ( - protocolGrpc = "unix" -) - type BasicAuthHandler struct { SocketAddress string Timeout int diff --git a/cmd/rdpgw/web/basic_unix.go b/cmd/rdpgw/web/basic_unix.go new file mode 100644 index 0000000..947f64e --- /dev/null +++ b/cmd/rdpgw/web/basic_unix.go @@ -0,0 +1,7 @@ +// +build !windows + +package web + +const ( + protocolGrpc = "unix" +) diff --git a/cmd/rdpgw/web/basic_windows.go b/cmd/rdpgw/web/basic_windows.go new file mode 100644 index 0000000..eab3474 --- /dev/null +++ b/cmd/rdpgw/web/basic_windows.go @@ -0,0 +1,7 @@ +// +build windows + +package web + +const ( + protocolGrpc = "tcp" +)