Skip to content

Commit 64354e3

Browse files
committed
mockssh: expose default command handler for reuse, remove RemoteDir and RemoteEnv
1 parent e3e250a commit 64354e3

File tree

2 files changed

+133
-23
lines changed

2 files changed

+133
-23
lines changed

pkg/mockssh/server.go

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"io"
1313
"net"
1414
"net/http"
15-
"os"
1615
"os/exec"
1716
"sync"
1817
"testing"
@@ -31,9 +30,9 @@ type Server struct {
3130
CertAuthorityKeys []ssh.PublicKey
3231
CertChecker ssh.CertChecker
3332

34-
// RemoteEnv, RemoteDir and CommandHandler are optional configuration.
35-
RemoteEnv []string
36-
RemoteDir string
33+
// An optional CommandHandler, which responds to commands sent over SSH.
34+
// NewServer will give this a default using ExecHandler, which can also
35+
// be reused from custom handlers.
3736
CommandHandler CommandHandler
3837

3938
// listener and port are set after Start.
@@ -47,7 +46,7 @@ type CommandIO struct {
4746
StdErr io.Writer
4847
}
4948

50-
type CommandHandler func(conn ssh.ConnMetadata, command string, io CommandIO) int
49+
type CommandHandler func(conn ssh.ConnMetadata, command string, commandIO CommandIO) int
5150

5251
// NewServer creates and starts a local SSH server for a test.
5352
// It must be stopped with the Server.Stop method.
@@ -65,9 +64,8 @@ func NewServer(t *testing.T, authorityEndpoint string) (*Server, error) {
6564
}
6665

6766
s := &Server{t: t, hostKey: hk}
68-
s.CommandHandler = s.defaultCommandHandler
67+
s.CommandHandler = ExecHandler("", nil)
6968
s.CertChecker = s.defaultCertChecker()
70-
s.RemoteDir = t.TempDir()
7169
s.CertAuthorityKeys = keys
7270

7371
if err := s.start(); err != nil {
@@ -89,6 +87,10 @@ func (s *Server) HostKeyConfig() string {
8987
)
9088
}
9189

90+
func (s *Server) HostKey() ssh.PublicKey {
91+
return s.hostKey.PublicKey()
92+
}
93+
9294
func (s *Server) start() error {
9395
t := s.t
9496

@@ -148,22 +150,25 @@ func (s *Server) Stop() error {
148150
return nil
149151
}
150152

151-
func (s *Server) defaultCommandHandler(_ ssh.ConnMetadata, command string, commandIO CommandIO) int {
152-
c := exec.Command("bash", "-c", command)
153-
c.Stdout = commandIO.StdOut
154-
c.Stderr = commandIO.StdErr
155-
c.Stdin = commandIO.StdIn
156-
c.Dir = s.RemoteDir
157-
c.Env = append(os.Environ(), s.RemoteEnv...)
158-
if err := c.Run(); err != nil {
159-
exitErr := &exec.ExitError{}
160-
if errors.As(err, &exitErr) {
161-
return exitErr.ExitCode()
153+
// ExecHandler returns a CommandHandler to execute a command in the given environment.
154+
func ExecHandler(workingDir string, env []string) CommandHandler {
155+
return func(_ ssh.ConnMetadata, command string, commandIO CommandIO) int {
156+
c := exec.Command("bash", "-c", command)
157+
c.Stdout = commandIO.StdOut
158+
c.Stderr = commandIO.StdErr
159+
c.Stdin = commandIO.StdIn
160+
c.Dir = workingDir
161+
c.Env = env
162+
if err := c.Run(); err != nil {
163+
exitErr := &exec.ExitError{}
164+
if errors.As(err, &exitErr) {
165+
return exitErr.ExitCode()
166+
}
167+
_, _ = fmt.Fprintf(commandIO.StdErr, "Failed to execute command: %v", err)
168+
return 1
162169
}
163-
_, _ = fmt.Fprintf(commandIO.StdErr, "Failed to execute command: %v", err)
164-
return 1
170+
return 0
165171
}
166-
return 0
167172
}
168173

169174
func (s *Server) defaultCertChecker() ssh.CertChecker {
@@ -253,9 +258,9 @@ func (s *Server) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.NewChann
253258
for {
254259
select {
255260
case s := <-exitWithStatus:
256-
_, err = channel.SendRequest("exit-status", false, ssh.Marshal(struct{ Status int }{s}))
261+
_, err = channel.SendRequest("exit-status", false, ssh.Marshal(struct{ Status uint32 }{uint32(s)})) //nolint: gosec
257262
if err != nil {
258-
t.Errorf("Failed to send exit status: %v", err)
263+
t.Fatalf("Failed to send exit status: %v", err)
259264
}
260265
goto closeChannel
261266
case <-timer.C:

pkg/mockssh/server_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
package mockssh_test
2+
3+
import (
4+
"bytes"
5+
"crypto/ed25519"
6+
"crypto/rand"
7+
"encoding/json"
8+
"fmt"
9+
"net"
10+
"net/http"
11+
"strings"
12+
"testing"
13+
14+
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
16+
"golang.org/x/crypto/ssh"
17+
18+
"github.com/platformsh/cli/pkg/mockapi"
19+
"github.com/platformsh/cli/pkg/mockssh"
20+
)
21+
22+
func TestServer(t *testing.T) {
23+
authServer := mockapi.NewAuthServer(t)
24+
defer authServer.Close()
25+
26+
sshServer, err := mockssh.NewServer(t, authServer.URL+"/ssh/authority")
27+
if err != nil {
28+
t.Fatal(err)
29+
}
30+
t.Cleanup(func() {
31+
_ = sshServer.Stop()
32+
})
33+
34+
tempDir := t.TempDir()
35+
sshServer.CommandHandler = mockssh.ExecHandler(tempDir, []string{})
36+
37+
cert := getTestSSHAuth(t, authServer.URL)
38+
39+
// Create the SSH client configuration
40+
address := fmt.Sprintf("127.0.0.1:%d", sshServer.Port())
41+
config := &ssh.ClientConfig{
42+
User: "test",
43+
Auth: []ssh.AuthMethod{ssh.PublicKeys(cert)},
44+
HostKeyCallback: func(_ string, remote net.Addr, key ssh.PublicKey) error {
45+
if remote.String() != address {
46+
return fmt.Errorf("unexpected address: %s", remote.String())
47+
}
48+
if bytes.Equal(sshServer.HostKey().Marshal(), key.Marshal()) {
49+
return nil
50+
}
51+
return fmt.Errorf("host key mismatch")
52+
},
53+
}
54+
55+
client, err := ssh.Dial("tcp", address, config)
56+
require.NoError(t, err)
57+
defer client.Close()
58+
59+
session, err := client.NewSession()
60+
require.NoError(t, err)
61+
defer session.Close()
62+
63+
stdOutBuffer := &bytes.Buffer{}
64+
session.Stdout = stdOutBuffer
65+
66+
require.NoError(t, session.Run("pwd"))
67+
assert.Equal(t, tempDir, strings.TrimRight(stdOutBuffer.String(), "\n"))
68+
69+
session2, err := client.NewSession()
70+
require.NoError(t, err)
71+
defer session2.Close()
72+
err = session2.Run("false")
73+
assert.Error(t, err)
74+
var exitErr *ssh.ExitError
75+
assert.ErrorAs(t, err, &exitErr)
76+
assert.Equal(t, 1, exitErr.ExitStatus())
77+
}
78+
79+
func getTestSSHAuth(t *testing.T, authServerURL string) ssh.Signer {
80+
t.Helper()
81+
82+
// Generate a keypair
83+
_, priv, err := ed25519.GenerateKey(rand.Reader)
84+
require.NoError(t, err)
85+
s, err := ssh.NewSignerFromKey(priv)
86+
require.NoError(t, err)
87+
88+
b, err := json.Marshal(struct{ Key string }{string(ssh.MarshalAuthorizedKey(s.PublicKey()))})
89+
require.NoError(t, err)
90+
resp, err := http.DefaultClient.Post(authServerURL+"/ssh", "application/json", bytes.NewReader(b))
91+
require.NoError(t, err)
92+
defer resp.Body.Close()
93+
94+
var rs struct{ Certificate string }
95+
require.NoError(t, json.NewDecoder(resp.Body).Decode(&rs))
96+
97+
parsed, _, _, _, err := ssh.ParseAuthorizedKey([]byte(rs.Certificate)) //nolint: dogsled
98+
require.NoError(t, err)
99+
100+
cert, _ := parsed.(*ssh.Certificate)
101+
certSigner, err := ssh.NewCertSigner(cert, s)
102+
require.NoError(t, err)
103+
104+
return certSigner
105+
}

0 commit comments

Comments
 (0)