diff --git a/src/net/fd_posix.go b/src/net/fd_posix.go index 7f3aeff580c95d..5c88b50cdae49c 100644 --- a/src/net/fd_posix.go +++ b/src/net/fd_posix.go @@ -29,6 +29,7 @@ type netFD struct { // number of bytes transferred. readHook func(int) writeHook func(int) + closeHook func() } func (fd *netFD) setAddr(laddr, raddr Addr) { @@ -39,6 +40,9 @@ func (fd *netFD) setAddr(laddr, raddr Addr) { func (fd *netFD) Close() error { runtime.SetFinalizer(fd, nil) + if fd.closeHook != nil { + fd.closeHook() + } return fd.pfd.Close() } @@ -49,10 +53,16 @@ func (fd *netFD) shutdown(how int) error { } func (fd *netFD) closeRead() error { + if fd.closeHook != nil { + fd.closeHook() + } return fd.shutdown(syscall.SHUT_RD) } func (fd *netFD) closeWrite() error { + if fd.closeHook != nil { + fd.closeHook() + } return fd.shutdown(syscall.SHUT_WR) } @@ -94,7 +104,7 @@ func (fd *netFD) readFromInet6(p []byte, from *syscall.SockaddrInet6) (n int, er func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int, sa syscall.Sockaddr, err error) { n, oobn, retflags, sa, err = fd.pfd.ReadMsg(p, oob, flags) if fd.readHook != nil && err == nil { - fd.readHook(n + oobn) + fd.readHook(n) } runtime.KeepAlive(fd) return n, oobn, retflags, sa, wrapSyscallError(readMsgSyscallName, err) @@ -103,7 +113,7 @@ func (fd *netFD) readMsg(p []byte, oob []byte, flags int) (n, oobn, retflags int func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet4) (n, oobn, retflags int, err error) { n, oobn, retflags, err = fd.pfd.ReadMsgInet4(p, oob, flags, sa) if fd.readHook != nil && err == nil { - fd.readHook(n + oobn) + fd.readHook(n) } runtime.KeepAlive(fd) return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err) @@ -112,7 +122,7 @@ func (fd *netFD) readMsgInet4(p []byte, oob []byte, flags int, sa *syscall.Socka func (fd *netFD) readMsgInet6(p []byte, oob []byte, flags int, sa *syscall.SockaddrInet6) (n, oobn, retflags int, err error) { n, oobn, retflags, err = fd.pfd.ReadMsgInet6(p, oob, flags, sa) if fd.readHook != nil && err == nil { - fd.readHook(n + oobn) + fd.readHook(n) } runtime.KeepAlive(fd) return n, oobn, retflags, wrapSyscallError(readMsgSyscallName, err) @@ -157,7 +167,7 @@ func (fd *netFD) writeToInet6(p []byte, sa *syscall.SockaddrInet6) (n int, err e func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) { n, oobn, err = fd.pfd.WriteMsg(p, oob, sa) if fd.writeHook != nil && err == nil { - fd.writeHook(n + oobn) + fd.writeHook(n) } runtime.KeepAlive(fd) return n, oobn, wrapSyscallError(writeMsgSyscallName, err) @@ -166,7 +176,7 @@ func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oob func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) (n int, oobn int, err error) { n, oobn, err = fd.pfd.WriteMsgInet4(p, oob, sa) if fd.writeHook != nil && err == nil { - fd.writeHook(n + oobn) + fd.writeHook(n) } runtime.KeepAlive(fd) return n, oobn, wrapSyscallError(writeMsgSyscallName, err) @@ -175,7 +185,7 @@ func (fd *netFD) writeMsgInet4(p []byte, oob []byte, sa *syscall.SockaddrInet4) func (fd *netFD) writeMsgInet6(p []byte, oob []byte, sa *syscall.SockaddrInet6) (n int, oobn int, err error) { n, oobn, err = fd.pfd.WriteMsgInet6(p, oob, sa) if fd.writeHook != nil && err == nil { - fd.writeHook(n + oobn) + fd.writeHook(n) } runtime.KeepAlive(fd) return n, oobn, wrapSyscallError(writeMsgSyscallName, err) diff --git a/src/net/sock_posix.go b/src/net/sock_posix.go index 5c5340c1b756ff..759abe872b15cc 100644 --- a/src/net/sock_posix.go +++ b/src/net/sock_posix.go @@ -31,6 +31,21 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only if trace := ContextSockTrace(ctx); trace != nil { fd.readHook = trace.DidRead fd.writeHook = trace.DidWrite + if (trace.DidCreateTCPConn != nil || trace.WillCloseTCPConn != nil) && len(net) >= 3 && net[0:3] == "tcp" { + // Ignore newRawConn errors (they're not possible in the current + // implementation, but even if they were, we don't want to + // affect socket operations for a trace hook invocation). + if c, err := newRawConn(fd); err == nil { + if trace.DidCreateTCPConn != nil { + trace.DidCreateTCPConn(c) + } + if trace.WillCloseTCPConn != nil { + fd.closeHook = func() { + trace.WillCloseTCPConn(c) + } + } + } + } } // This function makes a network file descriptor for the diff --git a/src/net/socktrace.go b/src/net/socktrace.go index 57af2be8b75e3e..b02a8d12484d4e 100644 --- a/src/net/socktrace.go +++ b/src/net/socktrace.go @@ -6,12 +6,16 @@ package net import ( "context" + "syscall" ) // SockTrace is a set of hooks to run at various operations on a network socket. // Any particular hook may be nil. Functions may be called concurrently from // different goroutines. type SockTrace struct { + // DidOpenTCPConn is called when a TCP socket was created. The + // underlying raw network connection that was created is provided. + DidCreateTCPConn func(c syscall.RawConn) // DidRead is called after a successful read from the socket, where n bytes // were read. DidRead func(n int) @@ -22,6 +26,9 @@ type SockTrace struct { // subsequent call to WithSockTrace. The provided trace is the new trace // that will be used. WillOverwrite func(trace *SockTrace) + // WillCloseTCPConn is called when a TCP socket is about to be closed. The + // underlying raw network connection that is being closed is provided. + WillCloseTCPConn func(c syscall.RawConn) } // WithSockTrace returns a new context based on the provided parent