Skip to content

Commit

Permalink
support tls connection
Browse files Browse the repository at this point in the history
  • Loading branch information
fatedier committed Mar 18, 2019
1 parent f3a8ff0 commit f41d00b
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 5 deletions.
3 changes: 3 additions & 0 deletions client/recv.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"crypto/tls"
"fmt"
"net"
"os"
Expand All @@ -27,6 +28,7 @@ func (svc *Service) recvFile(id string, filePath string) error {
if err != nil {
return err
}
conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true})
defer conn.Close()

msg.WriteMsg(conn, &msg.ReceiveFile{
Expand Down Expand Up @@ -123,6 +125,7 @@ func newRecvStream(recv *receiver.Receiver, id string, addr string, debugMode bo
log(debugMode, "[%s] %v", addr, err)
return
}
conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true})

msg.WriteMsg(conn, &msg.NewReceiveFileStream{
ID: id,
Expand Down
3 changes: 3 additions & 0 deletions client/send.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"crypto/tls"
"fmt"
"net"
"os"
Expand All @@ -20,6 +21,7 @@ func (svc *Service) sendFile(id string, filePath string) error {
if err != nil {
return err
}
conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true})
defer conn.Close()

f, err := os.Open(filePath)
Expand Down Expand Up @@ -108,6 +110,7 @@ func newSendStream(s *sender.Sender, id string, addr string, debugMode bool) {
log(debugMode, "[%s] %v", addr, err)
return
}
conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true})

msg.WriteMsg(conn, &msg.NewSendFileStream{
ID: id,
Expand Down
31 changes: 31 additions & 0 deletions server/service.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package server

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"net"
"time"

Expand All @@ -28,6 +34,8 @@ type Service struct {
l net.Listener
workerGroup *WorkerGroup
matchController *MatchController

tlsConfig *tls.Config
}

func NewService(options Options) (*Service, error) {
Expand All @@ -51,6 +59,7 @@ func NewService(options Options) (*Service, error) {
l: l,
workerGroup: NewWorkerGroup(),
matchController: NewMatchController(),
tlsConfig: generateTLSConfig(),
}, nil
}

Expand All @@ -69,6 +78,7 @@ func (svc *Service) Run() error {
if err != nil {
return err
}
conn = tls.Server(conn, svc.tlsConfig)

go svc.handleConn(conn)
}
Expand Down Expand Up @@ -175,3 +185,24 @@ func (svc *Service) handleRecvFile(conn net.Conn, m *msg.ReceiveFile) error {
})
return nil
}

// Setup a bare-bones TLS config for the server
func generateTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})

tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{Certificates: []tls.Certificate{tlsCert}}
}
4 changes: 3 additions & 1 deletion server/worker.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"crypto/tls"
"errors"
"fmt"
"net"
Expand Down Expand Up @@ -52,6 +53,8 @@ func (w *Worker) DetectPublicAddr() error {
log.Warn("dial worker public address error: %v", err)
return ErrPublicAddr
}
detectConn = tls.Client(detectConn, &tls.Config{InsecureSkipVerify: true})
defer detectConn.Close()

msg.WriteMsg(detectConn, &msg.Ping{})

Expand All @@ -64,7 +67,6 @@ func (w *Worker) DetectPublicAddr() error {
if _, ok := m.(*msg.Pong); !ok {
return ErrPublicAddr
}
detectConn.Close()

w.publicAddr = detectAddr
return nil
Expand Down
3 changes: 3 additions & 0 deletions worker/register.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package worker

import (
"crypto/tls"
"fmt"
"net"
"time"
Expand Down Expand Up @@ -72,9 +73,11 @@ func (r *Register) RunKeepAlive(conn net.Conn) error {
time.Sleep(10 * time.Second)
continue
}
conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true})

err = r.Register(conn)
if err != nil {
conn.Close()
time.Sleep(10 * time.Second)
continue
}
Expand Down
39 changes: 35 additions & 4 deletions worker/service.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
package worker

import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"net"
"strconv"
"time"
Expand Down Expand Up @@ -31,8 +37,9 @@ type Service struct {
serverAddr string
advicePublicIP string

l net.Listener
matchCtl *MatchController
l net.Listener
matchCtl *MatchController
tlsConfig *tls.Config
}

func NewService(options Options) (*Service, error) {
Expand All @@ -56,8 +63,9 @@ func NewService(options Options) (*Service, error) {
serverAddr: options.ServerAddr,
advicePublicIP: options.AdvicePublicIP,

l: l,
matchCtl: NewMatchController(),
l: l,
matchCtl: NewMatchController(),
tlsConfig: generateTLSConfig(),
}, nil
}

Expand All @@ -69,6 +77,7 @@ func (svc *Service) Run() error {
if err != nil {
return err
}
conn = tls.Client(conn, &tls.Config{InsecureSkipVerify: true})

_, portStr, err := net.SplitHostPort(svc.l.Addr().String())
if err != nil {
Expand Down Expand Up @@ -96,6 +105,7 @@ func (svc *Service) worker() error {
if err != nil {
return err
}
conn = tls.Server(conn, svc.tlsConfig)
go svc.handleConn(conn)
}
}
Expand Down Expand Up @@ -142,3 +152,24 @@ func (svc *Service) handleConn(conn net.Conn) {
return
}
}

// Setup a bare-bones TLS config for the server
func generateTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})

tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{Certificates: []tls.Certificate{tlsCert}}
}

0 comments on commit f41d00b

Please sign in to comment.