Skip to content

Commit

Permalink
Reconnect unhealthy tunnels (#112)
Browse files Browse the repository at this point in the history
* Reconnect unhealthy tunnels

* Add health_check_interval and health_check_max_retries to client config

* Use Tick
  • Loading branch information
amalshaji authored Nov 15, 2024
1 parent c3188da commit b44d0d4
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 32 deletions.
18 changes: 10 additions & 8 deletions tunnel/internal/client/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ func (c *Client) Start(ctx context.Context, services ...string) error {
continue
}
clientConfigs = append(clientConfigs, config.ClientConfig{
ServerUrl: c.config.ServerUrl,
SshUrl: c.config.SshUrl,
TunnelUrl: c.config.TunnelUrl,
SecretKey: c.config.SecretKey,
Tunnel: tunnel,
UseLocalHost: c.config.UseLocalHost,
Debug: c.config.Debug,
EnableRequestLogging: c.config.EnableRequestLogging,
ServerUrl: c.config.ServerUrl,
SshUrl: c.config.SshUrl,
TunnelUrl: c.config.TunnelUrl,
SecretKey: c.config.SecretKey,
Tunnel: tunnel,
UseLocalHost: c.config.UseLocalHost,
Debug: c.config.Debug,
EnableRequestLogging: c.config.EnableRequestLogging,
HealthCheckInterval: c.config.HealthCheckInterval,
HealthCheckMaxRetries: c.config.HealthCheckMaxRetries,
})
}

Expand Down
46 changes: 29 additions & 17 deletions tunnel/internal/client/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ func (t *Tunnel) GetLocalAddr() string {
}

type Config struct {
ServerUrl string `yaml:"server_url"`
SshUrl string `yaml:"ssh_url"`
TunnelUrl string `yaml:"tunnel_url"`
SecretKey string `yaml:"secret_key"`
Tunnels []Tunnel `yaml:"tunnels"`
UseLocalHost bool `yaml:"use_localhost"`
Debug bool `yaml:"debug"`
UseVite bool `yaml:"use_vite"`
EnableRequestLogging bool `yaml:"enable_request_logging"`
ServerUrl string `yaml:"server_url"`
SshUrl string `yaml:"ssh_url"`
TunnelUrl string `yaml:"tunnel_url"`
SecretKey string `yaml:"secret_key"`
Tunnels []Tunnel `yaml:"tunnels"`
UseLocalHost bool `yaml:"use_localhost"`
Debug bool `yaml:"debug"`
UseVite bool `yaml:"use_vite"`
EnableRequestLogging bool `yaml:"enable_request_logging"`
HealthCheckInterval int `yaml:"health_check_interval"`
HealthCheckMaxRetries int `yaml:"health_check_max_retries"`
}

