diff --git a/internal/server/ssh.go b/internal/server/ssh.go index ad8620d..cc45301 100644 --- a/internal/server/ssh.go +++ b/internal/server/ssh.go @@ -123,7 +123,7 @@ func (s *SSHServer) Listen(ctx context.Context, m *movie.Movie) error { func (s *SSHServer) Handler(m *movie.Movie) bubbletea.Handler { return func(session ssh.Session) (tea.Model, []tea.ProgramOption) { - remoteIP := RemoteIP(session.RemoteAddr().String()) + remoteIP := RemoteIP(session.RemoteAddr()) logger := s.Log.With( "remoteIP", remoteIP, "user", session.User(), @@ -151,7 +151,7 @@ func (s *SSHServer) Handler(m *movie.Movie) bubbletea.Handler { func (s *SSHServer) TrackStream(handler ssh.Handler) ssh.Handler { return func(session ssh.Session) { - remoteIP := RemoteIP(session.RemoteAddr().String()) + remoteIP := RemoteIP(session.RemoteAddr()) id, err := serverInfo.StreamConnect("ssh", remoteIP) if err != nil { s.Log.Error("Failed to begin stream", diff --git a/internal/server/telnet.go b/internal/server/telnet.go index 3f36879..6e604eb 100644 --- a/internal/server/telnet.go +++ b/internal/server/telnet.go @@ -92,7 +92,7 @@ func (s *TelnetServer) Handler(ctx context.Context, conn net.Conn, m *movie.Movi _ = conn.Close() }(conn) - remoteIP := RemoteIP(conn.RemoteAddr().String()) + remoteIP := RemoteIP(conn.RemoteAddr()) logger := s.Log.With("remoteIP", remoteIP) id, err := serverInfo.StreamConnect("telnet", remoteIP) diff --git a/internal/server/util.go b/internal/server/util.go index ed6dbac..5ce9f62 100644 --- a/internal/server/util.go +++ b/internal/server/util.go @@ -4,10 +4,16 @@ import ( "net" ) -func RemoteIP(remoteIPPort string) string { - remoteIP, _, err := net.SplitHostPort(remoteIPPort) - if err != nil { - remoteIP = remoteIPPort +func RemoteIP(addr net.Addr) string { + switch addr := addr.(type) { + case *net.TCPAddr: + return addr.IP.String() + default: + ipPort := addr.String() + ip, _, err := net.SplitHostPort(ipPort) + if err != nil { + ip = ipPort + } + return ip } - return remoteIP } diff --git a/internal/server/util_test.go b/internal/server/util_test.go index b988c7d..43ec204 100644 --- a/internal/server/util_test.go +++ b/internal/server/util_test.go @@ -1,6 +1,7 @@ package server import ( + "net" "testing" "github.com/stretchr/testify/assert" @@ -8,19 +9,22 @@ import ( func TestRemoteIp(t *testing.T) { type args struct { - remoteIPPort string + n net.Addr } tests := []struct { name string args args want string }{ - {"127.0.0.1", args{"127.0.0.1"}, "127.0.0.1"}, - {"127.0.0.1:12345", args{"127.0.0.1:12345"}, "127.0.0.1"}, + { + "127.0.0.1:12345", + args{&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 12345}}, + "127.0.0.1", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assert.Equalf(t, tt.want, RemoteIP(tt.args.remoteIPPort), "RemoteIP(%v)", tt.args.remoteIPPort) + assert.Equal(t, tt.want, RemoteIP(tt.args.n)) }) } }