diff --git a/connect.go b/connect.go index d55b211..cbe5ee0 100644 --- a/connect.go +++ b/connect.go @@ -11,6 +11,7 @@ import ( "net" "os" "os/signal" + "sync" "syscall" "time" @@ -53,6 +54,9 @@ type Connect struct { // Set the TTY to be used as the input and output for the Session/Cmd. PtyRelayTty *os.File + // StdoutMutex is a mutex for use Stdout. + StdoutMutex *sync.Mutex + // CheckKnownHosts if true, check knownhosts. // Ignored if HostKeyCallback is set. // Set it before CraeteClient. @@ -139,7 +143,7 @@ func (c *Connect) CreateClient(host, port, user string, authMethods []ssh.AuthMe // append default files c.KnownHostsFiles = append(c.KnownHostsFiles, "~/.ssh/known_hosts") } - config.HostKeyCallback = c.verifyAndAppendNew + config.HostKeyCallback = c.VerifyAndAppendNew } else { config.HostKeyCallback = ssh.InsecureIgnoreHostKey() } diff --git a/knownhosts.go b/knownhosts.go index cf7fff6..7973ce1 100644 --- a/knownhosts.go +++ b/knownhosts.go @@ -10,6 +10,7 @@ import ( "net" "os" "os/signal" + "slices" "strings" "syscall" "text/template" @@ -36,19 +37,19 @@ type OverwriteInventory struct { // If is no problem, error returns Nil. // // 【参考】: https://github.com/tatsushid/minssh/blob/57eae8c5bcf5d94639891f3267f05251f05face4/pkg/minssh/minssh.go#L190-L237 -func (c *Connect) verifyAndAppendNew(hostname string, remote net.Addr, key ssh.PublicKey) (err error) { +func (c *Connect) VerifyAndAppendNew(hostname string, remote net.Addr, key ssh.PublicKey) (err error) { // set TextAskWriteKnownHosts default text if len(c.TextAskWriteKnownHosts) == 0 { c.TextAskWriteKnownHosts += "The authenticity of host '{{.Address}} ({{.RemoteAddr}})' can't be established.\n" c.TextAskWriteKnownHosts += "RSA key fingerprint is {{.Fingerprint}}\n" - c.TextAskWriteKnownHosts += "Are you sure you want to continue connecting (yes/no)?" + c.TextAskWriteKnownHosts += "Are you sure you want to continue connecting ((yes|y)/(no|n))? " } // set TextAskOverwriteKnownHosts default text if len(c.TextAskOverwriteKnownHosts) == 0 { c.TextAskOverwriteKnownHosts += "The authenticity of host '{{.Address}} ({{.RemoteAddr}})' can't be established.\n" c.TextAskOverwriteKnownHosts += "Old key: {{.OldKeyText}}\n" - c.TextAskOverwriteKnownHosts += "Are you sure you want to overwrite {{.Fingerprint}}, continue connecting (yes/no)?" + c.TextAskOverwriteKnownHosts += "Are you sure you want to overwrite {{.Fingerprint}}, continue connecting ((yes|y)/(no|n))? " } // check count KnownHostsFiles @@ -79,6 +80,12 @@ func (c *Connect) verifyAndAppendNew(hostname string, remote net.Addr, key ssh.P filepath := knownHostsFiles[0] var line int + // check mutex + if c.StdoutMutex != nil { + c.StdoutMutex.Lock() + defer c.StdoutMutex.Unlock() + } + // check error keyErr, ok := err.(*knownhosts.KeyError) if !ok || len(keyErr.Want) > 0 { @@ -107,7 +114,7 @@ func (c *Connect) verifyAndAppendNew(hostname string, remote net.Addr, key ssh.P err = writeKnownHostsKey(filepath, line, hostname, remote, key) - return nil + return err } // askAddingUnknownHostKey @@ -150,9 +157,9 @@ func askAddingUnknownHostKey(text string, address string, remote net.Addr, key s return false, fmt.Errorf("failed to read answer: %s", err) } answer = string(strings.ToLower(strings.TrimSpace(answer))) - if answer == "yes" { + if slices.Contains([]string{"yes", "y"}, answer) { return true, nil - } else if answer == "no" { + } else if slices.Contains([]string{"no", "n"}, answer) { return false, nil } fmt.Print("Please type 'yes' or 'no': ") @@ -204,7 +211,7 @@ func askOverwriteKnownHostKey(text string, address string, remote net.Addr, key } else if answer == "no" { return false, nil } - fmt.Print("Please type 'yes' or 'no': ") + fmt.Print("Please type 'yes|y' or 'no|n': ") } }