Skip to content

Commit 2ec16f4

Browse files
committed
Add set_multi_cas for per-key CAS return from batched writes
set_multi's current return shape is a list of failed keys, which can't carry a per-key CAS value. It also uses the quiet setq/addq opcodes, which intentionally suppress successful responses -- so even if the shape allowed it, the wire protocol wouldn't return a CAS per key. Add a separate `set_multi_cas` method that uses the non-quiet set/add opcodes (one response per key) and returns `{str_key: int | None}` for every input key -- int on success, None on failure. The existing `{(key, cas): value}` input syntax from set_multi is preserved; the result dict is keyed by the string key regardless of which form was passed. For ReplicatingClient, the returned CAS per key is the first non-None CAS from any replica, matching the single-key helpers.
1 parent 52c948b commit 2ec16f4

5 files changed

Lines changed: 179 additions & 0 deletions

File tree

bmemcached/client/distributed.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,36 @@ def set_multi(self, mappings, time=0, compress_level=-1):
9090

9191
return list(returns)
9292

93+
def set_multi_cas(self, mappings, time=0, compress_level=-1):
94+
"""
95+
Set multiple keys with their values on server, returning the new CAS
96+
value for each successfully stored key.
97+
98+
:param mappings: A dict with keys/values. Keys may be (key, cas)
99+
tuples as in set_multi.
100+
:type mappings: dict
101+
:param time: Time in seconds that your key will expire.
102+
:type time: int
103+
:param compress_level: How much to compress.
104+
0 = no compression, 1 = fastest, 9 = slowest but best,
105+
-1 = default compression level.
106+
:type compress_level: int
107+
:return: A dict keyed by the string key of every input mapping. The
108+
value is the new CAS int on success or None on failure.
109+
:rtype: dict
110+
"""
111+
if not mappings:
112+
return {}
113+
result = {}
114+
server_mappings = defaultdict(dict)
115+
for key, value in mappings.items():
116+
str_key = key[0] if isinstance(key, tuple) else key
117+
server_key = self._get_server(str_key)
118+
server_mappings[server_key][key] = value
119+
for server, m in server_mappings.items():
120+
result.update(server.set_multi_cas(m, time, compress_level))
121+
return result
122+
93123
def add(self, key, value, time=0, compress_level=-1, get_cas=False):
94124
"""
95125
Add a key/value to server ony if it does not exist.

bmemcached/client/mixin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def cas(self, key, value, cas, time=0, compress_level=-1, get_cas=False):
141141
def set_multi(self, mappings, time=0, compress_level=-1):
142142
raise NotImplementedError()
143143

144+
def set_multi_cas(self, mappings, time=0, compress_level=-1):
145+
raise NotImplementedError()
146+
144147
def add(self, key, value, time=0, compress_level=-1, get_cas=False):
145148
raise NotImplementedError()
146149

bmemcached/client/replicating.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,38 @@ def set_multi(self, mappings, time=0, compress_level=-1):
245245

246246
return list(returns)
247247

248+
def set_multi_cas(self, mappings, time=0, compress_level=-1):
249+
"""
250+
Set multiple keys with their values on the server, returning the new
251+
CAS value for each successfully stored key.
252+
253+
Only supported when the client is configured with a single server;
254+
see the class docstring for why CAS and multi-server replication
255+
don't mix.
256+
257+
:param mappings: A dict with keys/values. Keys may be (key, cas)
258+
tuples as in set_multi.
259+
:type mappings: dict
260+
:param time: Time in seconds that your key will expire.
261+
:type time: int
262+
:param compress_level: How much to compress.
263+
0 = no compression, 1 = fastest, 9 = slowest but best,
264+
-1 = default compression level.
265+
:type compress_level: int
266+
:return: A dict keyed by the string key of every input mapping. The
267+
value is the new CAS int on success or None on failure.
268+
:rtype: dict
269+
:raises NotImplementedError: if more than one server is configured.
270+
"""
271+
if len(self._servers) > 1:
272+
raise NotImplementedError(
273+
"set_multi_cas is not supported on ReplicatingClient with "
274+
"more than one server."
275+
)
276+
if not mappings:
277+
return {}
278+
return self._servers[0].set_multi_cas(mappings, time, compress_level=compress_level)
279+
248280
def add(self, key, value, time=0, compress_level=-1, get_cas=False):
249281
"""
250282
Add a key/value to server ony if it does not exist.

bmemcached/protocol.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,72 @@ def set_multi(self, mappings, time=100, compress_level=-1):
789789

790790
return failed
791791

