Skip to content

Commit de0f2ab

Browse files
authored
Add a flag to returns the CAS calue from ADD/SET/REPLACE/CAS and multi-set operations (#261)
* Warn about the CAS-plus-replication hazard on ReplicatingClient ReplicatingClient's CAS-touching methods have always had a silent correctness problem when used against more than one replica: each server maintains its own CAS counter, so a CAS value cannot match on more than one replica. any(returns) then reports success as long as one server accepted the write, but the other replicas silently rejected it, leaving them divergent. The get-side methods are an equally potent footgun -- gets(), get(get_cas=True), and get_multi(get_cas=True) return a CAS from whichever replica happened to respond first, even though that value cannot be safely passed back to cas() on a multi-replica client. Add docstring warnings on the class and on each affected method, plus a runtime UserWarning fired by a small _warn_multi_replica_cas helper on the five at-risk surfaces (cas, set_multi with tuple-CAS keys, gets, get(get_cas=True), get_multi(get_cas=True)) when the client has more than one server. The warning is emitted via the warnings module rather than the existing logger, since it's an API-misuse signal (deduped by default, surfaced without logging configuration) rather than an operational event. For backwards compatibility, the behavior itself is left unchanged -- callers in single-replica deployments are unaffected. * Return CAS from single-key mutators via `get_cas` kwarg add(), set(), replace(), and cas() all produce an item with a new CAS value on success, and the memcached binary protocol already returns it in the response header -- the client was simply discarding it. Callers who want to chain a CAS-guarded update after a write had to follow up with a separate gets() round-trip, which is both slower and racy (another writer could slip in between). Add an optional `get_cas=False` kwarg matching the existing convention on get()/get_multi(). When True, these methods now return a tuple of `(success, cas)` instead of a plain bool; `cas` is the new CAS on success, or None on failure. For ReplicatingClient, get_cas=True is rejected with NotImplementedError when the client is configured with more than one server, since each replica has its own CAS counter and any value we returned would be unsafe to feed back to cas(). Single-server ReplicatingClient (and DistributedClient, which routes each key to exactly one server) work normally. * Add set_multi_cas for per-key CAS return from batched writes set_multi's current return shape is a list of failed keys, which can't carry a per-key CAS value. It also uses the quiet setq/addq opcodes, which intentionally suppress successful responses -- so even if the shape allowed it, the wire protocol wouldn't return a CAS per key. Add a separate `set_multi_cas` method that uses the non-quiet set/add opcodes (one response per key) and returns `{str_key: int | None}` for every input key -- int on success, None on failure. The existing `{(key, cas): value}` input syntax from set_multi is preserved; the result dict is keyed by the string key regardless of which form was passed. For ReplicatingClient, set_multi_cas raises NotImplementedError when the client has more than one server, for the same reason as the single-key get_cas=True surfaces: each replica has its own CAS counter, so any per-key value we returned would be unsafe to feed back to cas(). Single-server ReplicatingClient and DistributedClient work normally.
1 parent b40f7fd commit de0f2ab

5 files changed

Lines changed: 571 additions & 59 deletions

File tree

bmemcached/client/distributed.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def delete_multi(self, keys):
3939
servers[server_key].append(key)
4040
return all([server.delete_multi(keys_) for server, keys_ in servers.items()])
4141

42-
def set(self, key, value, time=0, compress_level=-1):
42+
def set(self, key, value, time=0, compress_level=-1, get_cas=False):
4343
"""
4444
Set a value for a key on server.
4545
@@ -53,11 +53,15 @@ def set(self, key, value, time=0, compress_level=-1):
5353
0 = no compression, 1 = fastest, 9 = slowest but best,
5454
-1 = default compression level.
5555
:type compress_level: int
56-
:return: True in case of success and False in case of failure
57-
:rtype: bool
56+
:param get_cas: If true, return (success, cas) where cas is the new
57+
CAS value on success and None on failure.
58+
:type get_cas: bool
59+
:return: True in case of success and False in case of failure, or a
60+
(success, cas) tuple if get_cas=True.
61+
:rtype: bool or tuple
5862
"""
5963
server = self._get_server(key)
60-
return server.set(key, value, time, compress_level)
64+
return server.set(key, value, time, compress_level, get_cas=get_cas)
6165

6266
def set_multi(self, mappings, time=0, compress_level=-1):
6367
"""
@@ -86,7 +90,37 @@ def set_multi(self, mappings, time=0, compress_level=-1):
8690

8791
return list(returns)
8892

89-
def add(self, key, value, time=0, compress_level=-1):
93+
def set_multi_cas(self, mappings, time=0, compress_level=-1):
94+
"""
95+
Set multiple keys with their values on server, returning the new CAS
96+
value for each successfully stored key.
97+
98+
:param mappings: A dict with keys/values. Keys may be (key, cas)
99+
tuples as in set_multi.
100+
:type mappings: dict
101+
:param time: Time in seconds that your key will expire.
102+
:type time: int
103+
:param compress_level: How much to compress.
104+
0 = no compression, 1 = fastest, 9 = slowest but best,
105+
-1 = default compression level.
106+
:type compress_level: int
107+
:return: A dict keyed by the string key of every input mapping. The
108+
value is the new CAS int on success or None on failure.
109+
:rtype: dict
110+
"""
111+
if not mappings:
112+
return {}
113+
result = {}
114+
server_mappings = defaultdict(dict)
115+
for key, value in mappings.items():
116+
str_key = key[0] if isinstance(key, tuple) else key
117+
server_key = self._get_server(str_key)
118+
server_mappings[server_key][key] = value
119+
for server, m in server_mappings.items():
120+
result.update(server.set_multi_cas(m, time, compress_level))
121+
return result
122+
123+
def add(self, key, value, time=0, compress_level=-1, get_cas=False):
90124
"""
91125
Add a key/value to server ony if it does not exist.
92126
@@ -100,13 +134,17 @@ def add(self, key, value, time=0, compress_level=-1):
100134
0 = no compression, 1 = fastest, 9 = slowest but best,
101135
-1 = default compression level.
102136
:type compress_level: int
103-
:return: True if key is added False if key already exists
104-
:rtype: bool
137+
:param get_cas: If true, return (success, cas) where cas is the new
138+
CAS value on success and None on failure.
139+
:type get_cas: bool
140+
:return: True if key is added False if key already exists, or a
141+
(success, cas) tuple if get_cas=True.
142+
:rtype: bool or tuple
105143
"""
106144
server = self._get_server(key)
107-
return server.add(key, value, time, compress_level)
145+
return server.add(key, value, time, compress_level, get_cas=get_cas)
108146

109-
def replace(self, key, value, time=0, compress_level=-1):
147+
def replace(self, key, value, time=0, compress_level=-1, get_cas=False):
110148
"""
111149
Replace a key/value to server ony if it does exist.
112150
@@ -120,11 +158,15 @@ def replace(self, key, value, time=0, compress_level=-1):
120158
0 = no compression, 1 = fastest, 9 = slowest but best,
121159
-1 = default compression level.
122160
:type compress_level: int
123-
:return: True if key is replace False if key does not exists
124-
:rtype: bool
161+
:param get_cas: If true, return (success, cas) where cas is the new
162+
CAS value on success and None on failure.
163+
:type get_cas: bool
164+
:return: True if key is replace False if key does not exists, or a
165+
(success, cas) tuple if get_cas=True.
166+
:rtype: bool or tuple
125167
"""
126168
server = self._get_server(key)
127-
return server.replace(key, value, time, compress_level)
169+
return server.replace(key, value, time, compress_level, get_cas=get_cas)
128170

129171
def get(self, key, default=None, get_cas=False):
130172
"""
@@ -182,7 +224,7 @@ def gets(self, key):
182224
server = self._get_server(key)
183225
return server.get(key)
184226

185-
def cas(self, key, value, cas, time=0, compress_level=-1):
227+
def cas(self, key, value, cas, time=0, compress_level=-1, get_cas=False):
186228
"""
187229
Set a value for a key on server if its CAS value matches cas.
188230
@@ -198,11 +240,15 @@ def cas(self, key, value, cas, time=0, compress_level=-1):
198240
0 = no compression, 1 = fastest, 9 = slowest but best,
199241
-1 = default compression level.
200242
:type compress_level: int
201-
:return: True in case of success and False in case of failure
202-
:rtype: bool
243+
:param get_cas: If true, return (success, new_cas) where new_cas is
244+
the item's new CAS after the operation, or None on failure.
245+
:type get_cas: bool
246+
:return: True in case of success and False in case of failure, or a
247+
(success, new_cas) tuple if get_cas=True.
248+
:rtype: bool or tuple
203249
"""
204250
server = self._get_server(key)
205-
return server.cas(key, value, cas, time, compress_level)
251+
return server.cas(key, value, cas, time, compress_level, get_cas=get_cas)
206252

207253
def incr(self, key, value, default=0, time=1000000):
208254
"""

bmemcached/client/mixin.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,19 +132,22 @@ def gets(self, key):
132132
def get_multi(self, keys, get_cas=False):
133133
raise NotImplementedError()
134134

135-
def set(self, key, value, time=0, compress_level=-1):
135+
def set(self, key, value, time=0, compress_level=-1, get_cas=False):
136136
raise NotImplementedError()
137137

138-
def cas(self, key, value, cas, time=0, compress_level=-1):
138+
def cas(self, key, value, cas, time=0, compress_level=-1, get_cas=False):
139139
raise NotImplementedError()
140140

141141
def set_multi(self, mappings, time=0, compress_level=-1):
142142
raise NotImplementedError()
143143

144-
def add(self, key, value, time=0, compress_level=-1):
144+
def set_multi_cas(self, mappings, time=0, compress_level=-1):
145145
raise NotImplementedError()
146146

147-
def replace(self, key, value, time=0, compress_level=-1):
147+
def add(self, key, value, time=0, compress_level=-1, get_cas=False):
148+
raise NotImplementedError()
149+
150+
def replace(self, key, value, time=0, compress_level=-1, get_cas=False):
148151
raise NotImplementedError()
149152

150153
def delete(self, key, cas=0): # type: (six.string_types, int) -> bool

0 commit comments

Comments
 (0)