diff --git a/main.go b/main.go index 9c77a60..7e07fe3 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "crypto/tls" "flag" "fmt" "io/ioutil" @@ -32,6 +33,8 @@ var ( logLevel = flag.String("L", "info", fmt.Sprintf("Log level. One of %v", getLogLevels())) flagSlackDebug = flag.Bool("D", false, "Enable debug logging of the Slack API") flagPagination = flag.Int("P", 0, "Pagination value for API calls. If 0 or unspecified, use the recommended default (currently 200). Larger values can help on large Slack teams") + flagKey = flag.String("key", "", "TLS key for HTTPS server. Requires -cert") + flagCert = flag.String("cert", "", "TLS certificate for HTTPS server. Requires -key") ) var log = logger.GetLogger("main") @@ -81,6 +84,21 @@ func main() { log.Fatalf("Missing or invalid download directory: %s", *fileDownloadLocation) } } + doTLS := false + if *flagKey != "" && *flagCert != "" { + doTLS = true + } + var tlsConfig *tls.Config + if doTLS { + if *flagKey == "" || *flagCert == "" { + log.Fatalf("-key and -cert must be specified together") + } + cert, err := tls.LoadX509KeyPair(*flagCert, *flagKey) + if err != nil { + log.Fatalf("Failed to load TLS key/cert: %v", err) + } + tlsConfig = &tls.Config{Certificates: []tls.Certificate{cert}} + } server := Server{ LocalAddr: &localAddr, Name: sName, @@ -89,6 +107,7 @@ func main() { FileProxyPrefix: *fileProxyPrefix, SlackDebug: *flagSlackDebug, Pagination: *flagPagination, + TLSConfig: tlsConfig, } if err := server.Start(); err != nil { log.Fatal(err) diff --git a/server.go b/server.go index c8c9c16..ce6d809 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "crypto/tls" "fmt" "io" "net" @@ -14,22 +15,27 @@ import ( type Server struct { Name string LocalAddr net.Addr - Listener *net.TCPListener + Listener net.Listener SlackAPIKey string SlackDebug bool ChunkSize int FileDownloadLocation string FileProxyPrefix string Pagination int + TLSConfig *tls.Config } // Start runs the IRC server func (s Server) Start() error { - listener, err := net.Listen("tcp", s.LocalAddr.String()) + var err error + if s.TLSConfig != nil { + s.Listener, err = tls.Listen("tcp", s.LocalAddr.String(), s.TLSConfig) + } else { + s.Listener, err = net.Listen("tcp", s.LocalAddr.String()) + } if err != nil { return err } - s.Listener = listener.(*net.TCPListener) defer s.Listener.Close() log.Infof("Listening on %v", s.LocalAddr) for {