Skip to content

Commit

Permalink
feat: async
Browse files Browse the repository at this point in the history
  • Loading branch information
Elvis339 committed Oct 8, 2024
1 parent 5f0a8dd commit 968906a
Show file tree
Hide file tree
Showing 7 changed files with 433 additions and 217 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "threadsafe_zmq"
version = "1.0.0"
version = "1.0.1"
edition = "2021"
authors = ["Elvis Sabanovic <[email protected]>"]
description = "Threadsafe zeromq"
Expand Down
3 changes: 2 additions & 1 deletion example/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ env_logger = "0.11.5"
zmq = "0.10.0"
rand = "0.8.5"

threadsafe_zmq = { path = ".." }
threadsafe_zmq = { path = ".." }
tokio = { version = "1.40.0", features = ["rt", "rt-multi-thread", "macros"] }
94 changes: 47 additions & 47 deletions example/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,67 +1,67 @@
use env_logger;
use crate::utils::to_string;
use log::{error, info};
use rand::Rng;
use zmq::Context;

mod utils;

fn main() {
env_logger::init();

let clients = 4;
let mut handles = Vec::with_capacity(clients);
let addr = "tcp://localhost:5555";
let ctx = Context::new();
let socket = ctx
.socket(zmq::DEALER)
.expect("Failed to create DEALER socket");

for i in 0..clients {
let client_id = i;
let handle = std::thread::spawn(move || {
let addr = "tcp://localhost:5555";
let ctx = Context::new();
let socket = ctx
.socket(zmq::DEALER)
.expect("Failed to create PAIR socket");
let id = format!("Client-{}", 1);
socket
.set_identity(id.clone().as_bytes())
.expect("Failed to set identity");
socket.connect(addr).expect("Failed to connect to server");

let rand_id = client_id as u8 + generate_random_number();
let id = format!("client-{}", rand_id);
socket
.set_identity(id.clone().as_bytes())
.expect("Failed to set identity");
socket.connect(addr).expect("Failed to connect to server");
info!("Connected to: {}", addr);
loop {
let rand_num = to_string(generate_random_number());

info!("{} connected to: {}", id, addr);
loop {
let rand_num = generate_random_number();
let rand_num_bytes = rand_num.to_le_bytes().to_vec();
// Send the message to the server
match socket.send(rand_num.as_bytes(), 0) {
Ok(_) => {
info!("SND: fib({})=?", rand_num.clone());
}
Err(snd_err) => {
error!("SND: failed to send message: {:?}", snd_err);
continue;
}
}

match socket.send_multipart(vec![rand_num_bytes], 0) {
Ok(_) => info!("{}, sent number: {}", id, rand_num),
Err(snd_err) => {
error!("{}, failed to send message: {:?}", id, snd_err);
continue;
}
}
// Receive the response from the server
match socket.recv_multipart(0) {
Ok(message) => {
for (_, frame) in message.iter().enumerate() {
match String::from_utf8(frame.clone()) {
Ok(str_frame) => {
if str_frame.is_empty() {
continue;
}

match socket.recv_multipart(0) {
Ok(message) => {
info!("Client {}, received result: {:?}", client_id, message);
}
Err(rcv_err) => {
error!(
"Client {}, failed to receive message: {:?}",
client_id, rcv_err
);
info!("RCV: fib({})={}", rand_num, str_frame);
info!("--------------------------------------------------")
}
Err(e) => {
error!("RCV: failed to convert frame to string: {:?}", e);
}
}
}

std::thread::sleep(std::time::Duration::from_millis(100));
}
});
handles.push(handle);
}

loop {
std::thread::sleep(std::time::Duration::from_secs(1));
Err(rcv_err) => {
error!("RCV: failed to receive message: {:?}", rcv_err);
}
}
}
}

fn generate_random_number() -> u8 {
fn generate_random_number() -> u64 {
let mut rng = rand::thread_rng();
rng.gen_range(0..=30)
rng.gen_range(1..=80)
}
109 changes: 55 additions & 54 deletions example/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
use env_logger;
use log::{debug, error, info};
use threadsafe_zmq::{ChannelPair, Sender, ZmqByteStream};
use log::{error, info};
use std::sync::Arc;
use threadsafe_zmq::{ChannelPair, ZmqByteStream};
use tokio::task;
use zmq::Context;

fn main() {
mod utils;
use crate::utils::to_string;

#[tokio::main]
async fn main() {
env_logger::init();

let addr = "tcp://*:5555";
Expand All @@ -13,72 +19,67 @@ fn main() {
.expect("Failed to create ROUTER socket");
socket.bind(addr).expect("Failed to bind to address");

let channel_pair = ChannelPair::new(socket).expect("Failed to create channel pair");
let channel_pair = ChannelPair::new(&ctx, socket).expect("Failed to create channel pair");
info!("Server listening on {}", addr);

loop {
debug!("Waiting to receive messages...");

match channel_pair.rx_chan().recv() {
Ok(message) => {
if message.len() < 2 {
error!("Received malformed message: {:?}", message);
continue;
}

println!("Received message: {:?}", message);
let cp = channel_pair.clone();
std::thread::spawn(move || {
calculate_fib(message, cp.tx_chan());
});
let channel_pair_clone = Arc::clone(&channel_pair);
match task::spawn_blocking(move || channel_pair_clone.rx().recv()).await {
Ok(Ok(message)) => {
let channel_pair_for_task = Arc::clone(&channel_pair);
task::spawn(handle_message(message, channel_pair_for_task));
}
Err(rcv_err) => {
error!("Failed to receive message: {:?}", rcv_err);
Ok(Err(e)) => {
error!("Failed to receive message: {:?}", e);
break;
}
Err(e) => {
error!("Task join error: {:?}", e);
break;
}
}
}
}

fn calculate_fib(messages: ZmqByteStream, sender: &Sender) {
// The first part is the identity, and the second part is the actual message
async fn handle_message(messages: ZmqByteStream, channel_pair: Arc<ChannelPair>) {
let identity = messages[0].clone();
let payload = messages[1].clone();

let id_str = String::from_utf8_lossy(&identity);
if let Ok(str_num) = String::from_utf8(messages.last().unwrap().clone()) {
if let Ok(number) = str_num.parse::<u64>() {
let result = task::spawn_blocking(move || fibonacci(number)).await;

if payload.is_empty() {
error!("Received an empty payload, skipping Fibonacci calculation.");
return;
match result {
Ok(result) => {
let result_bytes = to_string(result);
if let Err(err) = channel_pair
.tx()
.send(vec![identity.clone(), result_bytes.as_bytes().to_vec()])
{
error!("Failed to send response: {:?}", err);
} else {
info!("SND: fib({})={}", number, result);
}
}
Err(e) => {
error!("Task join error while calculating fibonacci({}): {:?}", number, e);
}
}
}
}
}

info!("Received message from: {:?}", id_str);

// Deserialize the message into u32
// let number = match payload.as_slice().try_into() {
// Ok(bytes) => u32::from_le_bytes(bytes),
// Err(_) => {
// error!("Failed to deserialize payload, skipping.");
// return;
// }
// };
let number = 13;
fn fibonacci(n: u64) -> u64 {
if n <= 1 {
return n;
}

info!("Calculating Fibonacci for number: {}", number);
let result = fibonacci_recursive(number);
let result_bytes = result.to_le_bytes().to_vec();
let mut a: u64 = 0;
let mut b: u64 = 1;

// The response must include the identity frame, followed by the result
let response = vec![identity.clone(), result_bytes];
match sender.send(response) {
Ok(_) => info!("Successfully sent response: {:?} to: {:?}", result, id_str),
Err(err) => error!("Failed to send response: {:?}", err),
for _ in 2..=n {
let temp = a + b;
a = b;
b = temp;
}
}

fn fibonacci_recursive(n: u32) -> u32 {
if n <= 1 {
n
} else {
fibonacci_recursive(n - 1) + fibonacci_recursive(n - 2)
}
b
}
62 changes: 62 additions & 0 deletions example/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::num::ParseIntError;

#[derive(Debug)]
pub enum NumberConversionError {
InvalidLength(String),
}

// Converts a `u64` number to a `Vec<u8>` byte representation.
pub fn to_bytes(num: u64) -> Vec<u8> {
num.to_le_bytes().to_vec()
}

// Converts a `Vec<u8>` to a `u64` number.
pub fn from_bytes(bytes: &[u8]) -> Result<u64, NumberConversionError> {
match bytes.len() {
8 => Ok(u64::from_le_bytes(bytes.try_into().map_err(|_| {
NumberConversionError::InvalidLength(format!(
"Invalid payload length: expected 8 bytes, got {}",
bytes.len()
))
})?)),
_ => Err(NumberConversionError::InvalidLength(format!(
"Invalid payload length: expected 8 bytes, got {}",
bytes.len()
))),
}
}

pub fn to_string(num: u64) -> String {
num.to_string()
}

pub fn from_string(num: &str) -> Result<u64, ParseIntError> {
u64::from_str_radix(num, 16)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_to_bytes() {
let num: u64 = 42;
let bytes = to_bytes(num);
assert_eq!(bytes, num.to_le_bytes().to_vec());
}

#[test]
fn test_from_bytes() {
let num: u64 = 42;
let bytes = num.to_le_bytes().to_vec();
let result = from_bytes(&bytes);
assert_eq!(result.unwrap(), num);
}

#[test]
fn test_invalid_length() {
let bytes = vec![1, 2, 3]; // Length is not 8 bytes for a u64 number
let result = from_bytes(&bytes);
assert!(result.is_err());
}
}
Loading

0 comments on commit 968906a

Please sign in to comment.