Skip to content

Commit

Permalink
fix: Correct redis usage and GeneticCNN implementation (#55)
Browse files Browse the repository at this point in the history
- Add missing BatchNormalization layer in GeneticCNN
- Fix bug that slowed down RedisController with multiple workers
- Adapt RedisWorker and RedisController to changes in #52
- Adapt test cases to RedisWorker/RedisController fixes
- Match isort and black configurations
  • Loading branch information
gmontamat authored Sep 25, 2024
1 parent d06e3e9 commit 7e1b64b
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 28 deletions.
6 changes: 3 additions & 3 deletions examples/sample_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# docker run -d --rm --name gentun-redis -p 6379:6379 redis
worker = RedisWorker("test", Dummy, host="localhost", port=6379)

x_train = []
y_train = []
x_train, y_train = [], []
x_test, y_test = [], []

# Start worker process
worker.run(x_train, y_train)
worker.run(x_train, y_train, x_test, y_test)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ line-length = 120
fast = true

[tool.isort]
profile = "black"
line_length = 120

[tool.coverage.run]
Expand Down
2 changes: 1 addition & 1 deletion src/gentun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

from .config import setup_logging

__version__ = "0.0.2"
__version__ = "0.0.3"

setup_logging()
15 changes: 14 additions & 1 deletion src/gentun/models/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Activation, Add, Conv2D, Dense, Dropout, Flatten, Input, MaxPool2D
from tensorflow.keras.layers import (
Activation,
Add,
BatchNormalization,
Conv2D,
Dense,
Dropout,
Flatten,
Input,
MaxPool2D,
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model
Expand Down Expand Up @@ -124,6 +134,7 @@ def build_dag(x: Any, nodes: int, connections: str, kernels: int):
else:
tmp = add_vars[0]
tmp = Conv2D(kernels, kernel_size=(3, 3), strides=(1, 1), padding="same")(tmp)
tmp = BatchNormalization()(tmp)
tmp = Activation("relu")(tmp)
all_vars[i] = tmp
if not outs:
Expand All @@ -150,13 +161,15 @@ def build_model(
for layer, kernels in enumerate(kernels_per_layer):
# Default input node
x = Conv2D(kernels, kernel_size=kernel_sizes[layer], strides=(1, 1), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
# Decode internal connections
# If at least one bit is 1, then we need to construct the Directed Acyclic Graph
if not all(not bool(int(bit)) for bit in connections[layer]):
x = self.build_dag(x, nodes[layer], connections[layer], kernels)
# Output node
x = Conv2D(kernels, kernel_size=(3, 3), strides=(1, 1), padding="same")(x)
x = BatchNormalization()(x)
x = Activation("relu")(x)
x = MaxPool2D(pool_size=pool_sizes[layer], strides=(2, 2))(x)
x = Flatten()(x)
Expand Down
34 changes: 21 additions & 13 deletions src/gentun/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from gentun.services import RedisWorker
worker = RedisWorker("{name}", {handler}, host="{host}", port={port})
x_train, y_train = ... # get data
worker.run(x_train, y_train)
x_train, y_train, x_test, y_test = ... # get data
worker.run(x_train, y_train, x_test, y_test)
```
"""

Expand Down Expand Up @@ -79,19 +79,22 @@ def send_job(self, handler: Type[Handler], **kwargs) -> str:
"handler": handler.__name__,
"kwargs": kwargs,
}
self.client.rpush(self.job_queue, json.dumps(job))
self.client.lpush(self.job_queue, json.dumps(job))
return job_id

def wait_for_result(self, job_id) -> float:
"""Retrieve fitness from the results queue."""
start_time = time.time()
while time.time() - start_time < self.timeout:
result = self.client.lpop(self.results_queue)
if result:
result = json.loads(result)
data = self.client.rpop(self.results_queue)
if data:
result = json.loads(data)
if result["name"] == self.name and result["id"] == job_id:
return result["fitness"]
time.sleep(1)
# Leave data back in queue
self.client.lpush(self.results_queue, data)
else:
time.sleep(1)
raise TimeoutError(f"Could not get job with id {job_id}")


Expand All @@ -116,25 +119,30 @@ def __init__(
self.results_queue = results_queue
self.timeout = timeout

def process_job(self, x_train: Any, y_train: Any, **kwargs) -> float:
def process_job(self, x_train: Any, y_train: Any, x_test: Any, y_test: Any, **kwargs) -> float:
"""Call model handler, return fitness."""
return self.handler(**kwargs).evaluate(x_train, y_train)
return self.handler(**kwargs)(x_train, y_train, x_test, y_test)

def run(self, x_train: Any, y_train: Any):
def run(self, x_train: Any, y_train: Any, x_test: Any = None, y_test: Any = None):
"""Read jobs from queue, call handler, and return fitness."""
logging.info("Worker started (Ctrl+C to stop), waiting for jobs...")
try:
while True:
job_data = self.client.lpop(self.job_queue)
job_data = self.client.rpop(self.job_queue)
if job_data:
data = json.loads(job_data)
if data["name"] == self.name and data["handler"] == self.handler.__name__:
logging.info("Working on job %s", data["id"])
fitness = self.process_job(x_train, y_train, **data["kwargs"])
fitness = self.process_job(x_train, y_train, x_test, y_test, **data["kwargs"])
result = {"id": data["id"], "name": self.name, "fitness": fitness}
self.client.rpush(self.results_queue, json.dumps(result))
self.client.lpush(self.results_queue, json.dumps(result))
else:
# Job not used, do not dump
self.client.lpush(self.job_queue, job_data)
else:
logging.debug("No jobs in queue, sleeping for a while...")
time.sleep(1)
except KeyboardInterrupt:
if job_data:
self.client.lpush(self.job_queue, job_data)
logging.info("Bye!")
26 changes: 16 additions & 10 deletions tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, param1: int, param2: str = "default"):
self.param1 = param1
self.param2 = param2

def evaluate(self, x_train, y_train):
def create_train_evaluate(self, x_train, y_train, x_test, y_test):
return 0.9


Expand Down Expand Up @@ -47,14 +47,14 @@ def test_redis_controller_send_job(mock_redis):
# Send first job
job_id = controller.send_job(MockHandler, param1=1, param2="value")
assert isinstance(job_id, str)
job = json.loads(mock_redis.return_value.rpush.call_args[0][1])
job = json.loads(mock_redis.return_value.lpush.call_args[0][1])
assert job["name"] == "test"
assert job["handler"] == "MockHandler"
assert job["kwargs"] == {"param1": 1, "param2": "value"}
# Send a second job
job_id = controller.send_job(MockHandler, param1=2, param2="value2")
assert isinstance(job_id, str)
job = json.loads(mock_redis.return_value.rpush.call_args[0][1])
job = json.loads(mock_redis.return_value.lpush.call_args[0][1])
assert job["name"] == "test"
assert job["handler"] == "MockHandler"
assert job["kwargs"] == {"param1": 2, "param2": "value2"}
Expand All @@ -66,7 +66,7 @@ def test_redis_controller_wait_for_result(mock_redis):
job_id = "test_job_id"
result = {"id": job_id, "name": "test", "fitness": 0.9}
ignore_result = {"id": "not_test_job_id", "name": "test", "fitness": 0.9}
mock_redis.return_value.lpop.side_effect = [None, json.dumps(ignore_result), json.dumps(result)]
mock_redis.return_value.rpop.side_effect = [None, json.dumps(ignore_result), json.dumps(result)]
fitness = controller.wait_for_result(job_id)
assert fitness == 0.9

Expand All @@ -75,7 +75,7 @@ def test_redis_controller_wait_for_result(mock_redis):
def test_redis_controller_wait_for_result_timeout(mock_redis):
controller = RedisController("test", timeout=1)
job_id = "test_job_id"
mock_redis.return_value.lpop.return_value = None
mock_redis.return_value.rpop.return_value = None
with pytest.raises(TimeoutError):
controller.wait_for_result(job_id)

Expand All @@ -94,7 +94,7 @@ def test_redis_worker_init(mock_redis):
@patch("src.gentun.services.redis.StrictRedis")
def test_redis_worker_process_job(mock_redis):
worker = RedisWorker("test", MockHandler)
fitness = worker.process_job([1, 2, 3], [4, 5, 6], param1=1, param2="value")
fitness = worker.process_job([1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 1, 2], param1=1, param2="value")
assert fitness == 0.9


Expand All @@ -113,12 +113,18 @@ def test_redis_worker_run(mock_redis):
"handler": "NotMockHandler",
"kwargs": {"param1": 1, "param2": "value"},
}
mock_redis.return_value.lpop.side_effect = [json.dumps(ignore_job_data)] + [json.dumps(job_data)] + [None]
mock_redis.return_value.rpop.side_effect = [json.dumps(ignore_job_data), json.dumps(job_data), None]
with patch.object(worker, "process_job", return_value=0.9) as mock_process_job:
with patch("time.sleep", side_effect=KeyboardInterrupt):
worker.run([1, 2, 3], [4, 5, 6])
mock_process_job.assert_called_once_with([1, 2, 3], [4, 5, 6], param1=1, param2="value")
result = json.loads(mock_redis.return_value.rpush.call_args[0][1])
worker.run([1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 1, 2])
mock_process_job.assert_called_once_with(
[1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 1, 2], param1=1, param2="value"
)
result = json.loads(mock_redis.return_value.lpush.call_args[0][1])
assert result["id"] == "test_job_id"
assert result["name"] == "test"
assert result["fitness"] == 0.9
mock_redis.return_value.rpop.side_effect = [json.dumps(ignore_job_data)]
with patch("json.loads", side_effect=KeyboardInterrupt):
worker.run([1, 2, 3], [4, 5, 6], [7, 8, 9], [0, 1, 2])
mock_redis.return_value.lpush.assert_any_call(worker.job_queue, json.dumps(ignore_job_data))

0 comments on commit 7e1b64b

Please sign in to comment.