Skip to content

Commit 8170b28

Browse files
committed
Add set_multi_cas for per-key CAS return from batched writes
set_multi's current return shape is a list of failed keys, which can't carry a per-key CAS value. It also uses the quiet setq/addq opcodes, which intentionally suppress successful responses -- so even if the shape allowed it, the wire protocol wouldn't return a CAS per key. Add a separate `set_multi_cas` method that uses the non-quiet set/add opcodes (one response per key) and returns `{str_key: int | None}` for every input key -- int on success, None on failure. The existing `{(key, cas): value}` input syntax from set_multi is preserved; the result dict is keyed by the string key regardless of which form was passed. For ReplicatingClient, set_multi_cas raises NotImplementedError when the client has more than one server, for the same reason as the single-key get_cas=True surfaces: each replica has its own CAS counter, so any per-key value we returned would be unsafe to feed back to cas(). Single-server ReplicatingClient and DistributedClient work normally.
1 parent 50a471a commit 8170b28

5 files changed

Lines changed: 179 additions & 0 deletions

File tree

bmemcached/client/distributed.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,36 @@ def set_multi(self, mappings, time=0, compress_level=-1):
9090

9191
return list(returns)
9292

93+
def set_multi_cas(self, mappings, time=0, compress_level=-1):
94+
"""
95+
Set multiple keys with their values on server, returning the new CAS
96+
value for each successfully stored key.
97+
98+
:param mappings: A dict with keys/values. Keys may be (key, cas)
99+
tuples as in set_multi.
100+
:type mappings: dict
101+
:param time: Time in seconds that your key will expire.
102+
:type time: int
103+
:param compress_level: How much to compress.
104+
0 = no compression, 1 = fastest, 9 = slowest but best,
105+
-1 = default compression level.
106+
:type compress_level: int
107+
:return: A dict keyed by the string key of every input mapping. The
108+
value is the new CAS int on success or None on failure.
109+
:rtype: dict
110+
"""
111+
if not mappings:
112+
return {}
113+
result = {}
114+
server_mappings = defaultdict(dict)
115+
for key, value in mappings.items():
116+
str_key = key[0] if isinstance(key, tuple) else key
117+
server_key = self._get_server(str_key)
118+
server_mappings[server_key][key] = value
119+
for server, m in server_mappings.items():
120+
result.update(server.set_multi_cas(m, time, compress_level))
121+
return result
122+
93123
def add(self, key, value, time=0, compress_level=-1, get_cas=False):
94124
"""
95125
Add a key/value to server ony if it does not exist.

bmemcached/client/mixin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def cas(self, key, value, cas, time=0, compress_level=-1, get_cas=False):
141141
def set_multi(self, mappings, time=0, compress_level=-1):
142142
raise NotImplementedError()
143143

144+
def set_multi_cas(self, mappings, time=0, compress_level=-1):
145+
raise NotImplementedError()
146+
144147
def add(self, key, value, time=0, compress_level=-1, get_cas=False):
145148
raise NotImplementedError()
146149

bmemcached/client/replicating.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,38 @@ def set_multi(self, mappings, time=0, compress_level=-1):
279279

280280
return list(returns)
281281

282+
def set_multi_cas(self, mappings, time=0, compress_level=-1):
283+
"""
284+
Set multiple keys with their values on the server, returning the new
285+
CAS value for each successfully stored key.
286+
287+
Only supported when the client is configured with a single server;
288+
see the class docstring for why CAS and multi-server replication
289+
don't mix.
290+
291+
:param mappings: A dict with keys/values. Keys may be (key, cas)
292+
tuples as in set_multi.
293+
:type mappings: dict
294+
:param time: Time in seconds that your key will expire.
295+
:type time: int
296+
:param compress_level: How much to compress.
297+
0 = no compression, 1 = fastest, 9 = slowest but best,
298+
-1 = default compression level.
299+
:type compress_level: int
300+
:return: A dict keyed by the string key of every input mapping. The
301+
value is the new CAS int on success or None on failure.
302+
:rtype: dict
303+
:raises NotImplementedError: if more than one server is configured.
304+
"""
305+
if len(self._servers) > 1:
306+
raise NotImplementedError(
307+
"set_multi_cas is not supported on ReplicatingClient with "
308+
"more than one server."
309+
)
310+
if not mappings:
311+
return {}
312+
return self._servers[0].set_multi_cas(mappings, time, compress_level=compress_level)
313+
282314
def add(self, key, value, time=0, compress_level=-1, get_cas=False):
283315
"""
284316
Add a key/value to server ony if it does not exist.

bmemcached/protocol.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,72 @@ def set_multi(self, mappings, time=100, compress_level=-1):
789789

790790
return failed
791791

