Skip to content

Commit

Permalink
Merge pull request #5 from anexia/pr/add-auth-support
Browse files Browse the repository at this point in the history
Add redis authentication support
  • Loading branch information
beachmachine authored Nov 7, 2022
2 parents 678d86b + f9ecc60 commit 6bbbeb5
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 14 deletions.
49 changes: 48 additions & 1 deletion connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,22 @@ type Connection struct {

protocol string
endpoint string
authUser string
authPassword string
connectTimeout time.Duration
readTimeout time.Duration
writeTimeout time.Duration
}

// Initializes a new connection, of the given protocol and endpoint, with the given connection timeout
// ex: "unix", "/tmp/myAwesomeSocket", 50*time.Millisecond
func NewConnection(Protocol, Endpoint string, ConnectTimeout, ReadTimeout, WriteTimeout time.Duration) *Connection {
func NewConnection(Protocol, Endpoint string, ConnectTimeout, ReadTimeout, WriteTimeout time.Duration,
authUser string, authPassword string) *Connection {
c := &Connection{}
c.protocol = Protocol
c.endpoint = Endpoint
c.authUser = authUser
c.authPassword = authPassword
c.connectTimeout = ConnectTimeout
c.readTimeout = ReadTimeout
c.writeTimeout = WriteTimeout
Expand Down Expand Up @@ -100,6 +105,10 @@ func (c *Connection) ReconnectIfNecessary() (err error) {
c.Writer = writer.NewFlexibleWriter(netReadWriter)
c.Reader = bufio.NewReader(netReadWriter)

if err = c.authenticate(); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -132,6 +141,44 @@ func (this *Connection) SelectDatabase(DatabaseId int) (err error) {
return
}

// Tries to authenticate the connection
// If an error is returned, or if an invalid response is returned from the AUTH command, then this will return an error
func (this *Connection) authenticate() (err error) {
if this.connection == nil {
log.Error("authenticate: Using an invalid connection")
return errors.New("authenticating against an invalid connection")
}

if this.authPassword == "" {
return
}

authCommand := fmt.Sprintf("AUTH %s", this.authPassword)

if this.authUser != "" {
authCommand = fmt.Sprintf("AUTH %s %s", this.authUser, this.authPassword)
}

err = protocol.WriteLine([]byte(authCommand), this.Writer, true)
if err != nil {
log.Error("authenticate: Error received from protocol.FlushLine: %s", err)
return
}

line, isPrefix, err := this.Reader.ReadLine()
if err != nil || isPrefix || !bytes.Equal(line, protocol.OK_RESPONSE) {
if err == nil {
err = errors.New("unknown ReadLine error")
}

log.Error("authenticate: Error while attempting to authenticate. Err:%q Response:%q isPrefix:%t",
err, line, isPrefix)
this.Disconnect()
return errors.New("invalid authentication response")
}
return
}

// Checks if the current connection is up or not
// If we do not get a response, or if we do not get a PONG reply, or if there is any error, returns false
func (myConnection *Connection) CheckConnection() bool {
Expand Down
11 changes: 10 additions & 1 deletion connection/connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ type ConnectionPool struct {
Protocol string
//The endpoint to connect to
Endpoint string
//User to use for authentication against the upstream redis server(s).
AuthUser string
//Password to use for authentication against the upstream redis server(s).
AuthPassword string
//And overridable connect timeout. Defaults to EXTERN_CONNECT_TIMEOUT
ConnectTimeout time.Duration
//An overridable read timeout. Defaults to EXTERN_READ_TIMEOUT
Expand All @@ -70,10 +74,13 @@ type ConnectionPool struct {
// Initialize a new connection pool, for the given protocol/endpoint, with a given pool capacity
// ex: "unix", "/tmp/myAwesomeSocket", 5
func NewConnectionPool(Protocol, Endpoint string, poolCapacity int, connectTimeout time.Duration,
readTimeout time.Duration, writeTimeout time.Duration) (newConnectionPool *ConnectionPool) {
readTimeout time.Duration, writeTimeout time.Duration, authUser string,
authPassword string) (newConnectionPool *ConnectionPool) {
newConnectionPool = &ConnectionPool{}
newConnectionPool.Protocol = Protocol
newConnectionPool.Endpoint = Endpoint
newConnectionPool.AuthUser = authUser
newConnectionPool.AuthPassword = authPassword
newConnectionPool.connectionPool = make(chan *Connection, poolCapacity)
newConnectionPool.ConnectTimeout = connectTimeout
newConnectionPool.ReadTimeout = readTimeout
Expand Down Expand Up @@ -117,6 +124,8 @@ func (cp *ConnectionPool) CreateConnection() *Connection {
cp.ConnectTimeout,
cp.ReadTimeout,
cp.WriteTimeout,
cp.AuthUser,
cp.AuthPassword,
)
}

Expand Down
4 changes: 2 additions & 2 deletions connection/connection_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestRecycleConnection(test *testing.T) {

//Setting the channel at size 2 makes this more interesting
timeout := 500 * time.Millisecond
connectionPool := NewConnectionPool("unix", testSocket, 2, timeout, timeout, timeout)
connectionPool := NewConnectionPool("unix", testSocket, 2, timeout, timeout, timeout, "", "")

connection, err := connectionPool.GetConnection()
if err != nil {
Expand Down Expand Up @@ -111,7 +111,7 @@ func TestCheckConnectionState(test *testing.T) {

// Create the pool, have a size of zero so that no connections are made except for diagnostics
timeout := 10 * time.Millisecond
connectionPool := NewConnectionPool("unix", testSocket, 0, timeout, timeout, timeout)
connectionPool := NewConnectionPool("unix", testSocket, 0, timeout, timeout, timeout, "", "")

// get and release which will actually create the connection
connectionPool.getDiagnosticConnection()
Expand Down
16 changes: 8 additions & 8 deletions connection/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func verifySelectDatabaseSuccess(test *testing.T, database int) {
test.Fatal("Failed to listen on test socket ", testSocket)
}
defer listenSock.Close()
testConnection := NewConnection("unix", testSocket, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond)
testConnection := NewConnection("unix", testSocket, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond, "", "")
testConnection.ReconnectIfNecessary()

//read buffer does't matter
Expand Down Expand Up @@ -77,7 +77,7 @@ func verifySelectDatabaseError(test *testing.T, database int) {
defer func() {
listenSock.Close()
}()
testConnection := NewConnection("unix", testSocket, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond)
testConnection := NewConnection("unix", testSocket, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond, "", "")
testConnection.ReconnectIfNecessary()
//read buffer does't matter
readBuf := bufio.NewReader(bytes.NewBufferString("+NOPE\r\n"))
Expand Down Expand Up @@ -106,7 +106,7 @@ func verifySelectDatabaseTimeout(test *testing.T, database int) {
}
defer listenSock.Close()

testConnection := NewConnection("unix", testSocket, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond)
testConnection := NewConnection("unix", testSocket, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond, "", "")
if err := testConnection.ReconnectIfNecessary(); err != nil {
test.Fatalf("Could not connect to testSocket %s: %s", testSocket, err)
}
Expand Down Expand Up @@ -149,13 +149,13 @@ func TestNewUnixConnection(test *testing.T) {
}
defer listenSock.Close()

connection := NewConnection("unix", testSocket, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond)
connection := NewConnection("unix", testSocket, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond, "", "")
connection.ReconnectIfNecessary()
if connection == nil || connection.connection == nil {
test.Fatal("Connection initialization returned nil, binding to unix endpoint failed")
}

connection = NewConnection("unix", "/tmp/thisdoesnotexist", 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond)
connection = NewConnection("unix", "/tmp/thisdoesnotexist", 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond, "", "")
connection.ReconnectIfNecessary()
if connection != nil && connection.connection != nil {
test.Fatal("Connection initialization success, binding to fake unix endpoint succeeded????")
Expand All @@ -170,14 +170,14 @@ func TestNewTcpConnection(test *testing.T) {
}
defer listenSock.Close()

connection := NewConnection("tcp", testEndpoint, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond)
connection := NewConnection("tcp", testEndpoint, 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond, "", "")
connection.ReconnectIfNecessary()
if connection == nil || connection.connection == nil {
test.Fatal("Connection initialization returned nil, binding to tcp endpoint failed")
}

//reserved sock should have nothing on it
connection = NewConnection("tcp", "localhost:49151", 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond)
connection = NewConnection("tcp", "localhost:49151", 10*time.Millisecond, 10*time.Millisecond, 10*time.Millisecond, "", "")
connection.ReconnectIfNecessary()
if connection != nil && connection.connection != nil {
test.Fatal("Connection initialization success, binding to fake tcp endpoint succeeded????")
Expand All @@ -194,7 +194,7 @@ func TestCheckConnection(test *testing.T) {
listenSock.Close()
}()

connection := NewConnection("unix", testSocket, 100*time.Millisecond, 100*time.Millisecond, 100*time.Millisecond)
connection := NewConnection("unix", testSocket, 100*time.Millisecond, 100*time.Millisecond, 100*time.Millisecond, "", "")
connection.ReconnectIfNecessary()
if connection == nil {
test.Fatal("Connection initialization returned nil, binding to unix endpoint failed")
Expand Down
2 changes: 2 additions & 0 deletions doc/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ for the configuration json is as follows:
"poolSize": int,
"tcpConnections": [string, string, ...],
"unixConnections": [string, string, ...],
"authUser": string,
"authPassword": string,
"localTimeout": int,
"localReadTimeout": int,
Expand Down
2 changes: 2 additions & 0 deletions main/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type PoolConfig struct {
PoolSize int `json:"poolSize"`
TcpConnections []string `json:"tcpConnections"`
UnixConnections []string `json:"unixConnections"`
AuthUser string `json:"authUser"`
AuthPassword string `json:"authPassword"`
LocalTimeout int64 `json:"localTimeout"`
LocalReadTimeout int64 `json:"localReadTimeout"`
LocalWriteTimeout int64 `json:"localWriteTimeout"`
Expand Down
3 changes: 3 additions & 0 deletions main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ func createInstances(configs []PoolConfig) (rmuxInstances []*rmux.RedisMultiplex
log.Info("Setting remote redis write timeout to: %s", duration)
}

rmuxInstance.AuthUser = config.AuthUser
rmuxInstance.AuthPassword = config.AuthPassword

if len(config.TcpConnections) > 0 {
for _, tcpConnection := range config.TcpConnections {
log.Info("Adding tcp (destination) connection: %s", tcpConnection)
Expand Down
9 changes: 7 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ type RedisMultiplexer struct {
PoolSize int
//The primary connection key to use. If we're not operating on a key-based operation, it will go here
PrimaryConnectionPool *connection.ConnectionPool
//And overridable connect timeout. Defaults to EXTERN_CONNECT_TIMEOUT
//User to use for authentication against the upstream redis server(s).
AuthUser string
//Password to use for authentication against the upstream redis server(s).
AuthPassword string
//An overridable connect timeout. Defaults to EXTERN_CONNECT_TIMEOUT
EndpointConnectTimeout time.Duration
//An overridable read timeout. Defaults to EXTERN_READ_TIMEOUT
EndpointReadTimeout time.Duration
Expand Down Expand Up @@ -138,7 +142,8 @@ func NewRedisMultiplexer(listenProtocol, listenEndpoint string, poolSize int) (n
// Adds a connection to the redis multiplexer, for the given protocol and endpoint
func (this *RedisMultiplexer) AddConnection(remoteProtocol, remoteEndpoint string) {
connectionCluster := connection.NewConnectionPool(remoteProtocol, remoteEndpoint, this.PoolSize,
this.EndpointConnectTimeout, this.EndpointReadTimeout, this.EndpointWriteTimeout)
this.EndpointConnectTimeout, this.EndpointReadTimeout, this.EndpointWriteTimeout, this.AuthUser,
this.AuthPassword)
this.ConnectionCluster = append(this.ConnectionCluster, connectionCluster)
if len(this.ConnectionCluster) == 1 {
this.PrimaryConnectionPool = connectionCluster
Expand Down

0 comments on commit 6bbbeb5

Please sign in to comment.