diff --git a/tunnel/internal/client/client/client.go b/tunnel/internal/client/client/client.go index d30690e..c820ea9 100644 --- a/tunnel/internal/client/client/client.go +++ b/tunnel/internal/client/client/client.go @@ -36,14 +36,16 @@ func (c *Client) Start(ctx context.Context, services ...string) error { continue } clientConfigs = append(clientConfigs, config.ClientConfig{ - ServerUrl: c.config.ServerUrl, - SshUrl: c.config.SshUrl, - TunnelUrl: c.config.TunnelUrl, - SecretKey: c.config.SecretKey, - Tunnel: tunnel, - UseLocalHost: c.config.UseLocalHost, - Debug: c.config.Debug, - EnableRequestLogging: c.config.EnableRequestLogging, + ServerUrl: c.config.ServerUrl, + SshUrl: c.config.SshUrl, + TunnelUrl: c.config.TunnelUrl, + SecretKey: c.config.SecretKey, + Tunnel: tunnel, + UseLocalHost: c.config.UseLocalHost, + Debug: c.config.Debug, + EnableRequestLogging: c.config.EnableRequestLogging, + HealthCheckInterval: c.config.HealthCheckInterval, + HealthCheckMaxRetries: c.config.HealthCheckMaxRetries, }) } diff --git a/tunnel/internal/client/config/config.go b/tunnel/internal/client/config/config.go index dc32d5d..694be3c 100644 --- a/tunnel/internal/client/config/config.go +++ b/tunnel/internal/client/config/config.go @@ -45,15 +45,17 @@ func (t *Tunnel) GetLocalAddr() string { } type Config struct { - ServerUrl string `yaml:"server_url"` - SshUrl string `yaml:"ssh_url"` - TunnelUrl string `yaml:"tunnel_url"` - SecretKey string `yaml:"secret_key"` - Tunnels []Tunnel `yaml:"tunnels"` - UseLocalHost bool `yaml:"use_localhost"` - Debug bool `yaml:"debug"` - UseVite bool `yaml:"use_vite"` - EnableRequestLogging bool `yaml:"enable_request_logging"` + ServerUrl string `yaml:"server_url"` + SshUrl string `yaml:"ssh_url"` + TunnelUrl string `yaml:"tunnel_url"` + SecretKey string `yaml:"secret_key"` + Tunnels []Tunnel `yaml:"tunnels"` + UseLocalHost bool `yaml:"use_localhost"` + Debug bool `yaml:"debug"` + UseVite bool `yaml:"use_vite"` + EnableRequestLogging bool `yaml:"enable_request_logging"` + HealthCheckInterval int `yaml:"health_check_interval"` + HealthCheckMaxRetries int `yaml:"health_check_max_retries"` } func (c *Config) SetDefaults() { @@ -69,6 +71,14 @@ func (c *Config) SetDefaults() { c.TunnelUrl = c.ServerUrl } + if c.HealthCheckInterval == 0 { + c.HealthCheckInterval = 3 + } + + if c.HealthCheckMaxRetries == 0 { + c.HealthCheckMaxRetries = 10 + } + for i := range c.Tunnels { c.Tunnels[i].SetDefaults() } @@ -84,14 +94,16 @@ func (c Config) GetAdminAddress() string { } type ClientConfig struct { - ServerUrl string - SshUrl string - TunnelUrl string - SecretKey string - Tunnel Tunnel - UseLocalHost bool - Debug bool - EnableRequestLogging bool + ServerUrl string + SshUrl string + TunnelUrl string + SecretKey string + Tunnel Tunnel + UseLocalHost bool + Debug bool + EnableRequestLogging bool + HealthCheckInterval int + HealthCheckMaxRetries int } func (c *ClientConfig) GetHttpTunnelAddr() string { diff --git a/tunnel/internal/client/ssh/ssh.go b/tunnel/internal/client/ssh/ssh.go index c9f5d39..9320588 100644 --- a/tunnel/internal/client/ssh/ssh.go +++ b/tunnel/internal/client/ssh/ssh.go @@ -34,6 +34,7 @@ type SshClient struct { listener net.Listener log *slog.Logger db *db.Db + client *ssh.Client } func New(config config.ClientConfig, db *db.Db) *SshClient { @@ -42,6 +43,7 @@ func New(config config.ClientConfig, db *db.Db) *SshClient { listener: nil, log: slog.New(slog.NewTextHandler(os.Stdout, nil)), db: db, + client: nil, } } @@ -98,15 +100,13 @@ func (s *SshClient) startListenerForClient() error { HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - sshClient, err := ssh.Dial("tcp", s.config.SshUrl, sshConfig) - + s.client, err = ssh.Dial("tcp", s.config.SshUrl, sshConfig) if err != nil { if s.config.Debug { s.log.Error("failed to connect to ssh server", "error", err) } return err } - defer sshClient.Close() localEndpoint := s.config.Tunnel.GetLocalAddr() // Local address to forward to @@ -123,7 +123,7 @@ func (s *SshClient) startListenerForClient() error { // try to connect to 10 random ports for _, port := range randomPorts { - s.listener, err = sshClient.Listen("tcp", "0.0.0.0:"+fmt.Sprint(port)) + s.listener, err = s.client.Listen("tcp", "0.0.0.0:"+fmt.Sprint(port)) remotePort = port if err == nil { break @@ -361,12 +361,125 @@ func (s *SshClient) Shutdown(ctx context.Context) error { return nil } -func (s *SshClient) Start(_ context.Context) { +func (s *SshClient) StartHealthCheck(ctx context.Context) { + ticker := time.Tick(time.Duration(s.config.HealthCheckInterval) * time.Second) + retryAttempts := 0 + + var err error + + for range ticker { + retryAttempts++ + if retryAttempts > s.config.HealthCheckMaxRetries { + fmt.Printf(color.Red("Failed to reconnect to tunnel after %d attempts\n"), retryAttempts) + os.Exit(1) + } + + err = s.HealthCheck() + if err == nil { + retryAttempts = 0 + continue + } + + if s.config.Debug { + s.log.Error("health check failed", "error", err) + } + + fmt.Printf(color.Yellow("Tunnel %s is not healthy 🪫, attempting to reconnect\n"), s.config.GetTunnelAddr()) + + err = s.Reconnect() + if err != nil { + if s.config.Debug { + s.log.Error("failed to reconnect to ssh tunnel", "error", err, "attempts", retryAttempts) + } + } else { + retryAttempts = 0 + } + + } +} + +func (s *SshClient) Start(ctx context.Context) { fmt.Printf("🌍 Starting tunnel connection for :%d\n", s.config.Tunnel.Port) - if err := s.startListenerForClient(); err != nil { - fmt.Println() + errChan := make(chan error, 1) + + go func() { + if err := s.startListenerForClient(); err != nil { + errChan <- err + } + }() + + // Wait for either an error or successful connection + select { + case err := <-errChan: fmt.Println(color.Red(err)) os.Exit(1) + case <-time.After(5 * time.Second): + // If no error after 2 seconds, assume connection is successful + // Start the health check routine + s.StartHealthCheck(ctx) } } + +func (s *SshClient) Reconnect() error { + if s.client != nil { + if err := s.client.Close(); err != nil { + if s.config.Debug { + s.log.Error("failed to close client", "error", err) + } + } + s.client = nil + } + + if s.listener != nil { + if err := s.listener.Close(); err != nil { + if s.config.Debug { + s.log.Error("failed to close listener", "error", err) + } + } + s.listener = nil + } + + // Channel to receive errors from the goroutine + errChan := make(chan error, 1) + + // Start the listener in a goroutine + go func() { + if err := s.startListenerForClient(); err != nil { + errChan <- err + } + }() + + // Wait for either an error or successful connection + select { + case err := <-errChan: + return err + case <-time.After(5 * time.Second): + return nil + } +} + +func (s *SshClient) HealthCheck() error { + // Make HTTP request to tunnel address with special header + client := resty.New(). + SetTimeout(5 * time.Second) + + resp, err := client.R(). + SetHeader("X-Portr-Ping-Request", "true"). + Get(s.config.GetTunnelAddr()) + + if err != nil { + if s.config.Debug { + s.log.Error("health check failed, attempting to reconnect", "error", err) + } + return err + } + + portrError := resp.Header().Get("X-Portr-Error") + portrErrorReason := resp.Header().Get("X-Portr-Error-Reason") + + if portrError == "true" && portrErrorReason == "unregistered-subdomain" { + return fmt.Errorf("unhealthy tunnel") + } + return nil +}