Skip to content

Commit 2b42373

Browse files
committed
Add basic backend testing
1 parent 92a082c commit 2b42373

File tree

6 files changed

+174
-23
lines changed

6 files changed

+174
-23
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ tags
66
*.log
77
*.conf
88
*.egg-info
9+
.coverage

tasq/__init__.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tasq.worker.jobqueue import JobQueue
55
from tasq.worker.executor import ProcessQueueExecutor
66
from tasq.remote.client import Client
7-
from tasq.remote.backend import RedisStoreBackend
7+
from tasq.remote.backend import RedisStoreBackend, RedisBackend
88
from tasq.remote.runner import Runners
99
from tasq.worker.actors import ClientWorker
1010
from tasq.actors.routers import RoundRobinRouter, actor_pool
@@ -27,7 +27,7 @@
2727
}
2828

2929

30-
def queue(url="zmq://localhost:9000", store=None, signkey=None):
30+
def queue(backend="zmq://localhost:9000", store=None, signkey=None):
3131
"""
3232
Create a TasqQueue instance.
3333
The formats accepted for the backends are:
@@ -53,17 +53,20 @@ def queue(url="zmq://localhost:9000", store=None, signkey=None):
5353
:param signkey: A string representing a shared key, sign data with a shared
5454
key
5555
"""
56-
url_parsed = urlparse(url)
57-
scheme = url_parsed.scheme or "zmq"
58-
assert scheme in _backends, f"Unsupported {scheme} as backend"
59-
backend = _backends[scheme].from_url(url, signkey)
60-
client = Client(backend)
56+
if isinstance(backend, str):
57+
url_parsed = urlparse(backend)
58+
scheme = url_parsed.scheme or "zmq"
59+
assert scheme in _backends, f"Unsupported {scheme} as backend"
60+
_backend = _backends[scheme].from_url(backend, signkey)
61+
client = Client(_backend)
62+
elif isinstance(backend, RedisBackend):
63+
client = Client(BackendConnection(backend))
6164
if store:
6265
urlstore = urlparse(store)
6366
assert urlstore.scheme in {
6467
"redis"
6568
}, f"Unknown {urlstore.scheme} as store"
66-
db = int(urlstore.path.split("/")[-1]) if url.query else 0
69+
db = int(urlstore.path.split("/")[-1]) if urlstore.query else 0
6770
store = RedisStoreBackend(urlstore.hostname, urlstore.port, db)
6871
return TasqQueue(client, store)
6972

tasq/remote/backend.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,11 @@ class RedisQueue:
138138

139139
log = get_logger(__name__)
140140

141-
def __init__(self, name, host, port, db, namespace="queue"):
141+
def __init__(self, name, redis_driver, namespace="queue"):
142142
"""The default connection parameters are: host='localhost',
143143
port=6379, db=0
144144
"""
145-
self._db = redis.StrictRedis(host=host, port=port, db=db)
145+
self._db = redis_driver
146146

147147
try:
148148
_ = self._db.dbsize()
@@ -152,6 +152,9 @@ def __init__(self, name, host, port, db, namespace="queue"):
152152
self._queue_name = f"{namespace}:{name}"
153153
self._work_queue_name = f"{namespace}:{name}:work"
154154

155+
def __repr__(self):
156+
return f"{self._db}:({self._queue_name}, {self._work_queue_name})"
157+
155158
def qsize(self):
156159
"""Return the approximate size of the queue."""
157160
return self._db.llen(self._queue_name)
@@ -245,22 +248,16 @@ def list_working_items(self):
245248
def close(self):
246249
self._db.connection_pool.disconnect()
247250

248-
def __init__(self, host, port, db, name, namespace="queue"):
251+
def __init__(self, redis_factory, name, namespace="queue"):
249252

250-
self._rq = self.RedisQueue(name, host, port, db, namespace)
253+
self._rq = self.RedisQueue(name, redis_factory(), namespace)
251254
self._rq_res = self.RedisQueue(
252-
f"{name}:result", host, port, db, namespace
255+
f"{name}:result", redis_factory(), namespace
253256
)
254-
self._host = host
255-
self._port = port
256-
self._db = db
257257
self._namespace = f"{namespace}:{name}"
258258

259259
def __repr__(self):
260-
return (
261-
f"RedisBackend(redis://{self._host}:{self._port}/{self._db}"
262-
f"?name={self._namespace}"
263-
)
260+
return f"RedisBackend(redis://{self._rq}"
264261

265262
def put_job(self, serialized_job):
266263
self._rq.put(serialized_job)

tasq/remote/connection.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77

88
import zmq
99
from urllib.parse import urlparse
10+
11+
try:
12+
import redis
13+
except ImportError:
14+
print("You need to install redis python driver to use redis backend")
15+
1016
from .backend import RedisBackend, RabbitMQBackend
1117
from .sockets import CloudPickleContext, BackendSocket
1218
from ..exception import BackendCommunicationErrorException
@@ -159,11 +165,12 @@ def from_url(cls, url, signkey=None):
159165
conn_args = {
160166
"host": u.hostname or "localhost",
161167
"port": u.port or 6379,
162-
"name": name,
163168
}
164169
if scheme == "redis":
165170
conn_args["db"] = int(extraparams.get("db", 0))
166-
backend = RedisBackend(**conn_args)
171+
backend = RedisBackend(
172+
lambda: redis.StrictRedis(**conn_args), name=name
173+
)
167174
else:
168175
conn_args["role"] = extraparams.get("role", "sender")
169176
backend = RabbitMQBackend(**conn_args)
@@ -174,7 +181,8 @@ def connect_redis_backend(
174181
host, port, db, name, namespace="queue", signkey=None
175182
):
176183
return BackendConnection(
177-
RedisBackend(host, port, db, name, namespace), signkey=signkey,
184+
RedisBackend(lambda: redis.StrictRedis(host, port, db), name, namespace),
185+
signkey=signkey,
178186
)
179187

