diff --git a/client/recv.go b/client/recv.go index cd66ac4..655c961 100644 --- a/client/recv.go +++ b/client/recv.go @@ -1,6 +1,7 @@ package client import ( + "crypto/tls" "fmt" "net" "os" @@ -27,6 +28,7 @@ func (svc *Service) recvFile(id string, filePath string) error { if err != nil { return err } + conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) defer conn.Close() msg.WriteMsg(conn, &msg.ReceiveFile{ @@ -123,6 +125,7 @@ func newRecvStream(recv *receiver.Receiver, id string, addr string, debugMode bo log(debugMode, "[%s] %v", addr, err) return } + conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) msg.WriteMsg(conn, &msg.NewReceiveFileStream{ ID: id, diff --git a/client/send.go b/client/send.go index 99ad9ad..33a1e5c 100644 --- a/client/send.go +++ b/client/send.go @@ -1,6 +1,7 @@ package client import ( + "crypto/tls" "fmt" "net" "os" @@ -20,6 +21,7 @@ func (svc *Service) sendFile(id string, filePath string) error { if err != nil { return err } + conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) defer conn.Close() f, err := os.Open(filePath) @@ -108,6 +110,7 @@ func newSendStream(s *sender.Sender, id string, addr string, debugMode bool) { log(debugMode, "[%s] %v", addr, err) return } + conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) msg.WriteMsg(conn, &msg.NewSendFileStream{ ID: id, diff --git a/server/service.go b/server/service.go index eb898e8..cd6aca5 100644 --- a/server/service.go +++ b/server/service.go @@ -1,7 +1,13 @@ package server import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" "fmt" + "math/big" "net" "time" @@ -28,6 +34,8 @@ type Service struct { l net.Listener workerGroup *WorkerGroup matchController *MatchController + + tlsConfig *tls.Config } func NewService(options Options) (*Service, error) { @@ -51,6 +59,7 @@ func NewService(options Options) (*Service, error) { l: l, workerGroup: NewWorkerGroup(), matchController: NewMatchController(), + tlsConfig: generateTLSConfig(), }, nil } @@ -69,6 +78,7 @@ func (svc *Service) Run() error { if err != nil { return err } + conn = tls.Server(conn, svc.tlsConfig) go svc.handleConn(conn) } @@ -175,3 +185,24 @@ func (svc *Service) handleRecvFile(conn net.Conn, m *msg.ReceiveFile) error { }) return nil } + +// Setup a bare-bones TLS config for the server +func generateTLSConfig() *tls.Config { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + template := x509.Certificate{SerialNumber: big.NewInt(1)} + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + panic(err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{tlsCert}} +} diff --git a/server/worker.go b/server/worker.go index 683cdfc..ded9f69 100644 --- a/server/worker.go +++ b/server/worker.go @@ -1,6 +1,7 @@ package server import ( + "crypto/tls" "errors" "fmt" "net" @@ -52,6 +53,8 @@ func (w *Worker) DetectPublicAddr() error { log.Warn("dial worker public address error: %v", err) return ErrPublicAddr } + detectConn = tls.Client(detectConn, &tls.Config{InsecureSkipVerify: true}) + defer detectConn.Close() msg.WriteMsg(detectConn, &msg.Ping{}) @@ -64,7 +67,6 @@ func (w *Worker) DetectPublicAddr() error { if _, ok := m.(*msg.Pong); !ok { return ErrPublicAddr } - detectConn.Close() w.publicAddr = detectAddr return nil diff --git a/worker/register.go b/worker/register.go index b4bf217..a97090c 100644 --- a/worker/register.go +++ b/worker/register.go @@ -1,6 +1,7 @@ package worker import ( + "crypto/tls" "fmt" "net" "time" @@ -72,9 +73,11 @@ func (r *Register) RunKeepAlive(conn net.Conn) error { time.Sleep(10 * time.Second) continue } + conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) err = r.Register(conn) if err != nil { + conn.Close() time.Sleep(10 * time.Second) continue } diff --git a/worker/service.go b/worker/service.go index d4f4340..90e57f2 100644 --- a/worker/service.go +++ b/worker/service.go @@ -1,7 +1,13 @@ package worker import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" "fmt" + "math/big" "net" "strconv" "time" @@ -31,8 +37,9 @@ type Service struct { serverAddr string advicePublicIP string - l net.Listener - matchCtl *MatchController + l net.Listener + matchCtl *MatchController + tlsConfig *tls.Config } func NewService(options Options) (*Service, error) { @@ -56,8 +63,9 @@ func NewService(options Options) (*Service, error) { serverAddr: options.ServerAddr, advicePublicIP: options.AdvicePublicIP, - l: l, - matchCtl: NewMatchController(), + l: l, + matchCtl: NewMatchController(), + tlsConfig: generateTLSConfig(), }, nil } @@ -69,6 +77,7 @@ func (svc *Service) Run() error { if err != nil { return err } + conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true}) _, portStr, err := net.SplitHostPort(svc.l.Addr().String()) if err != nil { @@ -96,6 +105,7 @@ func (svc *Service) worker() error { if err != nil { return err } + conn = tls.Server(conn, svc.tlsConfig) go svc.handleConn(conn) } } @@ -142,3 +152,24 @@ func (svc *Service) handleConn(conn net.Conn) { return } } + +// Setup a bare-bones TLS config for the server +func generateTLSConfig() *tls.Config { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + template := x509.Certificate{SerialNumber: big.NewInt(1)} + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + panic(err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + panic(err) + } + return &tls.Config{Certificates: []tls.Certificate{tlsCert}} +}