792+
def set_multi_cas(self, mappings, time=100, compress_level=-1):
793+
"""
794+
Set multiple keys with their values on server and return the new CAS
795+
value for each successfully stored key.
796+
797+
If a key is a (key, cas) tuple, insert as if cas(key, value, cas) had
798+
been called. A cas of 0 means add-if-not-exists.
799+
800+
Unlike set_multi, this uses the non-quiet set/add opcodes so that the
801+
server responds to every request; this costs one response per key but
802+
is what allows per-key CAS values to be returned.
803+
804+
:param mappings: A dict with keys/values
805+
:type mappings: dict
806+
:param time: Time in seconds that your key will expire.
807+
:type time: int
808+
:param compress_level: How much to compress.
809+
0 = no compression, 1 = fastest, 9 = slowest but best,
810+
-1 = default compression level.
811+
:type compress_level: int
812+
:return: A dict keyed by the string key of every input mapping. The
813+
value is the new CAS int on success or None on failure.
814+
:rtype: dict
815+
"""
816+
mappings = list(mappings.items())
817+
msg = bytearray()
818+
result = {}
819+
820+
for opaque, (key, value) in enumerate(mappings):
821+
if isinstance(key, tuple):
822+
str_key, cas = key
823+
else:
824+
str_key, cas = key, None
825+
result[str_key] = None
826+
827+
if cas == 0:
828+
command = 'add'
829+
else:
830+
command = 'set'
831+
832+
keybytes = str_to_bytes(str_key)
833+
flags, value = self.serialize(value, compress_level=compress_level)
834+
msg += struct.pack(self.HEADER_STRUCT +
835+
self.COMMANDS[command]['struct'] % (len(keybytes), len(value)),
836+
self.MAGIC['request'],
837+
self.COMMANDS[command]['command'],
838+
len(keybytes),
839+
8, 0, 0, len(keybytes) + len(value) + 8, opaque, cas or 0,
840+
flags, time, keybytes, value)
841+
842+
self._send(msg)
843+
844+
# Non-quiet set/add return exactly one response per request, so we can
845+
# read a fixed count rather than relying on a trailing noop sentinel.
846+
for _ in range(len(mappings)):
847+
(magic, opcode, keylen, extlen, datatype, status, bodylen, opaque,
848+
cas, extra_content) = self._get_response()
849+
if status == self.STATUS['server_disconnected']:
850+
return result
851+
if status == self.STATUS['success']:
852+
key, value = mappings[opaque]
853+
str_key = key[0] if isinstance(key, tuple) else key
854+
result[str_key] = cas
855+
856+
return result
857+
792858
def _incr_decr(self, command, key, value, default, time):
793859
"""
794860
Function which increments and decrements.

test/test_simple_functions.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def tearDown(self):
2828
def reset(self):
2929
self.client.delete('test_key')
3030
self.client.delete('test_key2')
31+
self.client.delete('fresh_key')
3132

3233
def testSet(self):
3334
self.assertTrue(self.client.set('test_key', 'test'))
@@ -121,6 +122,51 @@ def testMultiCas(self):
121122
}), [])
122123
self.assertEqual(self.client.get('test_key'), 'value4')
123124

125+
def testSetMultiCas(self):
126+
# All-success plain keys: every input gets a non-None CAS, and each
127+
# returned CAS matches what gets() reports afterwards.
128+
result = self.client.set_multi_cas({
129+
'test_key': 'value1',
130+
'test_key2': 'value2',
131+
})
132+
self.assertEqual(set(result.keys()), {'test_key', 'test_key2'})
133+
self.assertTrue(result['test_key'] is not None)
134+
self.assertTrue(result['test_key2'] is not None)
135+
_, cas1 = self.client.gets('test_key')
136+
_, cas2 = self.client.gets('test_key2')
137+
self.assertEqual(result['test_key'], cas1)
138+
self.assertEqual(result['test_key2'], cas2)
139+
140+
# CAS failure: add-if-not-exists when the key already exists returns
141+
# None for that key; unrelated keys still succeed.
142+
result = self.client.set_multi_cas({
143+
('test_key', 0): 'shouldnt_store',
144+
'fresh_key': 'fresh',
145+
})
146+
self.assertTrue(result['test_key'] is None)
147+
self.assertTrue(result['fresh_key'] is not None)
148+
self.assertEqual(self.client.get('test_key'), 'value1')
149+
self.client.delete('fresh_key')
150+
151+
# Stale-CAS failure: capture cas, mutate out of band, then set_multi_cas
152+
# with the stale cas must fail and leave the out-of-band value intact.
153+
_, stale_cas = self.client.gets('test_key')
154+
self.client.set('test_key', 'other')
155+
result = self.client.set_multi_cas({
156+
('test_key', stale_cas): 'should_fail',
157+
})
158+
self.assertTrue(result['test_key'] is None)
159+
self.assertEqual(self.client.get('test_key'), 'other')
160+
161+
# Returned CAS is usable directly in cas() without a gets() round-trip.
162+
self.client.delete('test_key')
163+
result = self.client.set_multi_cas({'test_key': 'v'})
164+
self.assertTrue(self.client.cas('test_key', 'v2', result['test_key']))
165+
self.assertEqual(self.client.get('test_key'), 'v2')
166+
167+
def testSetMultiCasEmpty(self):
168+
self.assertEqual(self.client.set_multi_cas({}), {})
169+
124170
def testGetMultiCas(self):
125171
self.client.set('test_key', 'value1')
126172
self.client.set('test_key2', 'value2')
@@ -374,6 +420,8 @@ def testGetCasMultiReplicaRaises(self):
374420
client.replace('test_key', 'v', get_cas=True)
375421
with self.assertRaises(NotImplementedError):
376422
client.cas('test_key', 'v', None, get_cas=True)
423+
with self.assertRaises(NotImplementedError):
424+
client.set_multi_cas({'test_key': 'v'})
377425

378426
# get_cas=False (default) still works fine on multi-replica.
379427
self.assertTrue(client.set('test_key', 'v'))

0 commit comments

Comments
 (0)