@@ -12,7 +12,6 @@ import (
12
12
"io"
13
13
"net"
14
14
"net/http"
15
- "os"
16
15
"os/exec"
17
16
"sync"
18
17
"testing"
@@ -31,9 +30,9 @@ type Server struct {
31
30
CertAuthorityKeys []ssh.PublicKey
32
31
CertChecker ssh.CertChecker
33
32
34
- // RemoteEnv, RemoteDir and CommandHandler are optional configuration .
35
- RemoteEnv [] string
36
- RemoteDir string
33
+ // An optional CommandHandler, which responds to commands sent over SSH .
34
+ // NewServer will give this a default using ExecHandler, which can also
35
+ // be reused from custom handlers.
37
36
CommandHandler CommandHandler
38
37
39
38
// listener and port are set after Start.
@@ -47,7 +46,7 @@ type CommandIO struct {
47
46
StdErr io.Writer
48
47
}
49
48
50
- type CommandHandler func (conn ssh.ConnMetadata , command string , io CommandIO ) int
49
+ type CommandHandler func (conn ssh.ConnMetadata , command string , commandIO CommandIO ) int
51
50
52
51
// NewServer creates and starts a local SSH server for a test.
53
52
// It must be stopped with the Server.Stop method.
@@ -65,9 +64,8 @@ func NewServer(t *testing.T, authorityEndpoint string) (*Server, error) {
65
64
}
66
65
67
66
s := & Server {t : t , hostKey : hk }
68
- s .CommandHandler = s . defaultCommandHandler
67
+ s .CommandHandler = ExecHandler ( "" , nil )
69
68
s .CertChecker = s .defaultCertChecker ()
70
- s .RemoteDir = t .TempDir ()
71
69
s .CertAuthorityKeys = keys
72
70
73
71
if err := s .start (); err != nil {
@@ -89,6 +87,10 @@ func (s *Server) HostKeyConfig() string {
89
87
)
90
88
}
91
89
90
+ func (s * Server ) HostKey () ssh.PublicKey {
91
+ return s .hostKey .PublicKey ()
92
+ }
93
+
92
94
func (s * Server ) start () error {
93
95
t := s .t
94
96
@@ -148,22 +150,25 @@ func (s *Server) Stop() error {
148
150
return nil
149
151
}
150
152
151
- func (s * Server ) defaultCommandHandler (_ ssh.ConnMetadata , command string , commandIO CommandIO ) int {
152
- c := exec .Command ("bash" , "-c" , command )
153
- c .Stdout = commandIO .StdOut
154
- c .Stderr = commandIO .StdErr
155
- c .Stdin = commandIO .StdIn
156
- c .Dir = s .RemoteDir
157
- c .Env = append (os .Environ (), s .RemoteEnv ... )
158
- if err := c .Run (); err != nil {
159
- exitErr := & exec.ExitError {}
160
- if errors .As (err , & exitErr ) {
161
- return exitErr .ExitCode ()
153
+ // ExecHandler returns a CommandHandler to execute a command in the given environment.
154
+ func ExecHandler (workingDir string , env []string ) CommandHandler {
155
+ return func (_ ssh.ConnMetadata , command string , commandIO CommandIO ) int {
156
+ c := exec .Command ("bash" , "-c" , command )
157
+ c .Stdout = commandIO .StdOut
158
+ c .Stderr = commandIO .StdErr
159
+ c .Stdin = commandIO .StdIn
160
+ c .Dir = workingDir
161
+ c .Env = env
162
+ if err := c .Run (); err != nil {
163
+ exitErr := & exec.ExitError {}
164
+ if errors .As (err , & exitErr ) {
165
+ return exitErr .ExitCode ()
166
+ }
167
+ _ , _ = fmt .Fprintf (commandIO .StdErr , "Failed to execute command: %v" , err )
168
+ return 1
162
169
}
163
- _ , _ = fmt .Fprintf (commandIO .StdErr , "Failed to execute command: %v" , err )
164
- return 1
170
+ return 0
165
171
}
166
- return 0
167
172
}
168
173
169
174
func (s * Server ) defaultCertChecker () ssh.CertChecker {
@@ -253,9 +258,9 @@ func (s *Server) handleChannels(conn ssh.ConnMetadata, chans <-chan ssh.NewChann
253
258
for {
254
259
select {
255
260
case s := <- exitWithStatus :
256
- _ , err = channel .SendRequest ("exit-status" , false , ssh .Marshal (struct { Status int }{s }))
261
+ _ , err = channel .SendRequest ("exit-status" , false , ssh .Marshal (struct { Status uint32 }{uint32 ( s ) })) //nolint: gosec
257
262
if err != nil {
258
- t .Errorf ("Failed to send exit status: %v" , err )
263
+ t .Fatalf ("Failed to send exit status: %v" , err )
259
264
}
260
265
goto closeChannel
261
266
case <- timer .C :
0 commit comments