792+
def set_multi_cas(self, mappings, time=100, compress_level=-1):
793+
"""
794+
Set multiple keys with their values on server and return the new CAS
795+
value for each successfully stored key.
796+
797+
If a key is a (key, cas) tuple, insert as if cas(key, value, cas) had
798+
been called. A cas of 0 means add-if-not-exists.
799+
800+
Unlike set_multi, this uses the non-quiet set/add opcodes so that the
801+
server responds to every request; this costs one response per key but
802+
is what allows per-key CAS values to be returned.
803+
804+
:param mappings: A dict with keys/values
805+
:type mappings: dict
806+
:param time: Time in seconds that your key will expire.
807+
:type time: int
808+
:param compress_level: How much to compress.
809+
0 = no compression, 1 = fastest, 9 = slowest but best,
810+
-1 = default compression level.
811+
:type compress_level: int
812+
:return: A dict keyed by the string key of every input mapping. The
813+
value is the new CAS int on success or None on failure.
814+
:rtype: dict
815+
"""
816+
mappings = list(mappings.items())
817+
msg = bytearray()
818+
result = {}
819+
820+
for opaque, (key, value) in enumerate(mappings):
821+
if isinstance(key, tuple):
822+
str_key, cas = key
823+
else:
824+
str_key, cas = key, None
825+
result[str_key] = None
826+
827+
if cas == 0:
828+
command = 'add'
829+
else:
830+
command = 'set'
831+
832+
keybytes = str_to_bytes(str_key)
833+
flags, value = self.serialize(value, compress_level=compress_level)
834+
msg += struct.pack(self.HEADER_STRUCT +
835+
self.COMMANDS[command]['struct'] % (len(keybytes), len(value)),
836+
self.MAGIC['request'],
837+
self.COMMANDS[command]['command'],
838+
len(keybytes),
839+
8, 0, 0, len(keybytes) + len(value) + 8, opaque, cas or 0,
840+
flags, time, keybytes, value)
841+
842+
self._send(msg)
843+
844+
# Non-quiet set/add return exactly one response per request, so we can
845+
# read a fixed count rather than relying on a trailing noop sentinel.
846+
for _ in range(len(mappings)):
847+
(magic, opcode, keylen, extlen, datatype, status, bodylen, opaque,
848+
cas, extra_content) = self._get_response()
849+
if status == self.STATUS['server_disconnected']:
850+
return result
851+
if status == self.STATUS['success']:
852+
key, value = mappings[opaque]
853+
str_key = key[0] if isinstance(key, tuple) else key
854+
result[str_key] = cas
855+
856+
return result
857+
792858
def _incr_decr(self, command, key, value, default, time):
793859
"""
794860
Function which increments and decrements.

test/test_simple_functions.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def tearDown(self):
2727
def reset(self):
2828
self.client.delete('test_key')
2929
self.client.delete('test_key2')
30+
self.client.delete('fresh_key')
3031

3132
def testSet(self):
3233
self.assertTrue(self.client.set('test_key', 'test'))
@@ -120,6 +121,51 @@ def testMultiCas(self):
120121
}), [])
121122
self.assertEqual(self.client.get('test_key'), 'value4')
122123

124+
def testSetMultiCas(self):
125+
# All-success plain keys: every input gets a non-None CAS, and each
126+
# returned CAS matches what gets() reports afterwards.
127+
result = self.client.set_multi_cas({
128+
'test_key': 'value1',
129+
'test_key2': 'value2',
130+
})
131+
self.assertEqual(set(result.keys()), {'test_key', 'test_key2'})
132+
self.assertTrue(result['test_key'] is not None)
133+
self.assertTrue(result['test_key2'] is not None)
134+
_, cas1 = self.client.gets('test_key')
135+
_, cas2 = self.client.gets('test_key2')
136+
self.assertEqual(result['test_key'], cas1)
137+
self.assertEqual(result['test_key2'], cas2)
138+
139+
# CAS failure: add-if-not-exists when the key already exists returns
140+
# None for that key; unrelated keys still succeed.
141+
result = self.client.set_multi_cas({
142+
('test_key', 0): 'shouldnt_store',
143+
'fresh_key': 'fresh',
144+
})
145+
self.assertTrue(result['test_key'] is None)
146+
self.assertTrue(result['fresh_key'] is not None)
147+
self.assertEqual(self.client.get('test_key'), 'value1')
148+
self.client.delete('fresh_key')
149+
150+
# Stale-CAS failure: capture cas, mutate out of band, then set_multi_cas
151+
# with the stale cas must fail and leave the out-of-band value intact.
152+
_, stale_cas = self.client.gets('test_key')
153+
self.client.set('test_key', 'other')
154+
result = self.client.set_multi_cas({
155+
('test_key', stale_cas): 'should_fail',
156+
})
157+
self.assertTrue(result['test_key'] is None)
158+
self.assertEqual(self.client.get('test_key'), 'other')
159+
160+
# Returned CAS is usable directly in cas() without a gets() round-trip.
161+
self.client.delete('test_key')
162+
result = self.client.set_multi_cas({'test_key': 'v'})
163+
self.assertTrue(self.client.cas('test_key', 'v2', result['test_key']))
164+
self.assertEqual(self.client.get('test_key'), 'v2')
165+
166+
def testSetMultiCasEmpty(self):
167+
self.assertEqual(self.client.set_multi_cas({}), {})
168+
123169
def testGetMultiCas(self):
124170
self.client.set('test_key', 'value1')
125171
self.client.set('test_key2', 'value2')
@@ -331,6 +377,8 @@ def testGetCasMultiReplicaRaises(self):
331377
client.replace('test_key', 'v', get_cas=True)
332378
with self.assertRaises(NotImplementedError):
333379
client.cas('test_key', 'v', None, get_cas=True)
380+
with self.assertRaises(NotImplementedError):
381+
client.set_multi_cas({'test_key': 'v'})
334382

335383
# get_cas=False (default) still works fine on multi-replica.
336384
self.assertTrue(client.set('test_key', 'v'))

0 commit comments

Comments
 (0)