Skip to content

Commit

Permalink
Add SessionPolicyCallback (#80)
Browse files Browse the repository at this point in the history
* Add SessionPolicyCallback

Closes #7

* Update docs related to the embedded sync.Locker in the Context

* Fix mutex in context
  • Loading branch information
belak authored and progrium committed Feb 23, 2019
1 parent 4b72c66 commit e5ece14
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 22 deletions.
8 changes: 6 additions & 2 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/hex"
"net"
"sync"

gossh "golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -59,9 +60,11 @@ var (
// Context is a package specific context interface. It exposes connection
// metadata and allows new values to be easily written to it. It's used in
// authentication handlers and callbacks, and its underlying context.Context is
// exposed on Session in the session Handler.
// exposed on Session in the session Handler. A connection-scoped lock is also
// embedded in the context to make it easier to limit operations per-connection.
type Context interface {
context.Context
sync.Locker

// User returns the username used when establishing the SSH connection.
User() string
Expand Down Expand Up @@ -90,11 +93,12 @@ type Context interface {

type sshContext struct {
context.Context
*sync.Mutex
}

func newContext(srv *Server) (*sshContext, context.CancelFunc) {
innerCtx, cancel := context.WithCancel(context.Background())
ctx := &sshContext{innerCtx}
ctx := &sshContext{innerCtx, &sync.Mutex{}}
ctx.SetValue(ContextKeyServer, srv)
perms := &Permissions{&gossh.Permissions{}}
ctx.SetValue(ContextKeyPermissions, perms)
Expand Down
1 change: 1 addition & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type Server struct {
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil
ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil
DefaultServerConfigCallback DefaultServerConfigCallback // callback for configuring detailed SSH options
SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions

IdleTimeout time.Duration // connection timeout when no activity, none if empty
MaxTimeout time.Duration // absolute connection timeout, none if empty
Expand Down
50 changes: 31 additions & 19 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,30 +84,32 @@ func sessionHandler(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChanne
return
}
sess := &session{
Channel: ch,
conn: conn,
handler: srv.Handler,
ptyCb: srv.PtyCallback,
ctx: ctx,
Channel: ch,
conn: conn,
handler: srv.Handler,
ptyCb: srv.PtyCallback,
sessReqCb: srv.SessionRequestCallback,
ctx: ctx,
}
sess.handleRequests(reqs)
}

type session struct {
sync.Mutex
gossh.Channel
conn *gossh.ServerConn
handler Handler
handled bool
exited bool
pty *Pty
winch chan Window
env []string
ptyCb PtyCallback
cmd []string
ctx Context
sigCh chan<- Signal
sigBuf []Signal
conn *gossh.ServerConn
handler Handler
handled bool
exited bool
pty *Pty
winch chan Window
env []string
ptyCb PtyCallback
sessReqCb SessionRequestCallback
cmd []string
ctx Context
sigCh chan<- Signal
sigBuf []Signal
}

func (sess *session) Write(p []byte) (n int, err error) {
Expand Down Expand Up @@ -209,12 +211,22 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
req.Reply(false, nil)
continue
}
sess.handled = true
req.Reply(true, nil)

var payload = struct{ Value string }{}
gossh.Unmarshal(req.Payload, &payload)
sess.cmd, _ = shlex.Split(payload.Value, true)

// If there's a session policy callback, we need to confirm before
// accepting the session.
if sess.sessReqCb != nil && !sess.sessReqCb(sess, req.Type) {
sess.cmd = nil
req.Reply(false, nil)
continue
}

sess.handled = true
req.Reply(true, nil)

go func() {
sess.handler(sess)
sess.Exit(0)
Expand Down
6 changes: 5 additions & 1 deletion ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package ssh

import (
"crypto/subtle"
gossh "golang.org/x/crypto/ssh"
"net"

gossh "golang.org/x/crypto/ssh"
)

type Signal string
Expand Down Expand Up @@ -46,6 +47,9 @@ type KeyboardInteractiveHandler func(ctx Context, challenger gossh.KeyboardInter
// PtyCallback is a hook for allowing PTY sessions.
type PtyCallback func(ctx Context, pty Pty) bool

// SessionRequestCallback is a callback for allowing or denying SSH sessions.
type SessionRequestCallback func(sess Session, requestType string) bool

// ConnCallback is a hook for new connections before handling.
// It allows wrapping for timeouts and limiting by returning
// the net.Conn that will be used as the underlying connection.
Expand Down

0 comments on commit e5ece14

Please sign in to comment.