180188

tasq/remote/runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def start(self):
4444
for incoming tasks and run the asyncio loop forever
4545
"""
4646
self._run = True
47+
self._log.debug("Listening on %s", self._backend)
4748
self.run()
4849

4950
def run(self):

tests/backend_test.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import time
2+
import asyncio
3+
import unittest
4+
import collections
5+
from unittest.mock import patch
6+
from tasq.remote.backend import ZMQBackend, RedisBackend, RabbitMQBackend
7+
8+
9+
class FakeSocket:
10+
def __init__(self):
11+
self.bind_url = u""
12+
self.data_sent = None
13+
14+
async def send_data(self, data, flags, signkey):
15+
self.data_sent = (data, flags, signkey)
16+
17+
async def recv_data(self, unpickle, flags, signkey):
18+
return self.data_sent
19+
20+
def bind(self, url):
21+
self.bind_url = url
22+
23+
def close(self):
24+
pass
25+
26+
27+
class FakeRedisPool:
28+
def disconnect(self):
29+
pass
30+
31+
32+
class FakeRedisClient:
33+
def __init__(self):
34+
self.queues = collections.defaultdict(list)
35+
self.connection_pool = FakeRedisPool()
36+
37+
def dbsize(self):
38+
return 0
39+
40+
def llen(self, queue_name):
41+
return len(self.queues[queue_name])
42+
43+
def lpush(self, queue_name, item):
44+
self.queues[queue_name].append(item)
45+
46+
def brpop(self, queue_name, timeout=0):
47+
if timeout and not self.queues[queue_name]:
48+
time.sleep(timeout)
49+
return None
50+
return (
51+
None,
52+
self.queues[queue_name],
53+
)
54+
55+
def rpop(self, queue_name):
56+
return self.brpop(queue_name)
57+
58+
def brpoplpush(self, queue_name, queue_name_alt, timeout=0):
59+
if timeout and not self.queues[queue_name]:
60+
time.sleep(timeout)
61+
return None
62+
item = self.queues[queue_name]
63+
self.queues[queue_name_alt].append(item)
64+
return item
65+
66+
def rpoplpush(self, queue_name, queue_name_alt):
67+
return self.brpoplpush(queue_name, queue_name_alt)
68+
69+
def lrange(self, queue_name, start, end):
70+
return self.queues[queue_name][start:end]
71+
72+
73+
class TestZMQBackend(unittest.TestCase):
74+
def test_init_zmqbackend(self):
75+
with patch(
76+
"tasq.remote.backend.AsyncCloudPickleContext.socket"
77+
) as mock:
78+
mock.side_effect = lambda _: FakeSocket()
79+
backend = ZMQBackend("localhost", 10000, 10001)
80+
backend.bind()
81+
self.assertEqual(
82+
backend._pull_socket.bind_url, "tcp://localhost:10001"
83+
)
84+
self.assertEqual(
85+
backend._push_socket.bind_url, "tcp://localhost:10000"
86+
)
87+
backend.stop()
88+
89+
def test_send_zmqbackend(self):
90+
fake_socket = FakeSocket()
91+
with patch(
92+
"tasq.remote.backend.AsyncCloudPickleContext.socket"
93+
) as mock:
94+
mock.return_value = fake_socket
95+
backend = ZMQBackend("localhost", 10000, 10001)
96+
backend.bind()
97+
asyncio.run(backend.send("hello"))
98+
self.assertEqual(fake_socket.data_sent, ("hello", 0, None))
99+
100+
def test_recv_zmqbackend(self):
101+
fake_socket = FakeSocket()
102+
with patch(
103+
"tasq.remote.backend.AsyncCloudPickleContext.socket"
104+
) as mock:
105+
mock.return_value = fake_socket
106+
backend = ZMQBackend("localhost", 10000, 10001)
107+
backend.bind()
108+
asyncio.run(backend.send("hello"))
109+
self.assertEqual(fake_socket.data_sent, ("hello", 0, None))
110+
payload = asyncio.run(backend.recv())
111+
self.assertEqual(payload, ("hello", 0, None))
112+
113+
114+
class TestRedisBackend(unittest.TestCase):
115+
def test_redis_put_job(self):
116+
backend = RedisBackend(lambda: FakeRedisClient(), name="test-queue")
117+
backend.put_job("test-job")
118+
self.assertEqual(backend.get_next_job(), ["test-job"])
119+
backend.close()
120+
121+
def test_redis_put_result(self):
122+
backend = RedisBackend(lambda: FakeRedisClient(), name="test-queue")
123+
backend.put_result("test-job-result")
124+
self.assertEqual(backend.get_available_result(), ["test-job-result"])
125+
backend.close()
126+
127+
def test_redis_get_next_job(self):
128+
backend = RedisBackend(lambda: FakeRedisClient(), name="test-queue")
129+
backend.put_job("test-job")
130+
self.assertEqual(backend.get_next_job(), ["test-job"])
131+
backend.close()
132+
133+
def test_redis_get_pending_jobs(self):
134+
backend = RedisBackend(lambda: FakeRedisClient(), name="test-queue")
135+
backend.put_job("test-job")
136+
self.assertEqual(backend.get_pending_jobs(), ([], []))
137+
backend.close()
138+
139+
140+
class TestRabbitMQBackend(unittest.TestCase):
141+
pass

0 commit comments

Comments
 (0)