func (c *Config) SetDefaults() {
Expand All @@ -69,6 +71,14 @@ func (c *Config) SetDefaults() {
c.TunnelUrl = c.ServerUrl
}

if c.HealthCheckInterval == 0 {
c.HealthCheckInterval = 3
}

if c.HealthCheckMaxRetries == 0 {
c.HealthCheckMaxRetries = 10
}

for i := range c.Tunnels {
c.Tunnels[i].SetDefaults()
}
Expand All @@ -84,14 +94,16 @@ func (c Config) GetAdminAddress() string {
}

type ClientConfig struct {
ServerUrl string
SshUrl string
TunnelUrl string
SecretKey string
Tunnel Tunnel
UseLocalHost bool
Debug bool
EnableRequestLogging bool
ServerUrl string
SshUrl string
TunnelUrl string
SecretKey string
Tunnel Tunnel
UseLocalHost bool
Debug bool
EnableRequestLogging bool
HealthCheckInterval int
HealthCheckMaxRetries int
}

func (c *ClientConfig) GetHttpTunnelAddr() string {
Expand Down
127 changes: 120 additions & 7 deletions tunnel/internal/client/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type SshClient struct {
listener net.Listener
log *slog.Logger
db *db.Db
client *ssh.Client
}

func New(config config.ClientConfig, db *db.Db) *SshClient {
Expand All @@ -42,6 +43,7 @@ func New(config config.ClientConfig, db *db.Db) *SshClient {
listener: nil,
log: slog.New(slog.NewTextHandler(os.Stdout, nil)),
db: db,
client: nil,
}
}

Expand Down Expand Up @@ -98,15 +100,13 @@ func (s *SshClient) startListenerForClient() error {
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}

sshClient, err := ssh.Dial("tcp", s.config.SshUrl, sshConfig)

s.client, err = ssh.Dial("tcp", s.config.SshUrl, sshConfig)
if err != nil {
if s.config.Debug {
s.log.Error("failed to connect to ssh server", "error", err)
}
return err
}
defer sshClient.Close()

localEndpoint := s.config.Tunnel.GetLocalAddr() // Local address to forward to

Expand All @@ -123,7 +123,7 @@ func (s *SshClient) startListenerForClient() error {

// try to connect to 10 random ports
for _, port := range randomPorts {
s.listener, err = sshClient.Listen("tcp", "0.0.0.0:"+fmt.Sprint(port))
s.listener, err = s.client.Listen("tcp", "0.0.0.0:"+fmt.Sprint(port))
remotePort = port
if err == nil {
break
Expand Down Expand Up @@ -361,12 +361,125 @@ func (s *SshClient) Shutdown(ctx context.Context) error {
return nil
}

func (s *SshClient) Start(_ context.Context) {
func (s *SshClient) StartHealthCheck(ctx context.Context) {
ticker := time.Tick(time.Duration(s.config.HealthCheckInterval) * time.Second)
retryAttempts := 0

var err error

for range ticker {
retryAttempts++
if retryAttempts > s.config.HealthCheckMaxRetries {
fmt.Printf(color.Red("Failed to reconnect to tunnel after %d attempts\n"), retryAttempts)
os.Exit(1)
}

err = s.HealthCheck()
if err == nil {
retryAttempts = 0
continue
}

if s.config.Debug {
s.log.Error("health check failed", "error", err)
}

fmt.Printf(color.Yellow("Tunnel %s is not healthy 🪫, attempting to reconnect\n"), s.config.GetTunnelAddr())

err = s.Reconnect()
if err != nil {
if s.config.Debug {
s.log.Error("failed to reconnect to ssh tunnel", "error", err, "attempts", retryAttempts)
}
} else {
retryAttempts = 0
}

}
}

func (s *SshClient) Start(ctx context.Context) {
fmt.Printf("🌍 Starting tunnel connection for :%d\n", s.config.Tunnel.Port)

if err := s.startListenerForClient(); err != nil {
fmt.Println()
errChan := make(chan error, 1)

go func() {
if err := s.startListenerForClient(); err != nil {
errChan <- err
}
}()

// Wait for either an error or successful connection
select {
case err := <-errChan:
fmt.Println(color.Red(err))
os.Exit(1)
case <-time.After(5 * time.Second):
// If no error after 2 seconds, assume connection is successful
// Start the health check routine
s.StartHealthCheck(ctx)
}
}

func (s *SshClient) Reconnect() error {
if s.client != nil {
if err := s.client.Close(); err != nil {
if s.config.Debug {
s.log.Error("failed to close client", "error", err)
}
}
s.client = nil
}

if s.listener != nil {
if err := s.listener.Close(); err != nil {
if s.config.Debug {
s.log.Error("failed to close listener", "error", err)
}
}
s.listener = nil
}

// Channel to receive errors from the goroutine
errChan := make(chan error, 1)

// Start the listener in a goroutine
go func() {
if err := s.startListenerForClient(); err != nil {
errChan <- err
}
}()

// Wait for either an error or successful connection
select {
case err := <-errChan:
return err
case <-time.After(5 * time.Second):
return nil
}
}

func (s *SshClient) HealthCheck() error {
// Make HTTP request to tunnel address with special header
client := resty.New().
SetTimeout(5 * time.Second)

resp, err := client.R().
SetHeader("X-Portr-Ping-Request", "true").
Get(s.config.GetTunnelAddr())

if err != nil {
if s.config.Debug {
s.log.Error("health check failed, attempting to reconnect", "error", err)
}
return err
}

portrError := resp.Header().Get("X-Portr-Error")
portrErrorReason := resp.Header().Get("X-Portr-Error-Reason")

if portrError == "true" && portrErrorReason == "unregistered-subdomain" {
return fmt.Errorf("unhealthy tunnel")
}
return nil
}

0 comments on commit b44d0d4

Please sign in to comment.