Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add full compression support #5

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ pub struct PlayerNet {
pub peer_addr: SocketAddr,
pub local_addr: SocketAddr,
pub state: RwLock<State>,
pub compression: RwLock<Option<usize>>,
pub compression: Option<usize>,
pub _explode: mpsc::Sender<()>,
}

Expand Down Expand Up @@ -156,7 +156,7 @@ impl PlayerNet {
let Err::<!, _>(e) = async {
loop {
let packet: SerializedPacket = r_send.recv_async().await?;
let data = packet.serialize()?;
let data = packet.serialize_compressing(compression)?;
write.write_all(&data).await?;
}
}
Expand All @@ -176,7 +176,8 @@ impl PlayerNet {
if read_bytes == 0 {
return Ok::<(), crate::error::Error>(());
}
let spack = SerializedPacket::deserialize.parse(buf.as_ref())?;
let spack = SerializedPacket::deserialize_compressing(compression)
.parse(buf.as_ref())?;
s_recv.send_async(spack).await?;
buf.clear();
}
Expand Down Expand Up @@ -212,7 +213,7 @@ impl PlayerNet {
peer_addr,
local_addr,
state: RwLock::new(State::Handshaking),
compression: RwLock::new(None),
compression,
_explode: shit,
}
}
Expand Down
63 changes: 57 additions & 6 deletions protocol/src/model/packets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::ops::Deref;
use crate::error::Error;
use crate::ser::*;
use ::bytes::{BufMut, Bytes, BytesMut};
use aott::prelude::*;
use tracing::debug;
use aott::{pfn_type, prelude::*};
use tracing::{debug, warn};

use super::{State, VarInt};

Expand Down Expand Up @@ -84,6 +84,48 @@ impl SerializedPacket {
Some(format!("packet_{:#x}.bin", self.id.0)),
))
}

pub fn serialize_compressing(&self, compression: Option<usize>) -> Result<Bytes, Error> {
if compression.filter(|x| self.length >= *x).is_some() {
let data_length = self.id.length_of() + self.data.len();
let datalength = VarInt::<i32>(data_length.try_into().unwrap());
let length =
datalength.length_of() + Compress((&self.id, &self.data), Zlib).serialize()?.len();
let pack = SerializedPacketCompressed {
length,
data_length,
id: self.id,
data: self.data.clone(),
};
pack.serialize()
} else {
self.serialize()
}
}

pub fn deserialize_compressing<'a>(
compression: Option<usize>,
) -> pfn_type!(&'a [u8], Self, Extra<<Self as Deserialize>::Context>) {
move |input| {
if let Some(cmp) = compression {
SerializedPacketCompressed::deserialize(input)
.map(Self::from)
.map(|v| {
if v.length < cmp {
warn!(
packet_length = v.length,
compresssion_threshold = cmp,
"Packet length was less than compression threshold"
);
}

v
})
} else {
Self::deserialize(input)
}
}
}
}

impl Deserialize for SerializedPacket {
Expand Down Expand Up @@ -187,9 +229,8 @@ impl Deserialize for SerializedPacketCompressed {
packet_length,
data_length, actual_data_length, "decompressing"
);
let Compress(real_data, Zlib): Compress<Bytes, Zlib> = Compress::decompress(
Bytes::copy_from_slice(input.input.slice_from(input.offset..)),
)?;
let Compress(real_data, Zlib): Compress<Bytes, Zlib> =
Compress::decompress(input.input.slice_from(input.offset..))?;
assert_eq!(real_data.len(), data_length);
let data_slice = real_data.deref();
let mut data_input = Input::new(&data_slice);
Expand All @@ -204,7 +245,7 @@ impl Deserialize for SerializedPacketCompressed {
length: packet_length,
data_length,
id,
data: Bytes::copy_from_slice(data),
data: Zlib::decode(data)?,
}
}
}
Expand All @@ -225,6 +266,16 @@ impl Serialize for SerializedPacketCompressed {
}
}

impl From<SerializedPacketCompressed> for SerializedPacket {
fn from(value: SerializedPacketCompressed) -> Self {
Self {
length: value.data_length,
id: value.id,
data: value.data,
}
}
}

#[derive(Debug, Clone)]
pub struct PluginMessage {
pub channel: Identifier,
Expand Down
6 changes: 3 additions & 3 deletions protocol/src/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ pub trait Compression {
thing: &T,
buf: &mut BytesMut,
) -> Result<(), crate::error::Error>;
fn decode(data: Bytes) -> Result<Bytes, crate::error::Error>;
fn decode(data: &[u8]) -> Result<Bytes, crate::error::Error>;
}
impl Compression for Zlib {
fn encode(data: Bytes) -> Result<Bytes, crate::error::Error> {
Expand Down Expand Up @@ -1087,7 +1087,7 @@ impl Compression for Zlib {
Ok(())
}

fn decode(data: Bytes) -> Result<Bytes, crate::error::Error> {
fn decode(data: &[u8]) -> Result<Bytes, crate::error::Error> {
use std::io::Read;
let mut dec = flate2::read::ZlibDecoder::new(std::io::Cursor::new(data));
let mut buf = Vec::new();
Expand All @@ -1105,7 +1105,7 @@ impl<T: Serialize, C: Compression> Serialize for Compress<T, C> {
}

impl<T: Deserialize<Context = ()>, C: Compression + Default> Compress<T, C> {
pub fn decompress(buffer: Bytes) -> Result<Self, crate::error::Error> {
pub fn decompress(buffer: &[u8]) -> Result<Self, crate::error::Error> {
let buffer = C::decode(buffer)?;
Ok(Self(T::deserialize.parse(&buffer)?, C::default()))
}
Expand Down