Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconnect unhealthy tunnels #112

Merged
merged 4 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions tunnel/internal/client/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ func (c *Client) Start(ctx context.Context, services ...string) error {
continue
}
clientConfigs = append(clientConfigs, config.ClientConfig{
ServerUrl: c.config.ServerUrl,
SshUrl: c.config.SshUrl,
TunnelUrl: c.config.TunnelUrl,
SecretKey: c.config.SecretKey,
Tunnel: tunnel,
UseLocalHost: c.config.UseLocalHost,
Debug: c.config.Debug,
EnableRequestLogging: c.config.EnableRequestLogging,
ServerUrl: c.config.ServerUrl,
SshUrl: c.config.SshUrl,
TunnelUrl: c.config.TunnelUrl,
SecretKey: c.config.SecretKey,
Tunnel: tunnel,
UseLocalHost: c.config.UseLocalHost,
Debug: c.config.Debug,
EnableRequestLogging: c.config.EnableRequestLogging,
HealthCheckInterval: c.config.HealthCheckInterval,
HealthCheckMaxRetries: c.config.HealthCheckMaxRetries,
})
}

Expand Down
46 changes: 29 additions & 17 deletions tunnel/internal/client/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ func (t *Tunnel) GetLocalAddr() string {
}

type Config struct {
ServerUrl string `yaml:"server_url"`
SshUrl string `yaml:"ssh_url"`
TunnelUrl string `yaml:"tunnel_url"`
SecretKey string `yaml:"secret_key"`
Tunnels []Tunnel `yaml:"tunnels"`
UseLocalHost bool `yaml:"use_localhost"`
Debug bool `yaml:"debug"`
UseVite bool `yaml:"use_vite"`
EnableRequestLogging bool `yaml:"enable_request_logging"`
ServerUrl string `yaml:"server_url"`
SshUrl string `yaml:"ssh_url"`
TunnelUrl string `yaml:"tunnel_url"`
SecretKey string `yaml:"secret_key"`
Tunnels []Tunnel `yaml:"tunnels"`
UseLocalHost bool `yaml:"use_localhost"`
Debug bool `yaml:"debug"`
UseVite bool `yaml:"use_vite"`
EnableRequestLogging bool `yaml:"enable_request_logging"`
HealthCheckInterval int `yaml:"health_check_interval"`
HealthCheckMaxRetries int `yaml:"health_check_max_retries"`
}

