Skip to content

Commit c22bf58

Browse files
authored
Fix VarInt/VarLong encoding; move tests to test/protocol/ (#2706)
1 parent 9016c02 commit c22bf58

File tree

7 files changed

+282
-177
lines changed

7 files changed

+282
-177
lines changed

kafka/protocol/types.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,17 @@ def repr(self, list_of_items):
226226

227227

228228
class UnsignedVarInt32(AbstractType):
229+
@classmethod
230+
def decode(cls, data):
231+
value = VarInt32.decode(data)
232+
return (value << 1) ^ (value >> 31)
233+
234+
@classmethod
235+
def encode(cls, value):
236+
return VarInt32.encode((value >> 1) ^ -(value & 1))
237+
238+
239+
class VarInt32(AbstractType):
229240
@classmethod
230241
def decode(cls, data):
231242
value, i = 0, 0
@@ -238,10 +249,12 @@ def decode(cls, data):
238249
if i > 28:
239250
raise ValueError('Invalid value {}'.format(value))
240251
value |= b << i
241-
return value
252+
return (value >> 1) ^ -(value & 1)
242253

243254
@classmethod
244255
def encode(cls, value):
256+
# bring it in line with the java binary repr
257+
value = (value << 1) ^ (value >> 31)
245258
value &= 0xffffffff
246259
ret = b''
247260
while (value & 0xffffff80) != 0:
@@ -252,25 +265,12 @@ def encode(cls, value):
252265
return ret
253266

254267

255-
class VarInt32(AbstractType):
256-
@classmethod
257-
def decode(cls, data):
258-
value = UnsignedVarInt32.decode(data)
259-
return (value >> 1) ^ -(value & 1)
260-
261-
@classmethod
262-
def encode(cls, value):
263-
# bring it in line with the java binary repr
264-
value &= 0xffffffff
265-
return UnsignedVarInt32.encode((value << 1) ^ (value >> 31))
266-
267-
268268
class VarInt64(AbstractType):
269269
@classmethod
270270
def decode(cls, data):
271271
value, i = 0, 0
272272
while True:
273-
b = data.read(1)
273+
b, = struct.unpack('B', data.read(1))
274274
if not (b & 0x80):
275275
break
276276
value |= (b & 0x7f) << i
@@ -283,14 +283,14 @@ def decode(cls, data):
283283
@classmethod
284284
def encode(cls, value):
285285
# bring it in line with the java binary repr
286+
value = (value << 1) ^ (value >> 63)
286287
value &= 0xffffffffffffffff
287-
v = (value << 1) ^ (value >> 63)
288288
ret = b''
289-
while (v & 0xffffffffffffff80) != 0:
289+
while (value & 0xffffffffffffff80) != 0:
290290
b = (value & 0x7f) | 0x80
291291
ret += struct.pack('B', b)
292-
v >>= 7
293-
ret += struct.pack('B', v)
292+
value >>= 7
293+
ret += struct.pack('B', value)
294294
return ret
295295

296296

test/protocol/test_api.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import struct
2+
3+
import pytest
4+
5+
from kafka.protocol.api import RequestHeader
6+
from kafka.protocol.fetch import FetchRequest
7+
from kafka.protocol.find_coordinator import FindCoordinatorRequest
8+
from kafka.protocol.metadata import MetadataRequest
9+
10+
11+
def test_encode_message_header():
12+
expect = b''.join([
13+
struct.pack('>h', 10), # API Key
14+
struct.pack('>h', 0), # API Version
15+
struct.pack('>i', 4), # Correlation Id
16+
struct.pack('>h', len('client3')), # Length of clientId
17+
b'client3', # ClientId
18+
])
19+
20+
req = FindCoordinatorRequest[0]('foo')
21+
header = RequestHeader(req, correlation_id=4, client_id='client3')
22+
assert header.encode() == expect
23+
24+
25+
def test_struct_unrecognized_kwargs():
26+
try:
27+
_mr = MetadataRequest[0](topicz='foo')
28+
assert False, 'Structs should not allow unrecognized kwargs'
29+
except ValueError:
30+
pass
31+
32+
33+
def test_struct_missing_kwargs():
34+
fr = FetchRequest[0](max_wait_time=100)
35+
assert fr.min_bytes is None

test/protocol/test_bit_field.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import io
2+
3+
import pytest
4+
5+
from kafka.protocol.types import BitField
6+
7+
8+
@pytest.mark.parametrize(('test_set',), [
9+
(set([0, 1, 5, 10, 31]),),
10+
(set(range(32)),),
11+
])
12+
def test_bit_field(test_set):
13+
assert BitField.decode(io.BytesIO(BitField.encode(test_set))) == test_set

test/protocol/test_compact.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import io
2+
import struct
3+
4+
import pytest
5+
6+
from kafka.protocol.types import CompactString, CompactArray, CompactBytes
7+
8+
9+
def test_compact_data_structs():
10+
cs = CompactString()
11+
encoded = cs.encode(None)
12+
assert encoded == struct.pack('B', 0)
13+
decoded = cs.decode(io.BytesIO(encoded))
14+
assert decoded is None
15+
assert b'\x01' == cs.encode('')
16+
assert '' == cs.decode(io.BytesIO(b'\x01'))
17+
encoded = cs.encode("foobarbaz")
18+
assert cs.decode(io.BytesIO(encoded)) == "foobarbaz"
19+
20+
arr = CompactArray(CompactString())
21+
assert arr.encode(None) == b'\x00'
22+
assert arr.decode(io.BytesIO(b'\x00')) is None
23+
enc = arr.encode([])
24+
assert enc == b'\x01'
25+
assert [] == arr.decode(io.BytesIO(enc))
26+
encoded = arr.encode(["foo", "bar", "baz", "quux"])
27+
assert arr.decode(io.BytesIO(encoded)) == ["foo", "bar", "baz", "quux"]
28+
29+
enc = CompactBytes.encode(None)
30+
assert enc == b'\x00'
31+
assert CompactBytes.decode(io.BytesIO(b'\x00')) is None
32+
enc = CompactBytes.encode(b'')
33+
assert enc == b'\x01'
34+
assert CompactBytes.decode(io.BytesIO(b'\x01')) == b''
35+
enc = CompactBytes.encode(b'foo')
36+
assert CompactBytes.decode(io.BytesIO(enc)) == b'foo'
37+
38+

test/protocol/test_fetch.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#pylint: skip-file
2+
import io
3+
import struct
4+
5+
import pytest
6+
7+
from kafka.protocol.fetch import FetchResponse
8+
from kafka.protocol.types import Int16, Int32, Int64, String
9+
10+
11+
def test_decode_fetch_response_partial():
12+
encoded = b''.join([
13+
Int32.encode(1), # Num Topics (Array)
14+
String('utf-8').encode('foobar'),
15+
Int32.encode(2), # Num Partitions (Array)
16+
Int32.encode(0), # Partition id
17+
Int16.encode(0), # Error Code
18+
Int64.encode(1234), # Highwater offset
19+
Int32.encode(52), # MessageSet size
20+
Int64.encode(0), # Msg Offset
21+
Int32.encode(18), # Msg Size
22+
struct.pack('>i', 1474775406), # CRC
23+
struct.pack('>bb', 0, 0), # Magic, flags
24+
struct.pack('>i', 2), # Length of key
25+
b'k1', # Key
26+
struct.pack('>i', 2), # Length of value
27+
b'v1', # Value
28+
29+
Int64.encode(1), # Msg Offset
30+
struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size)
31+
struct.pack('>i', -16383415), # CRC
32+
struct.pack('>bb', 0, 0), # Magic, flags
33+
struct.pack('>i', 2), # Length of key
34+
b'k2', # Key
35+
struct.pack('>i', 8), # Length of value
36+
b'ar', # Value (truncated)
37+
Int32.encode(1),
38+
Int16.encode(0),
39+
Int64.encode(2345),
40+
Int32.encode(52), # MessageSet size
41+
Int64.encode(0), # Msg Offset
42+
Int32.encode(18), # Msg Size
43+
struct.pack('>i', 1474775406), # CRC
44+
struct.pack('>bb', 0, 0), # Magic, flags
45+
struct.pack('>i', 2), # Length of key
46+
b'k1', # Key
47+
struct.pack('>i', 2), # Length of value
48+
b'v1', # Value
49+
50+
Int64.encode(1), # Msg Offset
51+
struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size)
52+
struct.pack('>i', -16383415), # CRC
53+
struct.pack('>bb', 0, 0), # Magic, flags
54+
struct.pack('>i', 2), # Length of key
55+
b'k2', # Key
56+
struct.pack('>i', 8), # Length of value
57+
b'ar', # Value (truncated)
58+
])
59+
resp = FetchResponse[0].decode(io.BytesIO(encoded))
60+
assert len(resp.topics) == 1
61+
topic, partitions = resp.topics[0]
62+
assert topic == 'foobar'
63+
assert len(partitions) == 2
64+
65+
#m1 = MessageSet.decode(
66+
# partitions[0][3], bytes_to_read=len(partitions[0][3]))
67+
#assert len(m1) == 2
68+
#assert m1[1] == (None, None, PartialMessage())

