diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index eac440496..7cded3d85 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,8 +17,8 @@ jobs: - name: Checkout Redis uses: actions/checkout@v2 with: - repository: 'redis/redis' - ref: 'unstable' + repository: 'sjpotter/redis' + ref: 'ed754fe01934cebf2ed3343276ef564c92c4c74f' path: 'redis' - name: Build Redis run: cd redis && make -j 4 gcov @@ -54,8 +54,8 @@ jobs: - name: Checkout Redis uses: actions/checkout@v2 with: - repository: 'redis/redis' - ref: 'unstable' + repository: 'sjpotter/redis' + ref: 'ed754fe01934cebf2ed3343276ef564c92c4c74f' path: 'redis' - name: Build Redis run: cd redis && make -j 4 SANITIZER=address diff --git a/deps/common/redismodule.h b/deps/common/redismodule.h index 2f3bb8586..b0bb8fb37 100644 --- a/deps/common/redismodule.h +++ b/deps/common/redismodule.h @@ -41,7 +41,7 @@ typedef long long ustime_t; /* API versions. */ #define REDISMODULE_APIVER_1 1 -/* Version of the RedisModuleTypeMethods structure. Once the RedisModuleTypeMethods +/* Version of the RedisModuleTypeMethods structure. Once the RedisModuleTypeMethods * structure is changed, this version number needs to be changed synchronistically. */ #define REDISMODULE_TYPE_METHOD_VERSION 5 @@ -309,6 +309,14 @@ typedef uint64_t RedisModuleTimerID; * Use RedisModule_GetModuleOptionsAll instead. */ #define _REDISMODULE_OPTIONS_FLAGS_NEXT (1<<4) +/* RM_GetClientFlags */ +#define REDISMODULE_CLIENT_FLAG_DIRTY_CAS (1<<0) /* Watched keys modified. EXEC will fail. */ +#define REDISMODULE_CLIENT_FLAG_DIRTY_EXEC (1<<1) /* EXEC will fail for errors while queueing */ +/* Next client flag, must be updated when adding new flags above! +This flag should not be used directly by the module. + * Use RedisModule_GetClientFlagsAll instead. */ +#define _REDISMODULE_CLIENT_FLAGS_NEXT (1<<2) + /* Definitions for RedisModule_SetCommandInfo. */ typedef enum { @@ -587,7 +595,7 @@ static const RedisModuleEvent /* Deprecated since Redis 7.0, not used anymore. */ __attribute__ ((deprecated)) RedisModuleEvent_ReplBackup = { - REDISMODULE_EVENT_REPL_BACKUP, + REDISMODULE_EVENT_REPL_BACKUP, 1 }, RedisModuleEvent_ReplAsyncLoad = { @@ -880,6 +888,7 @@ typedef struct RedisModuleCommandFilter RedisModuleCommandFilter; typedef struct RedisModuleServerInfoData RedisModuleServerInfoData; typedef struct RedisModuleScanCursor RedisModuleScanCursor; typedef struct RedisModuleUser RedisModuleUser; +typedef struct RedisModuleClient RedisModuleClient; typedef struct RedisModuleKeyOptCtx RedisModuleKeyOptCtx; typedef struct RedisModuleRdbStream RedisModuleRdbStream; @@ -976,7 +985,7 @@ REDISMODULE_API int (*RedisModule_GetSelectedDb)(RedisModuleCtx *ctx) REDISMODUL REDISMODULE_API int (*RedisModule_SelectDb)(RedisModuleCtx *ctx, int newid) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_KeyExists)(RedisModuleCtx *ctx, RedisModuleString *keyname) REDISMODULE_ATTR; REDISMODULE_API RedisModuleKey * (*RedisModule_OpenKey)(RedisModuleCtx *ctx, RedisModuleString *keyname, int mode) REDISMODULE_ATTR; -REDISMODULE_API int (*RedisModule_GetOpenKeyModesAll)() REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetOpenKeyModesAll)(void) REDISMODULE_ATTR; REDISMODULE_API void (*RedisModule_CloseKey)(RedisModuleKey *kp) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_KeyType)(RedisModuleKey *kp) REDISMODULE_ATTR; REDISMODULE_API size_t (*RedisModule_ValueLength)(RedisModuleKey *kp) REDISMODULE_ATTR; @@ -1098,7 +1107,7 @@ REDISMODULE_API int (*RedisModule_SetClientNameById)(uint64_t id, RedisModuleStr REDISMODULE_API int (*RedisModule_PublishMessage)(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_PublishMessageShard)(RedisModuleCtx *ctx, RedisModuleString *channel, RedisModuleString *message) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_GetContextFlags)(RedisModuleCtx *ctx) REDISMODULE_ATTR; -REDISMODULE_API int (*RedisModule_AvoidReplicaTraffic)() REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AvoidReplicaTraffic)(void) REDISMODULE_ATTR; REDISMODULE_API void * (*RedisModule_PoolAlloc)(RedisModuleCtx *ctx, size_t bytes) REDISMODULE_ATTR; REDISMODULE_API RedisModuleType * (*RedisModule_CreateDataType)(RedisModuleCtx *ctx, const char *name, int encver, RedisModuleTypeMethods *typemethods) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_ModuleTypeSetValue)(RedisModuleKey *key, RedisModuleType *mt, void *value) REDISMODULE_ATTR; @@ -1201,17 +1210,17 @@ REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_BlockClientOnKeys)(Redi REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_BlockClientOnKeysWithFlags)(RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, RedisModuleCmdFunc timeout_callback, void (*free_privdata)(RedisModuleCtx*,void*), long long timeout_ms, RedisModuleString **keys, int numkeys, void *privdata, int flags) REDISMODULE_ATTR; REDISMODULE_API void (*RedisModule_SignalKeyAsReady)(RedisModuleCtx *ctx, RedisModuleString *key) REDISMODULE_ATTR; REDISMODULE_API RedisModuleString * (*RedisModule_GetBlockedClientReadyKey)(RedisModuleCtx *ctx) REDISMODULE_ATTR; -REDISMODULE_API RedisModuleScanCursor * (*RedisModule_ScanCursorCreate)() REDISMODULE_ATTR; +REDISMODULE_API RedisModuleScanCursor * (*RedisModule_ScanCursorCreate)(void) REDISMODULE_ATTR; REDISMODULE_API void (*RedisModule_ScanCursorRestart)(RedisModuleScanCursor *cursor) REDISMODULE_ATTR; REDISMODULE_API void (*RedisModule_ScanCursorDestroy)(RedisModuleScanCursor *cursor) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_Scan)(RedisModuleCtx *ctx, RedisModuleScanCursor *cursor, RedisModuleScanCB fn, void *privdata) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_ScanKey)(RedisModuleKey *key, RedisModuleScanCursor *cursor, RedisModuleScanKeyCB fn, void *privdata) REDISMODULE_ATTR; -REDISMODULE_API int (*RedisModule_GetContextFlagsAll)() REDISMODULE_ATTR; -REDISMODULE_API int (*RedisModule_GetModuleOptionsAll)() REDISMODULE_ATTR; -REDISMODULE_API int (*RedisModule_GetKeyspaceNotificationFlagsAll)() REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetContextFlagsAll)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetModuleOptionsAll)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetKeyspaceNotificationFlagsAll)(void) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_IsSubEventSupported)(RedisModuleEvent event, uint64_t subevent) REDISMODULE_ATTR; -REDISMODULE_API int (*RedisModule_GetServerVersion)() REDISMODULE_ATTR; -REDISMODULE_API int (*RedisModule_GetTypeMethodVersion)() REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetServerVersion)(void) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetTypeMethodVersion)(void) REDISMODULE_ATTR; REDISMODULE_API void (*RedisModule_Yield)(RedisModuleCtx *ctx, int flags, const char *busy_reply) REDISMODULE_ATTR; REDISMODULE_API RedisModuleBlockedClient * (*RedisModule_BlockClient)(RedisModuleCtx *ctx, RedisModuleCmdFunc reply_callback, RedisModuleCmdFunc timeout_callback, void (*free_privdata)(RedisModuleCtx*,void*), long long timeout_ms) REDISMODULE_ATTR; REDISMODULE_API void * (*RedisModule_BlockClientGetPrivateData)(RedisModuleBlockedClient *blocked_client) REDISMODULE_ATTR; @@ -1234,7 +1243,7 @@ REDISMODULE_API void (*RedisModule_ThreadSafeContextUnlock)(RedisModuleCtx *ctx) REDISMODULE_API int (*RedisModule_SubscribeToKeyspaceEvents)(RedisModuleCtx *ctx, int types, RedisModuleNotificationFunc cb) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_AddPostNotificationJob)(RedisModuleCtx *ctx, RedisModulePostNotificationJobFunc callback, void *pd, void (*free_pd)(void*)) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_NotifyKeyspaceEvent)(RedisModuleCtx *ctx, int type, const char *event, RedisModuleString *key) REDISMODULE_ATTR; -REDISMODULE_API int (*RedisModule_GetNotifyKeyspaceEvents)() REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_GetNotifyKeyspaceEvents)(void) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_BlockedClientDisconnected)(RedisModuleCtx *ctx) REDISMODULE_ATTR; REDISMODULE_API void (*RedisModule_RegisterClusterMessageReceiver)(RedisModuleCtx *ctx, uint8_t type, RedisModuleClusterMessageReceiver callback) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_SendClusterMessage)(RedisModuleCtx *ctx, const char *target_id, uint8_t type, const char *msg, uint32_t len) REDISMODULE_ATTR; @@ -1259,11 +1268,12 @@ REDISMODULE_API RedisModuleString * (*RedisModule_CommandFilterArgGet)(RedisModu REDISMODULE_API int (*RedisModule_CommandFilterArgInsert)(RedisModuleCommandFilterCtx *fctx, int pos, RedisModuleString *arg) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_CommandFilterArgReplace)(RedisModuleCommandFilterCtx *fctx, int pos, RedisModuleString *arg) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_CommandFilterArgDelete)(RedisModuleCommandFilterCtx *fctx, int pos) REDISMODULE_ATTR; +REDISMODULE_API unsigned long long (*RedisModule_CommandFilterGetClientId)(RedisModuleCommandFilterCtx *fctx) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_Fork)(RedisModuleForkDoneHandler cb, void *user_data) REDISMODULE_ATTR; REDISMODULE_API void (*RedisModule_SendChildHeartbeat)(double progress) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_ExitFromChild)(int retcode) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_KillForkChild)(int child_pid) REDISMODULE_ATTR; -REDISMODULE_API float (*RedisModule_GetUsedMemoryRatio)() REDISMODULE_ATTR; +REDISMODULE_API float (*RedisModule_GetUsedMemoryRatio)(void) REDISMODULE_ATTR; REDISMODULE_API size_t (*RedisModule_MallocSize)(void* ptr) REDISMODULE_ATTR; REDISMODULE_API size_t (*RedisModule_MallocUsableSize)(void *ptr) REDISMODULE_ATTR; REDISMODULE_API size_t (*RedisModule_MallocSizeString)(RedisModuleString* str) REDISMODULE_ATTR; @@ -1309,6 +1319,12 @@ REDISMODULE_API RedisModuleRdbStream *(*RedisModule_RdbStreamCreateFromFile)(con REDISMODULE_API void (*RedisModule_RdbStreamFree)(RedisModuleRdbStream *stream) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_RdbLoad)(RedisModuleCtx *ctx, RedisModuleRdbStream *stream, int flags) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_RdbSave)(RedisModuleCtx *ctx, RedisModuleRdbStream *stream, int flags) REDISMODULE_ATTR; +REDISMODULE_API RedisModuleClient * (*RedisModule_CreateModuleClient)(RedisModuleCtx *ctx) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_FreeModuleClient)(RedisModuleCtx *ctx, RedisModuleClient *client) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetContextClient)(RedisModuleCtx *ctx, RedisModuleClient *client) REDISMODULE_ATTR; +REDISMODULE_API uint64_t (*RedisModule_GetClientFlags)(RedisModuleClient *client) REDISMODULE_ATTR; +REDISMODULE_API uint64_t (*RedisModule_GetClientFlagsAll)(void) REDISMODULE_ATTR; +REDISMODULE_API void (*RedisModule_SetClientUser)(RedisModuleClient *client, RedisModuleUser *user) REDISMODULE_ATTR; #define RedisModule_IsAOFClient(id) ((id) == UINT64_MAX) @@ -1619,6 +1635,7 @@ static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int REDISMODULE_GET_API(CommandFilterArgInsert); REDISMODULE_GET_API(CommandFilterArgReplace); REDISMODULE_GET_API(CommandFilterArgDelete); + REDISMODULE_GET_API(CommandFilterGetClientId); REDISMODULE_GET_API(Fork); REDISMODULE_GET_API(SendChildHeartbeat); REDISMODULE_GET_API(ExitFromChild); @@ -1669,6 +1686,13 @@ static int RedisModule_Init(RedisModuleCtx *ctx, const char *name, int ver, int REDISMODULE_GET_API(RdbStreamFree); REDISMODULE_GET_API(RdbLoad); REDISMODULE_GET_API(RdbSave); + REDISMODULE_GET_API(CreateModuleClient); + REDISMODULE_GET_API(FreeModuleClient); + REDISMODULE_GET_API(SetContextClient); + REDISMODULE_GET_API(GetClientFlags); + REDISMODULE_GET_API(GetClientFlagsAll); + REDISMODULE_GET_API(SetClientUser); + if (RedisModule_IsModuleNameBusy && RedisModule_IsModuleNameBusy(name)) return REDISMODULE_ERR; RedisModule_SetModuleAttribs(ctx,name,ver,apiver); diff --git a/src/clientstate.c b/src/clientstate.c index 6cc3257eb..e35855bc6 100644 --- a/src/clientstate.c +++ b/src/clientstate.c @@ -20,6 +20,7 @@ ClientState *ClientStateGet(RedisRaftCtx *rr, RedisModuleCtx *ctx) void ClientStateAlloc(RedisRaftCtx *rr, unsigned long long client_id) { ClientState *clientState = RedisModule_Calloc(sizeof(ClientState), 1); + clientState->client_id = client_id; int ret = RedisModule_DictSetC(rr->client_state, &client_id, sizeof(client_id), clientState); RedisModule_Assert(ret == REDISMODULE_OK); } diff --git a/src/commands.c b/src/commands.c index aa6f03bc5..b71965f76 100644 --- a/src/commands.c +++ b/src/commands.c @@ -72,13 +72,13 @@ static const CommandSpec commands[] = { /* Admin commands - bypassed */ {"auth", CMD_SPEC_DONT_INTERCEPT }, - {"ping", CMD_SPEC_DONT_INTERCEPT }, + {"ping", CMD_SPEC_INTERCEPT_IN_MULTI }, {"hello", CMD_SPEC_DONT_INTERCEPT }, {"module", CMD_SPEC_DONT_INTERCEPT }, - {"config", CMD_SPEC_DONT_INTERCEPT }, + {"config", CMD_SPEC_INTERCEPT_IN_MULTI }, {"monitor", CMD_SPEC_DONT_INTERCEPT }, {"command", CMD_SPEC_DONT_INTERCEPT }, - {"shutdown", CMD_SPEC_DONT_INTERCEPT }, + {"shutdown", CMD_SPEC_INTERCEPT_IN_MULTI }, {"quit", CMD_SPEC_DONT_INTERCEPT }, {"slowlog", CMD_SPEC_DONT_INTERCEPT }, {"acl", CMD_SPEC_DONT_INTERCEPT }, diff --git a/src/multi.c b/src/multi.c index 551cc61d7..8f04303ed 100644 --- a/src/multi.c +++ b/src/multi.c @@ -75,7 +75,12 @@ bool MultiHandleCommand(RedisRaftCtx *rr, if (multiState->error) { MultiStateReset(multiState); - RedisModule_ReplyWithError(ctx, "EXECABORT Transaction discarded because of previous errors."); + if (clientState->watched) { + RaftReq *req = RaftReqInit(ctx, RR_END_SESSION); + appendEndClientSession(rr, req, clientState->client_id, SESSION_END_EXECABORT); + } else { + RedisModule_ReplyWithError(ctx, EXECABORT_ERR); + } return true; } @@ -96,7 +101,12 @@ bool MultiHandleCommand(RedisRaftCtx *rr, } MultiStateReset(multiState); - RedisModule_ReplyWithSimpleString(ctx, "OK"); + if (clientState->watched) { + RaftReq *req = RaftReqInit(ctx, RR_END_SESSION); + appendEndClientSession(rr, req, clientState->client_id, SESSION_END_DISCARD); + } else { + RedisModule_ReplyWithSimpleString(ctx, "OK"); + } return true; } @@ -120,10 +130,12 @@ bool MultiHandleCommand(RedisRaftCtx *rr, return true; } - if (cmd_flags & CMD_SPEC_DONT_INTERCEPT) { - RedisModule_ReplyWithError(ctx, "ERR not supported by RedisRaft inside MULTI/EXEC"); - multiState->error = true; - return true; + if (cmd_flags & CMD_SPEC_INTERCEPT_IN_MULTI) { + if (cmd_len == 8 && !strncasecmp(cmd_str, "SHUTDOWN", 8)) { + RedisModule_ReplyWithError(ctx, "ERR Command not allowed inside a transaction"); + multiState->error = true; + return true; + } } if (RedisModule_GetUsedMemoryRatio() > 1.0) { diff --git a/src/raft.c b/src/raft.c index 6b057f226..b13d61316 100644 --- a/src/raft.c +++ b/src/raft.c @@ -314,6 +314,8 @@ static void *getClientSession(RedisRaftCtx *rr, RaftRedisCommandArray *cmds, boo client_session = RedisModule_Alloc(sizeof(ClientSession)); client_session->client_id = id; client_session->local = local; + client_session->dirty = false; + client_session->client = RedisModule_CreateModuleClient(rr->ctx); RedisModule_DictSetC(rr->client_session_dict, &id, sizeof(id), client_session); } } @@ -321,6 +323,21 @@ static void *getClientSession(RedisRaftCtx *rr, RaftRedisCommandArray *cmds, boo return client_session; } +static void freeClientSession(RedisRaftCtx *rr, ClientSession *client_session) +{ + RedisModule_FreeModuleClient(rr->ctx, client_session->client); + RedisModule_Free(client_session); +} + +static void endClientSession(RedisRaftCtx *rr, unsigned long long id) +{ + ClientSession *client_session = NULL; + RedisModule_DictDelC(rr->client_session_dict, &id, sizeof(id), &client_session); + if (client_session) { + freeClientSession(rr, client_session); + } +} + RedisModuleUser *RaftGetACLUser(RedisModuleCtx *ctx, RedisRaftCtx *rr, RaftRedisCommandArray *cmds) { int nokey; @@ -359,6 +376,16 @@ void handleUnblock(RedisModuleCtx *ctx, RedisModuleCallReply *reply, void *priva freeBlockedCommand(bc); } +static bool isClientSessionDirty(ClientSession *client_session) +{ + uint64_t flags = RedisModule_GetClientFlags(client_session->client); + if (client_session->dirty || flags & REDISMODULE_CLIENT_FLAG_DIRTY_CAS) { + return true; + } + + return false; +} + /* Execute all commands in a specified RaftRedisCommandArray. * * If reply_ctx is non-NULL, replies are delivered to it. @@ -376,6 +403,7 @@ RedisModuleCallReply *RaftExecuteCommandArray(RedisRaftCtx *rr, RedisModuleCallReply *reply = NULL; RedisModuleUser *user = NULL; RedisModuleCtx *ctx = req ? req->ctx : rr->ctx; + bool is_multi_session = false; if (cmds->acl) { user = RaftGetACLUser(rr->ctx, rr, cmds); @@ -392,7 +420,6 @@ RedisModuleCallReply *RaftExecuteCommandArray(RedisRaftCtx *rr, } ClientSession *client_session = getClientSession(rr, cmds, req != NULL); - (void) client_session; /* unused for now */ if (cmds->cmd_flags & CMD_SPEC_BLOCKING) { replaceBlockingTimeout(cmds); @@ -412,6 +439,16 @@ RedisModuleCallReply *RaftExecuteCommandArray(RedisRaftCtx *rr, * (although no harm is done). */ if (i == 0 && cmdlen == 5 && !strncasecmp(cmd, "MULTI", 5)) { + if (client_session) { + if (isClientSessionDirty(client_session)) { + if (req) { + RedisModule_ReplyWithNull(req->ctx); + } + endClientSession(rr, client_session->client_id); + return NULL; + } + is_multi_session = true; + } if (req) { RedisModule_ReplyWithArray(req->ctx, cmds->len - 1); } @@ -442,9 +479,19 @@ RedisModuleCallReply *RaftExecuteCommandArray(RedisRaftCtx *rr, } enterRedisModuleCall(); - RedisModule_SetContextUser(ctx, user); + if (client_session) { + RedisModule_SetClientUser(client_session->client, user); + RedisModule_SetContextClient(ctx, client_session->client); + } else { + RedisModule_SetContextUser(ctx, user); + } reply = RedisModule_Call(ctx, cmd, resp_call_fmt, &c->argv[1], c->argc - 1); - RedisModule_SetContextUser(ctx, NULL); + if (client_session) { + RedisModule_SetClientUser(client_session->client, NULL); + RedisModule_SetContextClient(ctx, NULL); + } else { + RedisModule_SetContextUser(ctx, NULL); + } exitRedisModuleCall(); rr->entered_eval = old_entered_eval; @@ -465,6 +512,10 @@ RedisModuleCallReply *RaftExecuteCommandArray(RedisRaftCtx *rr, } } + if (is_multi_session) { + endClientSession(rr, client_session->client_id); + } + /* if blocking (this won't be NULL), return it to the caller, to setup callback / saving state */ return reply; } @@ -591,20 +642,6 @@ static void unlockDeleteKeys(RedisRaftCtx *rr, raft_entry_t *entry, RaftReq *req } } -static void freeClientSession(void *client_session) -{ - RedisModule_Free(client_session); -} - -static void endClientSession(RedisRaftCtx *rr, unsigned long long id) -{ - void *client_session = NULL; - RedisModule_DictDelC(rr->client_session_dict, &id, sizeof(id), &client_session); - if (client_session) { - freeClientSession(client_session); - } -} - static void handleEndClientSession(RedisRaftCtx *rr, raft_entry_t *entry, RaftReq *req) { RedisModule_Assert(entry->type == RAFT_LOGTYPE_END_SESSION); @@ -613,7 +650,11 @@ static void handleEndClientSession(RedisRaftCtx *rr, raft_entry_t *entry, RaftRe endClientSession(rr, id); if (req) { - RedisModule_ReplyWithSimpleString(req->ctx, "OK"); + if (strncmp(entry->data, SESSION_END_EXECABORT, entry->data_len) == 0) { + RedisModule_ReplyWithError(req->ctx, EXECABORT_ERR); + } else { + RedisModule_ReplyWithSimpleString(req->ctx, "OK"); + } RaftReqFree(req); } } @@ -626,7 +667,7 @@ void clearClientSessions(RedisRaftCtx *rr) if (client_session->local) { RedisModule_DeauthenticateAndCloseClient(rr->ctx, client_session->client_id); } - freeClientSession(client_session); + freeClientSession(rr, client_session); } RedisModule_DictIteratorStop(iter); RedisModule_FreeDict(rr->ctx, rr->client_session_dict); diff --git a/src/redisraft.c b/src/redisraft.c index 35af658d0..e1dfcabda 100644 --- a/src/redisraft.c +++ b/src/redisraft.c @@ -587,7 +587,7 @@ static void handleClientCommand(RedisRaftCtx *rr, RedisModuleCtx *ctx, RaftRedis RedisModule_ReplyWithError(ctx, "ERR RedisRaft should only handle CLIENT UNBLOCK commands"); } -static void appendEndClientSession(RedisRaftCtx *rr, RaftReq *req, unsigned long long id, char *reason) +void appendEndClientSession(RedisRaftCtx *rr, RaftReq *req, unsigned long long id, char *reason) { raft_entry_t *entry = raft_entry_new(strlen(reason) + 1); entry->type = RAFT_LOGTYPE_END_SESSION; @@ -677,7 +677,7 @@ static bool handleUnwatch(RedisRaftCtx *rr, RedisModuleCtx *ctx, RaftRedisComman if (cmd_len == 7 && strncasecmp(cmd, "UNWATCH", 7) == 0) { RaftReq *req = RaftReqInit(ctx, RR_END_SESSION); unsigned long long id = RedisModule_GetClientId(ctx); - appendEndClientSession(rr, req, id, "UNWATCH"); + appendEndClientSession(rr, req, id, SESSION_END_UNWATCH); return true; } @@ -1725,6 +1725,15 @@ static void interceptRedisCommands(RedisModuleCommandFilterCtx *filter) if (flags != -1 && (flags & CMD_SPEC_DONT_INTERCEPT)) return; + if (flags != -1 && (flags & CMD_SPEC_INTERCEPT_IN_MULTI)) { + unsigned long long id = RedisModule_CommandFilterGetClientId(filter); + ClientState *clientState = ClientStateGetById(rr, id); + + if (clientState == NULL || !clientState->multi_state.active) { + return; + } + } + size_t len; const char *str = RedisModule_StringPtrLen(cmd, &len); diff --git a/src/redisraft.h b/src/redisraft.h index 69e0bf64e..2d7f34fd1 100644 --- a/src/redisraft.h +++ b/src/redisraft.h @@ -676,16 +676,17 @@ typedef struct { unsigned int flags; /* Command flags, see CMD_SPEC_* */ } CommandSpec; -#define CMD_SPEC_READONLY (1 << 1) /* Command is a read-only command */ -#define CMD_SPEC_WRITE (1 << 2) /* Command is a (potentially) write command */ -#define CMD_SPEC_UNSUPPORTED (1 << 3) /* Command is not supported, should be rejected */ -#define CMD_SPEC_DONT_INTERCEPT (1 << 4) /* Command should not be intercepted to RAFT */ -#define CMD_SPEC_SORT_REPLY (1 << 5) /* Command output should be sorted within a lua script */ -#define CMD_SPEC_RANDOM (1 << 6) /* Commands that are always random */ -#define CMD_SPEC_SCRIPTS (1 << 7) /* Commands that have script/function flags */ -#define CMD_SPEC_BLOCKING (1 << 8) /* Blocking command */ -#define CMD_SPEC_MULTI (1 << 9) /* a MULTI */ -#define CMD_SPEC_SUBCOMMAND (1 << 10) /* a command with subcommand specs */ +#define CMD_SPEC_READONLY (1 << 1) /* Command is a read-only command */ +#define CMD_SPEC_WRITE (1 << 2) /* Command is a (potentially) write command */ +#define CMD_SPEC_UNSUPPORTED (1 << 3) /* Command is not supported, should be rejected */ +#define CMD_SPEC_DONT_INTERCEPT (1 << 4) /* Command should not be intercepted to RAFT */ +#define CMD_SPEC_SORT_REPLY (1 << 5) /* Command output should be sorted within a lua script */ +#define CMD_SPEC_RANDOM (1 << 6) /* Commands that are always random */ +#define CMD_SPEC_SCRIPTS (1 << 7) /* Commands that have script/function flags */ +#define CMD_SPEC_BLOCKING (1 << 8) /* Blocking command */ +#define CMD_SPEC_MULTI (1 << 9) /* a MULTI */ +#define CMD_SPEC_SUBCOMMAND (1 << 10) /* a command with subcommand specs */ +#define CMD_SPEC_INTERCEPT_IN_MULTI (1 << 11) /* only intecept this command within a MULTI */ /* Command filtering re-entrancy counter handling. * @@ -738,6 +739,7 @@ typedef struct MultiState { } MultiState; typedef struct ClientState { + unsigned long long client_id; MultiState multi_state; bool asking; /* we record "watched" at append time, for 2 reasons @@ -774,9 +776,18 @@ typedef struct ClientState { typedef struct ClientSession { raft_session_t client_id; + RedisModuleClient *client; bool local; + bool dirty; } ClientSession; +#define SESSION_END_DISCONNECT "DISCONNECT" +#define SESSION_END_UNWATCH "UNWATCH" +#define SESSION_END_DISCARD "DISCARD" +#define SESSION_END_EXECABORT "EXECABORT" + +#define EXECABORT_ERR "EXECABORT Transaction discarded because of previous errors." + /* common.c */ void joinLinkIdleCallback(Connection *conn); void joinLinkFreeCallback(void *privdata); @@ -830,6 +841,7 @@ RRStatus RaftRedisDeserializeTimeout(const void *buf, size_t buf_size, raft_inde /* redisraft.c */ RRStatus RedisRaftCtxInit(RedisRaftCtx *rr, RedisModuleCtx *ctx); void RedisRaftCtxClear(RedisRaftCtx *rr); +void appendEndClientSession(RedisRaftCtx *rr, RaftReq *req, unsigned long long id, char *reason); /* raft.c */ void RaftReqFree(RaftReq *req); diff --git a/src/snapshot.c b/src/snapshot.c index 3be9b5196..05b3dbe5e 100644 --- a/src/snapshot.c +++ b/src/snapshot.c @@ -587,7 +587,9 @@ static void clientSessionRDBLoad(RedisModuleIO *rdb) ClientSession *client_session = RedisModule_Alloc(sizeof(ClientSession)); unsigned long long id = RedisModule_LoadUnsigned(rdb); client_session->client_id = id; + client_session->dirty = RedisModule_LoadUnsigned(rdb); client_session->local = false; + client_session->client = RedisModule_CreateModuleClient(rr->ctx); RedisModule_DictSetC(rr->client_session_dict, &id, sizeof(id), client_session); } } @@ -707,6 +709,12 @@ static void clientSessionRDBSave(RedisModuleIO *rdb) ClientSession *client_session; while (RedisModule_DictNextC(iter, NULL, (void **) &client_session) != NULL) { RedisModule_SaveUnsigned(rdb, client_session->client_id); + uint64_t flags = RedisModule_GetClientFlags(client_session->client); + if (flags & REDISMODULE_CLIENT_FLAG_DIRTY_CAS) { + RedisModule_SaveUnsigned(rdb, 1); + } else { + RedisModule_SaveUnsigned(rdb, 0); + } } RedisModule_DictIteratorStop(iter); } diff --git a/tests/integration/test_blocking.py b/tests/integration/test_blocking.py index 09d274d5d..d2d92cbe8 100644 --- a/tests/integration/test_blocking.py +++ b/tests/integration/test_blocking.py @@ -414,3 +414,29 @@ def test_blocking_with_timeout_after_unblock(cluster): val = cluster.node(i).raft_debug_exec("lrange", "x", 0, -1) assert type(val) == list assert len(val) == 0 + + +def test_blocking_with_watch(cluster): + cluster.create(3) + + c1 = cluster.leader_node().client.connection_pool.get_connection('c1') + c1.send_command('watch', 'x') + assert c1.read_response() == b'OK' + c1.send_command('blpop', 'x', 0) + c2 = cluster.leader_node().client.connection_pool.get_connection('c2') + c2.send_command('watch', 'x') + assert c2.read_response() == b'OK' + c2.send_command('blpop', 'x', 0) + + cluster.leader_node().execute("lpush", "x", 1) + cluster.leader_node().execute("lpush", "x", 2) + cluster.leader_node().execute("lpush", "x", 3) + + cluster.wait_for_unanimity() + + assert c1.read_response() == [b'x', b'1'] + assert c2.read_response() == [b'x', b'2'] + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("lrange", "x", 0, -1) + assert val == [b'3'] diff --git a/tests/integration/test_multi.py b/tests/integration/test_multi.py index 07d6cf47a..c23702f9f 100644 --- a/tests/integration/test_multi.py +++ b/tests/integration/test_multi.py @@ -3,7 +3,7 @@ Licensed under your choice of the Redis Source Available License 2.0 (RSALv2) or the Server Side Public License v1 (SSPLv1). """ - +import pytest from pytest import raises from redis.exceptions import ExecAbortError, ResponseError @@ -240,3 +240,277 @@ def test_watch_within_multi(cluster): conn.execute('watch', 'x') assert conn.execute('get', 'key1') == b'QUEUED' assert conn.execute('exec') == [b'1', b'1'] + + +def test_multi_watch_without_modification(cluster): + cluster.create(3) + node = cluster.leader_node() + node.execute('set', 'key1', 1) + + conn = RawConnection(cluster.node(1).client) + + assert conn.execute('watch', 'key1') == b'OK' + assert conn.execute('multi') == b'OK' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('set', 'key2', 1) == b'QUEUED' + assert conn.execute('exec') == [b'1', b'1', b'OK'] + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val == b'1' + + +def test_multi_watch_with_modification(cluster): + cluster.create(3) + node = cluster.leader_node() + node.execute('set', 'key1', 1) + + conn = RawConnection(cluster.node(1).client) + + assert conn.execute('watch', 'key1') == b'OK' + assert conn.execute('multi') == b'OK' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('set', 'key2', 1) == b'QUEUED' + # "dirty" the key that was watched, should kill transaction + assert cluster.execute('set', 'key1', 2) + assert conn.execute('exec') is None + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val is None + + +def test_multi_watch_cleared_after_exec(cluster): + cluster.create(3) + node = cluster.leader_node() + node.execute('set', 'key1', 1) + + conn = RawConnection(cluster.node(1).client) + + assert conn.execute('watch', 'key1') == b'OK' + assert conn.execute('multi') == b'OK' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('set', 'key2', 1) == b'QUEUED' + assert conn.execute('exec') == [b'1', b'1', b'OK'] + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val == b'1' + + assert conn.execute('multi') == b'OK' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('set', 'key2', 2) == b'QUEUED' + # "dirty" the key that was formerly watched, + # but previous exec should have cleared it + assert cluster.execute('set', 'key1', 2) + assert conn.execute('exec') == [b'2', b'2', b'OK'] + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val == b'2' + + +def test_multi_watch_cleared_after_discard(cluster): + cluster.create(3) + node = cluster.leader_node() + node.execute('set', 'key1', 1) + + conn = RawConnection(cluster.node(1).client) + + assert conn.execute('watch', 'key1') == b'OK' + assert conn.execute('multi') == b'OK' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('set', 'key2', 1) == b'QUEUED' + assert conn.execute('discard') == b'OK' + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val is None + + assert conn.execute('multi') == b'OK' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('set', 'key2', 2) == b'QUEUED' + # "dirty" the key that was formerly watched, + # but previous exec should have cleared it + assert cluster.execute('set', 'key1', 2) + assert conn.execute('exec') == [b'2', b'2', b'OK'] + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val == b'2' + + +def test_multi_watch_cleared_after_execabort(cluster): + cluster.create(3) + node = cluster.leader_node() + node.execute('set', 'key1', 1) + + conn = RawConnection(cluster.node(1).client) + + assert conn.execute('watch', 'key1') == b'OK' + assert conn.execute('multi') == b'OK' + with raises(ResponseError, match=".*unknown command 'notexistcmd'.*"): + conn.execute('notexistcmd') + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('set', 'key2', 1) == b'QUEUED' + with raises(ResponseError, match="Transaction discarded.*"): + conn.execute('exec') + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val is None + + assert conn.execute('multi') == b'OK' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('set', 'key2', 2) == b'QUEUED' + # "dirty" the key that was formerly watched, + # but previous exec should have cleared it + assert cluster.execute('set', 'key1', 2) + assert conn.execute('exec') == [b'2', b'2', b'OK'] + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val == b'2' + + +def test_multi_watch_cleared_after_unwatch(cluster): + cluster.create(3) + node = cluster.leader_node() + node.execute('set', 'key1', 1) + + conn = RawConnection(cluster.node(1).client) + + assert conn.execute('watch', 'key1') == b'OK' + assert conn.execute('unwatch') == b'OK' + assert conn.execute('multi') == b'OK' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('get', 'key1') == b'QUEUED' + assert conn.execute('set', 'key2', 1) == b'QUEUED' + assert cluster.execute('set', 'key1', 2) + assert conn.execute('exec') == [b'2', b'2', b'OK'] + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val == b'1' + + +def test_multi_watch_with_restart_clean(cluster): + cluster.create(3) + node = cluster.leader_node() + node.execute('set', 'key1', 1) + + conn = RawConnection(cluster.node(1).client) + + assert conn.execute('watch', 'key1') == b'OK' + assert conn.execute('multi') == b'OK' + assert conn.execute('set', 'key2', 1) == b'QUEUED' + + cluster.node(2).restart() + cluster.node(2).wait_for_election() + cluster.node(3).restart() + cluster.node(3).wait_for_election() + cluster.wait_for_unanimity() + + assert conn.execute('exec') == [b'OK'] + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val == b'1' + + +def test_multi_watch_with_restart_dirty(cluster): + cluster.create(3) + node = cluster.leader_node() + node.execute('set', 'key1', 1) + + conn = RawConnection(cluster.node(1).client) + + assert conn.execute('watch', 'key1') == b'OK' + assert conn.execute('multi') == b'OK' + assert conn.execute('set', 'key2', 1) == b'QUEUED' + assert cluster.execute('set', 'key1', 2) + + cluster.node(2).restart() + cluster.node(2).wait_for_election() + cluster.node(3).restart() + cluster.node(3).wait_for_election() + cluster.wait_for_unanimity() + + assert conn.execute('exec') is None + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val is None + + +def test_multi_watch_with_dirty_after_restart(cluster): + cluster.create(3) + node = cluster.leader_node() + node.execute('set', 'key1', 1) + + conn = RawConnection(cluster.node(1).client) + + assert conn.execute('watch', 'key1') == b'OK' + assert conn.execute('multi') == b'OK' + assert conn.execute('set', 'key2', 1) == b'QUEUED' + + cluster.node(2).restart() + cluster.node(2).wait_for_election() + cluster.node(3).restart() + cluster.node(3).wait_for_election() + cluster.wait_for_unanimity() + + assert cluster.execute('set', 'key1', 2) + + assert conn.execute('exec') is None + + cluster.wait_for_unanimity() + + for i in range(1, 3): + val = cluster.node(i).raft_debug_exec("get", "key2") + assert val is None + + +@pytest.mark.parametrize("with_watch", [False, True]) +def test_multi_with_blocking_commands(cluster, with_watch): + cluster.create(3) + node = cluster.leader_node() + node.execute('set', 'key1', 1) + + conn = RawConnection(cluster.node(1).client) + + if with_watch: + assert conn.execute('watch', 'key1') == b'OK' + assert conn.execute('multi') == b'OK' + assert conn.execute('blpop', 'key2', 0) == b'QUEUED' + assert conn.execute('exec') == [None] diff --git a/tests/redis-suite/skip.txt b/tests/redis-suite/skip.txt index aafcecab8..e62313428 100644 --- a/tests/redis-suite/skip.txt +++ b/tests/redis-suite/skip.txt @@ -1,5 +1,6 @@ --- doesn't work, as command appears as "raft", not "blpop" +-- doesn't work, as command appears as "raft", not the command Blocking command accounted only once in commandstats after timeout +command stats for MULTI -- Streams not supported -- See: https://github.com/RedisLabs/redisraft/issues/59 @@ -22,6 +23,7 @@ Timedout script link is still usable after Lua returns /function kill /script kill /test wrong subcommand +/.*script timeout.* -- RAFT command prefix shows up in SLOWLOG. SLOWLOG - Rewritten commands are logged as their original command @@ -32,6 +34,7 @@ MONITOR can log executed commands MONITOR can log commands issued by the scripting engine MONITOR can log commands issued by functions MONITOR correctly handles multi-exec cases +MONITOR log blocked command only once -- TODO: check what's wrong UNLINK can reclaim memory in background @@ -39,18 +42,11 @@ UNLINK can reclaim memory in background -- ACL test fails because we prepend "raft" string to the command Script ACL check --- WATCH (multi/exec) not supported -/.*MULTI.* -/.*EXEC.* -/.*WATCH.* -SMOVE only notify dstset when the addition is successful -FLUSHALL is able to touch the watched keys -FLUSHDB is able to touch the watched keys -client evicted due to watched key list -FLUSHALL does not touch non affected keys -FLUSHDB does not touch non affected keys -SWAPDB is able to touch the watched keys that exist -SWAPDB is able to touch the watched keys that do not exist +-- MULTI/EXEC is currently read-write in RedisRaft +EXEC with only read commands should not be rejected when OOM + +-- SAVE is in general unsupported in RedisRaft so can ignore it in multi as well +MULTI with SAVE -- After fixing this: https://github.com/RedisLabs/redisraft/issues/367 -- We don't need to skip this test as it doesn't actually configure a replica.