func (c *Config) SetDefaults() {
Expand All @@ -69,6 +71,14 @@ func (c *Config) SetDefaults() {
c.TunnelUrl = c.ServerUrl
}

if c.HealthCheckInterval == 0 {
c.HealthCheckInterval = 3
}

if c.HealthCheckMaxRetries == 0 {
c.HealthCheckMaxRetries = 10
}

for i := range c.Tunnels {
c.Tunnels[i].SetDefaults()
}
Expand All @@ -84,14 +94,16 @@ func (c Config) GetAdminAddress() string {
}

type ClientConfig struct {
ServerUrl string
SshUrl string
TunnelUrl string
SecretKey string
Tunnel Tunnel
UseLocalHost bool
Debug bool
EnableRequestLogging bool
ServerUrl string
SshUrl string
TunnelUrl string
SecretKey string
Tunnel Tunnel
UseLocalHost bool
Debug bool
EnableRequestLogging bool
HealthCheckInterval int
HealthCheckMaxRetries int
}

func (c *ClientConfig) GetHttpTunnelAddr() string {
Expand Down
127 changes: 120 additions & 7 deletions tunnel/internal/client/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type SshClient struct {
listener net.Listener
log *slog.Logger
db *db.Db
client *ssh.Client
}

func New(config config.ClientConfig, db *db.Db) *SshClient {
Expand All @@ -42,6 +43,7 @@ func New(config config.ClientConfig, db *db.Db) *SshClient {
listener: nil,
log: slog.New(slog.NewTextHandler(os.Stdout, nil)),
db: db,
client: nil,
}
}

Expand Down Expand Up @@ -98,15 +100,13 @@ func (s *SshClient) startListenerForClient() error {
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}

sshClient, err := ssh.Dial("tcp", s.config.SshUrl, sshConfig)

s.client, err = ssh.Dial("tcp", s.config.SshUrl, sshConfig)
if err != nil {
if s.config.Debug {
s.log.Error("failed to connect to ssh server", "error", err)
}
return err
}
defer sshClient.Close()

localEndpoint := s.config.Tunnel.GetLocalAddr() // Local address to forward to

Expand All @@ -123,7 +123,7 @@ func (s *SshClient) startListenerForClient() error {

// try to connect to 10 random ports
for _, port := range randomPorts {
s.listener, err = sshClient.Listen("tcp", "0.0.0.0:"+fmt.Sprint(port))
s.listener, err = s.client.Listen("tcp", "0.0.0.0:"+fmt.Sprint(port))
remotePort = port
if err == nil {
break
Expand Down Expand Up @@ -361,12 +361,125 @@ func (s *SshClient) Shutdown(ctx context.Context) error {
return nil
}

func (s *SshClient) Start(_ context.Context) {
func (s *SshClient) StartHealthCheck(ctx context.Context) {
ticker := time.Tick(time.Duration(s.config.HealthCheckInterval) * time.Second)
retryAttempts := 0

var err error

for range ticker {
retryAttempts++
if retryAttempts > s.config.HealthCheckMaxRetries {
fmt.Printf(color.Red("Failed to reconnect to tunnel after %d attempts\n"), retryAttempts)
os.Exit(1)
}

err = s.HealthCheck()
if err == nil {
retryAttempts = 0
continue
}

if s.config.Debug {
s.log.Error("health check failed", "error", err)
}

fmt.Printf(color.Yellow("Tunnel %s is not healthy 🪫, attempting to reconnect\n"), s.config.GetTunnelAddr())

err = s.Reconnect()
if err != nil {
if s.config.Debug {
s.log.Error("failed to reconnect to ssh tunnel", "error", err, "attempts", retryAttempts)
}
} else {
retryAttempts = 0
}

}
}

func (s *SshClient) Start(ctx context.Context) {
fmt.Printf("🌍 Starting tunnel connection for :%d\n", s.config.Tunnel.Port)

if err := s.startListenerForClient(); err != nil {
fmt.Println()
errChan := make(chan error, 1)

go func() {
if err := s.startListenerForClient(); err != nil {
errChan <- err
}
}()

// Wait for either an error or successful connection
select {
case err := <-errChan:
fmt.Println(color.Red(err))
os.Exit(1)
case <-time.After(5 * time.Second):
// If no error after 2 seconds, assume connection is successful
// Start the health check routine
s.StartHealthCheck(ctx)
}
}

func (s *SshClient) Reconnect() error {
if s.client != nil {
if err := s.client.Close(); err != nil {
if s.config.Debug {
s.log.Error("failed to close client", "error", err)
}
}
s.client = nil
}

if s.listener != nil {
if err := s.listener.Close(); err != nil {
if s.config.Debug {
s.log.Error("failed to close listener", "error", err)
}
}
s.listener = nil
}

// Channel to receive errors from the goroutine
errChan := make(chan error, 1)

// Start the listener in a goroutine
go func() {
if err := s.startListenerForClient(); err != nil {
errChan <- err
}
}()

// Wait for either an error or successful connection
select {
case err := <-errChan:
return err
case <-time.After(5 * time.Second):
return nil
}
}

func (s *SshClient) HealthCheck() error {
// Make HTTP request to tunnel address with special header
client := resty.New().
SetTimeout(5 * time.Second)

resp, err := client.R().
SetHeader("X-Portr-Ping-Request", "true").
Get(s.config.GetTunnelAddr())

if err != nil {
if s.config.Debug {
s.log.Error("health check failed, attempting to reconnect", "error", err)
}
return err
}

portrError := resp.Header().Get("X-Portr-Error")
portrErrorReason := resp.Header().Get("X-Portr-Error-Reason")

if portrError == "true" && portrErrorReason == "unregistered-subdomain" {
return fmt.Errorf("unhealthy tunnel")
}
return nil
}