test/protocol/test_varint.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import io
2+
import struct
3+
4+
import pytest
5+
6+
from kafka.protocol.types import UnsignedVarInt32, VarInt32, VarInt64
7+
8+
9+
@pytest.mark.parametrize(('value','expected_encoded'), [
10+
(0, [0x00]),
11+
(-1, [0xFF, 0xFF, 0xFF, 0xFF, 0x0F]),
12+
(1, [0x01]),
13+
(63, [0x3F]),
14+
(-64, [0xC0, 0xFF, 0xFF, 0xFF, 0x0F]),
15+
(64, [0x40]),
16+
(8191, [0xFF, 0x3F]),
17+
(-8192, [0x80, 0xC0, 0xFF, 0xFF, 0x0F]),
18+
(8192, [0x80, 0x40]),
19+
(-8193, [0xFF, 0xBF, 0xFF, 0xFF, 0x0F]),
20+
(1048575, [0xFF, 0xFF, 0x3F]),
21+
(1048576, [0x80, 0x80, 0x40]),
22+
(2147483647, [0xFF, 0xFF, 0xFF, 0xFF, 0x07]),
23+
(-2147483648, [0x80, 0x80, 0x80, 0x80, 0x08]),
24+
])
25+
def test_unsigned_varint_serde(value, expected_encoded):
26+
value &= 0xffffffff
27+
encoded = UnsignedVarInt32.encode(value)
28+
assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded)
29+
assert value == UnsignedVarInt32.decode(io.BytesIO(encoded))
30+
31+
32+
@pytest.mark.parametrize(('value','expected_encoded'), [
33+
(0, [0x00]),
34+
(-1, [0x01]),
35+
(1, [0x02]),
36+
(63, [0x7E]),
37+
(-64, [0x7F]),
38+
(64, [0x80, 0x01]),
39+
(-65, [0x81, 0x01]),
40+
(8191, [0xFE, 0x7F]),
41+
(-8192, [0xFF, 0x7F]),
42+
(8192, [0x80, 0x80, 0x01]),
43+
(-8193, [0x81, 0x80, 0x01]),
44+
(1048575, [0xFE, 0xFF, 0x7F]),
45+
(-1048576, [0xFF, 0xFF, 0x7F]),
46+
(1048576, [0x80, 0x80, 0x80, 0x01]),
47+
(-1048577, [0x81, 0x80, 0x80, 0x01]),
48+
(134217727, [0xFE, 0xFF, 0xFF, 0x7F]),
49+
(-134217728, [0xFF, 0xFF, 0xFF, 0x7F]),
50+
(134217728, [0x80, 0x80, 0x80, 0x80, 0x01]),
51+
(-134217729, [0x81, 0x80, 0x80, 0x80, 0x01]),
52+
(2147483647, [0xFE, 0xFF, 0xFF, 0xFF, 0x0F]),
53+
(-2147483648, [0xFF, 0xFF, 0xFF, 0xFF, 0x0F]),
54+
])
55+
def test_signed_varint_serde(value, expected_encoded):
56+
encoded = VarInt32.encode(value)
57+
assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded)
58+
assert value == VarInt32.decode(io.BytesIO(encoded))
59+
60+
61+
@pytest.mark.parametrize(('value','expected_encoded'), [
62+
(0, [0x00]),
63+
(-1, [0x01]),
64+
(1, [0x02]),
65+
(63, [0x7E]),
66+
(-64, [0x7F]),
67+
(64, [0x80, 0x01]),
68+
(-65, [0x81, 0x01]),
69+
(8191, [0xFE, 0x7F]),
70+
(-8192, [0xFF, 0x7F]),
71+
(8192, [0x80, 0x80, 0x01]),
72+
(-8193, [0x81, 0x80, 0x01]),
73+
(1048575, [0xFE, 0xFF, 0x7F]),
74+
(-1048576, [0xFF, 0xFF, 0x7F]),
75+
(1048576, [0x80, 0x80, 0x80, 0x01]),
76+
(-1048577, [0x81, 0x80, 0x80, 0x01]),
77+
(134217727, [0xFE, 0xFF, 0xFF, 0x7F]),
78+
(-134217728, [0xFF, 0xFF, 0xFF, 0x7F]),
79+
(134217728, [0x80, 0x80, 0x80, 0x80, 0x01]),
80+
(-134217729, [0x81, 0x80, 0x80, 0x80, 0x01]),
81+
(2147483647, [0xFE, 0xFF, 0xFF, 0xFF, 0x0F]),
82+
(-2147483648, [0xFF, 0xFF, 0xFF, 0xFF, 0x0F]),
83+
(17179869183, [0xFE, 0xFF, 0xFF, 0xFF, 0x7F]),
84+
(-17179869184, [0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
85+
(17179869184, [0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
86+
(-17179869185, [0x81, 0x80, 0x80, 0x80, 0x80, 0x01]),
87+
(2199023255551, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
88+
(-2199023255552, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
89+
(2199023255552, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
90+
(-2199023255553, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
91+
(281474976710655, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
92+
(-281474976710656, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
93+
(281474976710656, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
94+
(-281474976710657, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 1]),
95+
(36028797018963967, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
96+
(-36028797018963968, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
97+
(36028797018963968, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
98+
(-36028797018963969, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
99+
(4611686018427387903, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
100+
(-4611686018427387904, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]),
101+
(4611686018427387904, [0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
102+
(-4611686018427387905, [0x81, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]),
103+
(9223372036854775807, [0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01]),
104+
(-9223372036854775808, [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01]),
105+
])
106+
def test_signed_varlong_serde(value, expected_encoded):
107+
encoded = VarInt64.encode(value)
108+
assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded)
109+
assert value == VarInt64.decode(io.BytesIO(encoded))

0 commit comments

Comments
 (0)