aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/cfb8.rs105
-rw-r--r--src/connection.rs125
-rw-r--r--src/lib.rs14
-rw-r--r--src/reader.rs294
-rw-r--r--src/tcp.rs69
-rw-r--r--src/util.rs55
-rw-r--r--src/wrapper.rs12
-rw-r--r--src/writer.rs420
8 files changed, 1094 insertions, 0 deletions
diff --git a/src/cfb8.rs b/src/cfb8.rs
new file mode 100644
index 0000000..6fc3bb6
--- /dev/null
+++ b/src/cfb8.rs
@@ -0,0 +1,105 @@
+use aes::{
+ cipher::{consts::U16, generic_array::GenericArray, BlockCipherMut, NewBlockCipher},
+ Aes128,
+};
+use thiserror::Error;
+
+pub type CraftCipherResult<T> = Result<T, CipherError>;
+
+#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
+pub enum CipherComponent {
+ Key,
+ Iv,
+}
+
+#[derive(Debug, Error)]
+pub enum CipherError {
+ #[error("encryption is already enabled and cannot be enabled again")]
+ AlreadyEnabled,
+ #[error("bad size '{1}' for '{0:?}'")]
+ BadSize(CipherComponent, usize),
+}
+
+const BYTES_SIZE: usize = 16;
+
+pub struct CraftCipher {
+ iv: GenericArray<u8, U16>,
+ tmp: GenericArray<u8, U16>,
+ cipher: Aes128,
+}
+
+impl CraftCipher {
+ pub fn new(key: &[u8], iv: &[u8]) -> CraftCipherResult<Self> {
+ if iv.len() != BYTES_SIZE {
+ return Err(CipherError::BadSize(CipherComponent::Iv, iv.len()));
+ }
+
+ if key.len() != BYTES_SIZE {
+ return Err(CipherError::BadSize(CipherComponent::Key, iv.len()));
+ }
+
+ let mut iv_out = [0u8; BYTES_SIZE];
+ iv_out.copy_from_slice(iv);
+
+ let mut key_out = [0u8; BYTES_SIZE];
+ key_out.copy_from_slice(key);
+
+ let tmp = [0u8; BYTES_SIZE];
+
+ Ok(Self {
+ iv: GenericArray::from(iv_out),
+ tmp: GenericArray::from(tmp),
+ cipher: Aes128::new(&GenericArray::from(key_out)),
+ })
+ }
+
+ pub fn encrypt(&mut self, data: &mut [u8]) {
+ unsafe { self.crypt(data, false) }
+ }
+
+ pub fn decrypt(&mut self, data: &mut [u8]) {
+ unsafe { self.crypt(data, true) }
+ }
+
+ unsafe fn crypt(&mut self, data: &mut [u8], decrypt: bool) {
+ let iv = &mut self.iv;
+ const IV_SIZE: usize = 16;
+ const IV_SIZE_MINUS_ONE: usize = IV_SIZE - 1;
+ let iv_ptr = iv.as_mut_ptr();
+ let iv_end_ptr = iv_ptr.offset(IV_SIZE_MINUS_ONE as isize);
+ let tmp_ptr = self.tmp.as_mut_ptr();
+ let tmp_offset_one_ptr = tmp_ptr.offset(1);
+ let cipher = &mut self.cipher;
+ let n = data.len();
+ let mut data_ptr = data.as_mut_ptr();
+ let data_end_ptr = data_ptr.offset(n as isize);
+
+ while data_ptr != data_end_ptr {
+ std::ptr::copy_nonoverlapping(iv_ptr, tmp_ptr, IV_SIZE);
+ cipher.encrypt_block(iv);
+ let orig = *data_ptr;
+ let updated = orig ^ *iv_ptr;
+ std::ptr::copy_nonoverlapping(tmp_offset_one_ptr, iv_ptr, IV_SIZE_MINUS_ONE);
+ if decrypt {
+ *iv_end_ptr = orig;
+ } else {
+ *iv_end_ptr = updated;
+ }
+ *data_ptr = updated;
+ data_ptr = data_ptr.offset(1);
+ }
+ }
+}
+
+pub(crate) fn setup_craft_cipher(
+ target: &mut Option<CraftCipher>,
+ key: &[u8],
+ iv: &[u8],
+) -> Result<(), CipherError> {
+ if target.is_some() {
+ Err(CipherError::AlreadyEnabled)
+ } else {
+ *target = Some(CraftCipher::new(key, iv)?);
+ Ok(())
+ }
+}
diff --git a/src/connection.rs b/src/connection.rs
new file mode 100644
index 0000000..af6f4a9
--- /dev/null
+++ b/src/connection.rs
@@ -0,0 +1,125 @@
+use crate::cfb8::CipherError;
+use crate::reader::{CraftAsyncReader, CraftReader, CraftSyncReader, ReadResult};
+use crate::wrapper::{CraftIo, CraftWrapper};
+use crate::writer::{CraftAsyncWriter, CraftSyncWriter, CraftWriter, WriteResult};
+use mcproto_rs::protocol::{Packet, RawPacket, State};
+
+#[cfg(feature = "async")]
+use async_trait::async_trait;
+
+pub struct CraftConnection<R, W> {
+ pub(crate) reader: CraftReader<R>,
+ pub(crate) writer: CraftWriter<W>,
+}
+
+impl<R, W> CraftWrapper<(CraftReader<R>, CraftWriter<W>)> for CraftConnection<R, W> {
+ fn into_inner(self) -> (CraftReader<R>, CraftWriter<W>) {
+ (self.reader, self.writer)
+ }
+}
+
+impl<R, W> CraftIo for CraftConnection<R, W> {
+ fn set_state(&mut self, next: State) {
+ self.reader.set_state(next);
+ self.writer.set_state(next);
+ }
+
+ fn set_compression_threshold(&mut self, threshold: Option<i32>) {
+ self.reader.set_compression_threshold(threshold);
+ self.writer.set_compression_threshold(threshold);
+ }
+
+ fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError> {
+ self.reader.enable_encryption(key, iv)?;
+ self.writer.enable_encryption(key, iv)?;
+ Ok(())
+ }
+}
+
+impl<R, W> CraftSyncReader for CraftConnection<R, W>
+where
+ CraftReader<R>: CraftSyncReader,
+ CraftWriter<W>: CraftSyncWriter,
+{
+ fn read_packet<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet>
+ where
+ P: RawPacket<'a>,
+ {
+ self.reader.read_packet::<P>()
+ }
+
+ fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P>
+ where
+ P: RawPacket<'a>,
+ {
+ self.reader.read_raw_packet::<P>()
+ }
+}
+
+impl<R, W> CraftSyncWriter for CraftConnection<R, W>
+where
+ CraftReader<R>: CraftSyncReader,
+ CraftWriter<W>: CraftSyncWriter,
+{
+ fn write_packet<P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: Packet,
+ {
+ self.writer.write_packet(packet)
+ }
+
+ fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: RawPacket<'a>,
+ {
+ self.writer.write_raw_packet(packet)
+ }
+}
+
+#[cfg(feature = "async")]
+#[async_trait]
+impl<R, W> CraftAsyncReader for CraftConnection<R, W>
+where
+ CraftReader<R>: CraftAsyncReader,
+ R: Send + Sync,
+ CraftWriter<W>: CraftAsyncWriter,
+ W: Send + Sync,
+{
+ async fn read_packet<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet>
+ where
+ P: RawPacket<'a>,
+ {
+ self.reader.read_packet::<P>().await
+ }
+
+ async fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P>
+ where
+ P: RawPacket<'a>,
+ {
+ self.reader.read_raw_packet::<P>().await
+ }
+}
+
+#[cfg(feature = "async")]
+#[async_trait]
+impl<R, W> CraftAsyncWriter for CraftConnection<R, W>
+where
+ CraftReader<R>: CraftAsyncReader,
+ R: Send + Sync,
+ CraftWriter<W>: CraftAsyncWriter,
+ W: Send + Sync,
+{
+ async fn write_packet<P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: Packet + Send + Sync,
+ {
+ self.writer.write_packet(packet).await
+ }
+
+ async fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: RawPacket<'a> + Send + Sync,
+ {
+ self.writer.write_raw_packet(packet).await
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..8fafec0
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,14 @@
+mod cfb8;
+mod connection;
+mod reader;
+mod tcp;
+mod util;
+mod wrapper;
+mod writer;
+
+pub use connection::CraftConnection;
+pub use reader::*;
+pub use writer::*;
+pub use tcp::*;
+pub use cfb8::CipherError;
+pub use wrapper::*;
diff --git a/src/reader.rs b/src/reader.rs
new file mode 100644
index 0000000..68b9a9b
--- /dev/null
+++ b/src/reader.rs
@@ -0,0 +1,294 @@
+use crate::cfb8::{setup_craft_cipher, CipherError, CraftCipher};
+use crate::util::{get_sized_buf, VAR_INT_BUF_SIZE};
+use crate::wrapper::{CraftIo, CraftWrapper};
+use flate2::{DecompressError, FlushDecompress, Status};
+use mcproto_rs::protocol::{Id, PacketDirection, RawPacket, State};
+use mcproto_rs::types::VarInt;
+use mcproto_rs::{Deserialize, Deserialized};
+use thiserror::Error;
+
+#[cfg(feature = "async")]
+use {async_trait::async_trait, futures::AsyncReadExt};
+
+#[derive(Debug, Error)]
+pub enum ReadError {
+ #[error("i/o failure during read")]
+ IoFailure(#[from] std::io::Error),
+ #[error("failed to read header VarInt")]
+ PacketHeaderErr(#[from] mcproto_rs::DeserializeErr),
+ #[error("failed to read packet")]
+ PacketErr(#[from] mcproto_rs::protocol::PacketErr),
+ #[error("failed to decompress packet")]
+ DecompressFailed(#[from] DecompressErr),
+}
+
+#[derive(Debug, Error)]
+pub enum DecompressErr {
+ #[error("buf error")]
+ BufError,
+ #[error("failure while decompressing")]
+ Failure(#[from] DecompressError),
+}
+
+pub type ReadResult<P> = Result<Option<P>, ReadError>;
+
+#[cfg(feature = "async")]
+#[async_trait]
+pub trait CraftAsyncReader {
+ async fn read_packet<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet>
+ where
+ P: RawPacket<'a>,
+ {
+ deserialize_raw_packet(self.read_raw_packet::<P>().await)
+ }
+
+ async fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P>
+ where
+ P: RawPacket<'a>;
+}
+
+pub trait CraftSyncReader {
+ fn read_packet<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet>
+ where
+ P: RawPacket<'a>,
+ {
+ deserialize_raw_packet(self.read_raw_packet::<'a, P>())
+ }
+
+ fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P>
+ where
+ P: RawPacket<'a>;
+}
+
+pub struct CraftReader<R> {
+ inner: R,
+ raw_buf: Option<Vec<u8>>,
+ decompress_buf: Option<Vec<u8>>,
+ compression_threshold: Option<i32>,
+ state: State,
+ direction: PacketDirection,
+ encryption: Option<CraftCipher>,
+}
+
+impl<R> CraftWrapper<R> for CraftReader<R> {
+ fn into_inner(self) -> R {
+ self.inner
+ }
+}
+
+impl<R> CraftIo for CraftReader<R> {
+ fn set_state(&mut self, next: State) {
+ self.state = next;
+ }
+
+ fn set_compression_threshold(&mut self, threshold: Option<i32>) {
+ self.compression_threshold = threshold;
+ }
+
+ fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError> {
+ setup_craft_cipher(&mut self.encryption, key, iv)
+ }
+}
+
+macro_rules! rr_unwrap {
+ ($result: expr) => {
+ match $result {
+ Ok(Some(r)) => r,
+ Ok(None) => return Ok(None),
+ Err(err) => return Err(err),
+ }
+ };
+}
+
+macro_rules! check_unexpected_eof {
+ ($result: expr) => {
+ if let Err(err) = $result {
+ if err.kind() == std::io::ErrorKind::UnexpectedEof {
+ return Ok(None);
+ }
+
+ return Err(ReadError::IoFailure(err));
+ }
+ };
+}
+
+impl<R> CraftSyncReader for CraftReader<R>
+where
+ R: std::io::Read,
+{
+ fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P>
+ where
+ P: RawPacket<'a>,
+ {
+ let primary_packet_len = rr_unwrap!(self.read_one_varint_sync()).0 as usize;
+ rr_unwrap!(self.read_n(primary_packet_len));
+ self.read_packet_in_buf::<'a, P>(primary_packet_len)
+ }
+}
+
+#[cfg(feature = "async")]
+#[async_trait]
+impl<R> CraftAsyncReader for CraftReader<R>
+where
+ R: futures::AsyncRead + Unpin + Sync + Send,
+{
+ async fn read_raw_packet<'a, P>(&'a mut self) -> Result<Option<P>, ReadError>
+ where
+ P: RawPacket<'a>,
+ {
+ let primary_packet_len = rr_unwrap!(self.read_one_varint_async().await).0 as usize;
+ rr_unwrap!(self.read_n_async(primary_packet_len).await);
+ self.read_packet_in_buf::<P>(primary_packet_len)
+ }
+}
+
+impl<R> CraftReader<R>
+where
+ R: std::io::Read,
+{
+ fn read_one_varint_sync(&mut self) -> ReadResult<VarInt> {
+ deserialize_varint(rr_unwrap!(self.read_n(VAR_INT_BUF_SIZE)))
+ }
+
+ fn read_n(&mut self, n: usize) -> ReadResult<&mut [u8]> {
+ let buf = get_sized_buf(&mut self.raw_buf, 0, n);
+ check_unexpected_eof!(self.inner.read_exact(buf));
+ Ok(Some(buf))
+ }
+}
+
+#[cfg(feature = "async")]
+impl<R> CraftReader<R>
+where
+ R: futures::io::AsyncRead + Unpin + Sync + Send,
+{
+ async fn read_one_varint_async(&mut self) -> ReadResult<VarInt> {
+ deserialize_varint(rr_unwrap!(self.read_n_async(VAR_INT_BUF_SIZE).await))
+ }
+
+ async fn read_n_async(&mut self, n: usize) -> ReadResult<&mut [u8]> {
+ let buf = get_sized_buf(&mut self.raw_buf, 0, n);
+ check_unexpected_eof!(self.inner.read_exact(buf).await);
+ Ok(Some(buf))
+ }
+}
+
+macro_rules! dsz_unwrap {
+ ($bnam: expr, $k: ty) => {
+ match <$k>::mc_deserialize($bnam) {
+ Ok(Deserialized {
+ value: val,
+ data: rest,
+ }) => (val, rest),
+ Err(err) => {
+ return Err(ReadError::PacketHeaderErr(err));
+ }
+ };
+ };
+}
+
+impl<R> CraftReader<R> {
+ pub fn wrap(inner: R, direction: PacketDirection) -> Self {
+ Self::wrap_with_state(inner, direction, State::Handshaking)
+ }
+
+ pub fn wrap_with_state(inner: R, direction: PacketDirection, state: State) -> Self {
+ Self {
+ inner,
+ raw_buf: None,
+ decompress_buf: None,
+ compression_threshold: None,
+ state,
+ direction,
+ encryption: None,
+ }
+ }
+
+ fn read_packet_in_buf<'a, P>(&'a mut self, size: usize) -> ReadResult<P>
+ where
+ P: RawPacket<'a>,
+ {
+ // find data in buf
+ let buf = &mut self.raw_buf.as_mut().expect("should exist right now")[..size];
+ // decrypt the packet if encryption is enabled
+ if let Some(encryption) = self.encryption.as_mut() {
+ encryption.decrypt(buf);
+ }
+
+ // try to get the packet body bytes... this boils down to:
+ // * check if compression enabled,
+ // * read data len (VarInt) which isn't compressed
+ // * if data len is 0, then rest of packet is not compressed, remaining data is body
+ // * otherwise, data len is decompressed length, so prepare a decompression buf and decompress from
+ // the buffer into the decompression buffer, and return the slice of the decompression buffer
+ // which contains this packet's data
+ // * if compression not enabled, then the buf contains only the packet body bytes
+
+ let packet_buf = if let Some(_) = self.compression_threshold {
+ let (data_len, rest) = dsz_unwrap!(buf, VarInt);
+ let data_len = data_len.0 as usize;
+ if data_len == 0 {
+ rest
+ } else {
+ decompress(rest, &mut self.decompress_buf, data_len)?
+ }
+ } else {
+ buf
+ };
+
+ let (raw_id, body_buf) = dsz_unwrap!(packet_buf, VarInt);
+
+ let id = Id {
+ id: raw_id.0,
+ state: self.state.clone(),
+ direction: self.direction.clone(),
+ };
+
+ match P::create(id, body_buf) {
+ Ok(raw) => Ok(Some(raw)),
+ Err(err) => Err(ReadError::PacketErr(err)),
+ }
+ }
+}
+
+fn deserialize_raw_packet<'a, P>(raw: ReadResult<P>) -> ReadResult<P::Packet>
+where
+ P: RawPacket<'a>,
+{
+ match raw {
+ Ok(Some(raw)) => match raw.deserialize() {
+ Ok(deserialized) => Ok(Some(deserialized)),
+ Err(err) => Err(ReadError::PacketErr(err)),
+ },
+ Ok(None) => Ok(None),
+ Err(err) => Err(err),
+ }
+}
+
+fn deserialize_varint(buf: &[u8]) -> ReadResult<VarInt> {
+ match VarInt::mc_deserialize(buf) {
+ Ok(v) => Ok(Some(v.value)),
+ Err(err) => Err(ReadError::PacketHeaderErr(err)),
+ }
+}
+
+fn decompress<'a>(
+ src: &'a [u8],
+ target: &'a mut Option<Vec<u8>>,
+ decompressed_len: usize,
+) -> Result<&'a mut [u8], ReadError> {
+ let mut decompress = flate2::Decompress::new(true);
+ let decompress_buf = get_sized_buf(target, 0, decompressed_len);
+ loop {
+ match decompress.decompress(src, decompress_buf, FlushDecompress::Finish) {
+ Ok(Status::StreamEnd) => break,
+ Ok(Status::Ok) => {}
+ Ok(Status::BufError) => {
+ return Err(ReadError::DecompressFailed(DecompressErr::BufError))
+ }
+ Err(err) => return Err(ReadError::DecompressFailed(DecompressErr::Failure(err))),
+ }
+ }
+
+ Ok(&mut decompress_buf[..(decompress.total_out() as usize)])
+}
diff --git a/src/tcp.rs b/src/tcp.rs
new file mode 100644
index 0000000..e2ead90
--- /dev/null
+++ b/src/tcp.rs
@@ -0,0 +1,69 @@
+use crate::connection::CraftConnection;
+use crate::reader::CraftReader;
+use crate::writer::CraftWriter;
+use mcproto_rs::protocol::{PacketDirection, State};
+use std::convert::TryFrom;
+use std::io::BufReader as StdBufReader;
+use std::net::TcpStream;
+
+#[cfg(feature = "async")]
+use futures::io::{AsyncRead, AsyncWrite, BufReader as AsyncBufReader};
+
+pub const BUF_SIZE: usize = 8192;
+
+pub type CraftTcpConnection = CraftConnection<StdBufReader<TcpStream>, TcpStream>;
+
+impl CraftConnection<StdBufReader<TcpStream>, TcpStream> {
+ pub fn connect_server_std(to: String) -> Result<Self, std::io::Error> {
+ Self::from_std(TcpStream::connect(to)?, PacketDirection::ClientBound)
+ }
+
+ pub fn wrap_client_stream_std(stream: TcpStream) -> Result<Self, std::io::Error> {
+ Self::from_std(stream, PacketDirection::ServerBound)
+ }
+
+ pub fn from_std(
+ s1: TcpStream,
+ read_direction: PacketDirection,
+ ) -> Result<Self, std::io::Error> {
+ Self::from_std_with_state(s1, read_direction, State::Handshaking)
+ }
+
+ pub fn from_std_with_state(
+ s1: TcpStream,
+ read_direction: PacketDirection,
+ state: State,
+ ) -> Result<Self, std::io::Error> {
+ let write = s1.try_clone()?;
+ let read = StdBufReader::with_capacity(BUF_SIZE, s1);
+
+ Ok(Self {
+ reader: CraftReader::wrap_with_state(read, read_direction, state),
+ writer: CraftWriter::wrap_with_state(write, read_direction.opposite(), state),
+ })
+ }
+}
+
+#[cfg(feature = "async")]
+impl<R, W> CraftConnection<AsyncBufReader<R>, W>
+where
+ R: AsyncRead + Send + Sync + Unpin,
+ W: AsyncWrite + Send + Sync + Unpin,
+{
+ pub fn from_async(tuple: (R, W), read_direction: PacketDirection) -> Self {
+ Self::from_async_with_state(tuple, read_direction, State::Handshaking)
+ }
+
+ pub fn from_async_with_state(
+ tuple: (R, W),
+ read_direction: PacketDirection,
+ state: State,
+ ) -> Self {
+ let (reader, writer) = tuple;
+ let reader = AsyncBufReader::with_capacity(BUF_SIZE, reader);
+ Self {
+ reader: CraftReader::wrap_with_state(reader, read_direction, state),
+ writer: CraftWriter::wrap_with_state(writer, read_direction.opposite(), state),
+ }
+ }
+}
diff --git a/src/util.rs b/src/util.rs
new file mode 100644
index 0000000..a19922f
--- /dev/null
+++ b/src/util.rs
@@ -0,0 +1,55 @@
+pub(crate) const VAR_INT_BUF_SIZE: usize = 5;
+
+pub(crate) fn get_sized_buf(buf: &mut Option<Vec<u8>>, offset: usize, size: usize) -> &mut [u8] {
+ let end_at = offset + size;
+ loop {
+ match buf {
+ Some(v) => {
+ ensure_buf_has_size(v, end_at);
+ break &mut v[offset..end_at];
+ }
+ None => {
+ let new_buf = Vec::with_capacity(end_at);
+ *buf = Some(new_buf);
+ }
+ }
+ }
+}
+fn ensure_buf_has_size(buf: &mut Vec<u8>, total_size: usize) {
+ let cur_len = buf.len();
+ if cur_len >= total_size {
+ return;
+ }
+
+ let additional = total_size - cur_len;
+ buf.reserve(additional);
+ unsafe {
+ let start_at = buf.as_mut_ptr();
+ let start_write_at = start_at.offset(cur_len as isize);
+ std::ptr::write_bytes(start_write_at, 0, additional);
+ buf.set_len(total_size);
+ }
+}
+
+pub(crate) fn move_data_rightwards(target: &mut [u8], size: usize, shift_amount: usize) {
+ let required_len = size + shift_amount;
+ let actual_len = target.len();
+ if actual_len < required_len {
+ panic!(
+ "move of data to the right (0..{} -> {}..{}) exceeds size of buffer {}",
+ size, shift_amount, required_len, actual_len,
+ )
+ }
+
+ unsafe { move_data_rightwards_unchecked(target, size, shift_amount) }
+}
+
+unsafe fn move_data_rightwards_unchecked(target: &mut [u8], size: usize, shift_amount: usize) {
+ if shift_amount == 0 {
+ return;
+ }
+
+ let src_ptr = target.as_mut_ptr();
+ let dst_ptr = src_ptr.offset(shift_amount as isize);
+ std::ptr::copy(src_ptr, dst_ptr, size);
+}
diff --git a/src/wrapper.rs b/src/wrapper.rs
new file mode 100644
index 0000000..06f451b
--- /dev/null
+++ b/src/wrapper.rs
@@ -0,0 +1,12 @@
+use crate::cfb8::CipherError;
+use mcproto_rs::protocol::State;
+
+pub trait CraftWrapper<I> {
+ fn into_inner(self) -> I;
+}
+
+pub trait CraftIo {
+ fn set_state(&mut self, next: State);
+ fn set_compression_threshold(&mut self, threshold: Option<i32>);
+ fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError>;
+}
diff --git a/src/writer.rs b/src/writer.rs
new file mode 100644
index 0000000..6c76b23
--- /dev/null
+++ b/src/writer.rs
@@ -0,0 +1,420 @@
+use crate::cfb8::{setup_craft_cipher, CipherError, CraftCipher};
+use crate::util::{get_sized_buf, move_data_rightwards, VAR_INT_BUF_SIZE};
+use crate::wrapper::{CraftIo, CraftWrapper};
+use flate2::{CompressError, Compression, FlushCompress, Status};
+use mcproto_rs::protocol::{Id, Packet, PacketDirection, RawPacket, State};
+use mcproto_rs::types::VarInt;
+use mcproto_rs::{Serialize, SerializeErr, SerializeResult, Serializer};
+use thiserror::Error;
+
+#[cfg(feature = "async")]
+use {async_trait::async_trait, futures::AsyncWriteExt};
+
+#[derive(Debug, Error)]
+pub enum WriteError {
+ #[error("serialization of header data failed")]
+ HeaderSerializeFail(SerializeErr),
+ #[error("packet body serialization failed")]
+ BodySerializeFail(SerializeErr),
+ #[error("failed to compress packet")]
+ CompressFail(CompressError),
+ #[error("compression gave buf error")]
+ CompressBufError,
+ #[error("io error while writing data")]
+ IoFail(#[from] std::io::Error),
+ #[error("bad direction")]
+ BadDirection {
+ attempted: PacketDirection,
+ expected: PacketDirection,
+ },
+ #[error("bad state")]
+ BadState { attempted: State, expected: State },
+}
+
+pub type WriteResult<P> = Result<P, WriteError>;
+
+#[cfg(feature = "async")]
+#[async_trait]
+pub trait CraftAsyncWriter {
+ async fn write_packet<P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: Packet + Send + Sync;
+
+ async fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: RawPacket<'a> + Send + Sync;
+}
+
+pub trait CraftSyncWriter {
+ fn write_packet<P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: Packet;
+
+ fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: RawPacket<'a>;
+}
+
+pub struct CraftWriter<W> {
+ inner: W,
+
+ raw_buf: Option<Vec<u8>>,
+ compress_buf: Option<Vec<u8>>,
+ compression_threshold: Option<i32>,
+ state: State,
+ direction: PacketDirection,
+ encryption: Option<CraftCipher>,
+}
+
+impl<W> CraftWrapper<W> for CraftWriter<W> {
+ fn into_inner(self) -> W {
+ self.inner
+ }
+}
+
+impl<W> CraftIo for CraftWriter<W> {
+ fn set_state(&mut self, next: State) {
+ self.state = next;
+ }
+
+ fn set_compression_threshold(&mut self, threshold: Option<i32>) {
+ self.compression_threshold = threshold;
+ }
+
+ fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError> {
+ setup_craft_cipher(&mut self.encryption, key, iv)
+ }
+}
+
+impl<W> CraftSyncWriter for CraftWriter<W>
+where
+ W: std::io::Write,
+{
+ fn write_packet<P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: Packet,
+ {
+ let prepared = self.serialize_packet_to_buf(packet)?;
+ write_data_to_target_sync(self.prepare_packet_in_buf(prepared)?)?;
+ Ok(())
+ }
+
+ fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: RawPacket<'a>,
+ {
+ let prepared = self.serialize_raw_packet_to_buf(packet)?;
+ write_data_to_target_sync(self.prepare_packet_in_buf(prepared)?)?;
+ Ok(())
+ }
+}
+
+fn write_data_to_target_sync<'a, W>(tuple: (&'a [u8], &'a mut W)) -> Result<(), std::io::Error>
+where
+ W: std::io::Write,
+{
+ let (data, target) = tuple;
+ target.write_all(data)
+}
+
+#[cfg(feature = "async")]
+#[async_trait]
+impl<W> CraftAsyncWriter for CraftWriter<W>
+where
+ W: futures::AsyncWrite + Unpin + Send + Sync,
+{
+ async fn write_packet<P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: Packet + Send + Sync,
+ {
+ let prepared = self.serialize_packet_to_buf(packet)?;
+ write_data_to_target_async(self.prepare_packet_in_buf(prepared)?).await?;
+ Ok(())
+ }
+
+ async fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()>
+ where
+ P: RawPacket<'a> + Send + Sync,
+ {
+ let prepared = self.serialize_raw_packet_to_buf(packet)?;
+ write_data_to_target_async(self.prepare_packet_in_buf(prepared)?).await?;
+ Ok(())
+ }
+}
+
+#[cfg(feature = "async")]
+async fn write_data_to_target_async<'a, W>(
+ tuple: (&'a [u8], &'a mut W),
+) -> Result<(), std::io::Error>
+where
+ W: futures::AsyncWrite + Unpin + Send + Sync,
+{
+ let (data, target) = tuple;
+ target.write_all(data).await
+}
+
+const HEADER_OFFSET: usize = VAR_INT_BUF_SIZE * 2;
+
+struct PreparedPacketHandle {
+ id_size: usize,
+ data_size: usize,
+}
+
+impl<W> CraftWriter<W> {
+ pub fn wrap(inner: W, direction: PacketDirection) -> Self {
+ Self::wrap_with_state(inner, direction, State::Handshaking)
+ }
+
+ pub fn wrap_with_state(inner: W, direction: PacketDirection, state: State) -> Self {
+ Self {
+ inner,
+ raw_buf: None,
+ compression_threshold: None,
+ compress_buf: None,
+ state,
+ direction,
+ encryption: None,
+ }
+ }
+
+ fn prepare_packet_in_buf(
+ &mut self,
+ prepared: PreparedPacketHandle,
+ ) -> WriteResult<(&[u8], &mut W)> {
+ // assume id and body are in raw buf from HEADER_OFFSET .. size + HEADER_OFFSET
+ let body_size = prepared.id_size + prepared.data_size;
+ let buf = get_sized_buf(&mut self.raw_buf, 0, body_size);
+
+ let packet_data = if let Some(threshold) = self.compression_threshold {
+ if threshold >= 0 && (threshold as usize) <= body_size {
+ let compressed_size = compress(buf, &mut self.compress_buf, HEADER_OFFSET)?.len();
+ let compress_buf =
+ get_sized_buf(&mut self.compress_buf, 0, compressed_size + HEADER_OFFSET);
+
+ let data_len_target = &mut compress_buf[VAR_INT_BUF_SIZE..HEADER_OFFSET];
+ let mut data_len_serializer = SliceSerializer::create(data_len_target);
+ VarInt(body_size as i32)
+ .mc_serialize(&mut data_len_serializer)
+ .map_err(move |err| WriteError::HeaderSerializeFail(err))?;
+ let data_len_bytes = data_len_serializer.finish().len();
+
+ let packet_len_target = &mut compress_buf[..VAR_INT_BUF_SIZE];
+ let mut packet_len_serializer = SliceSerializer::create(packet_len_target);
+ VarInt((compressed_size + data_len_bytes) as i32)
+ .mc_serialize(&mut packet_len_serializer)
+ .map_err(move |err| WriteError::HeaderSerializeFail(err))?;
+ let packet_len_bytes = packet_len_serializer.finish().len();
+
+ let n_shift_packet_len = VAR_INT_BUF_SIZE - packet_len_bytes;
+ move_data_rightwards(
+ &mut compress_buf[..HEADER_OFFSET],
+ packet_len_bytes,
+ n_shift_packet_len,
+ );
+ let n_shift_data_len = VAR_INT_BUF_SIZE - data_len_bytes;
+ move_data_rightwards(
+ &mut compress_buf[n_shift_packet_len..HEADER_OFFSET],
+ packet_len_bytes + data_len_bytes,
+ n_shift_data_len,
+ );
+ let start_offset = n_shift_data_len + n_shift_packet_len;
+ let end_at = start_offset + data_len_bytes + packet_len_bytes + compressed_size;
+ &mut compress_buf[start_offset..end_at]
+ } else {
+ let packet_len_start_at = VAR_INT_BUF_SIZE - 1;
+ let packet_len_target = &mut buf[packet_len_start_at..HEADER_OFFSET - 1];
+ let mut packet_len_serializer = SliceSerializer::create(packet_len_target);
+ VarInt((body_size + 1) as i32)
+ .mc_serialize(&mut packet_len_serializer)
+ .map_err(move |err| WriteError::HeaderSerializeFail(err))?;
+
+ let packet_len_bytes = packet_len_serializer.finish().len();
+ let n_shift_packet_len = VAR_INT_BUF_SIZE - packet_len_bytes;
+ move_data_rightwards(
+ &mut buf[packet_len_start_at..HEADER_OFFSET - 1],
+ packet_len_bytes,
+ n_shift_packet_len,
+ );
+
+ let start_offset = packet_len_start_at + n_shift_packet_len;
+ let end_at = start_offset + packet_len_bytes + 1 + body_size;
+ &mut buf[start_offset..end_at]
+ }
+ } else {
+ let packet_len_target = &mut buf[VAR_INT_BUF_SIZE..HEADER_OFFSET];
+ let mut packet_len_serializer = SliceSerializer::create(packet_len_target);
+ VarInt(body_size as i32)
+ .mc_serialize(&mut packet_len_serializer)
+ .map_err(move |err| WriteError::HeaderSerializeFail(err))?;
+ let packet_len_bytes = packet_len_serializer.finish().len();
+ let n_shift_packet_len = VAR_INT_BUF_SIZE - packet_len_bytes;
+ move_data_rightwards(
+ &mut buf[VAR_INT_BUF_SIZE..HEADER_OFFSET],
+ packet_len_bytes,
+ n_shift_packet_len,
+ );
+ let start_offset = VAR_INT_BUF_SIZE + n_shift_packet_len;
+ let end_at = start_offset + packet_len_bytes + body_size;
+ &mut buf[start_offset..end_at]
+ };
+
+ if let Some(encryption) = &mut self.encryption {
+ encryption.encrypt(packet_data);
+ }
+
+ Ok((packet_data, &mut self.inner))
+ }
+
+ fn serialize_packet_to_buf<P>(&mut self, packet: P) -> WriteResult<PreparedPacketHandle>
+ where
+ P: Packet,
+ {
+ let id_size = self.serialize_id_to_buf(packet.id())?;
+ let data_size = self.serialize_to_buf(HEADER_OFFSET + id_size, move |serializer| {
+ packet
+ .mc_serialize_body(serializer)
+ .map_err(move |err| WriteError::BodySerializeFail(err))
+ })?;
+
+ Ok(PreparedPacketHandle { id_size, data_size })
+ }
+
+ fn serialize_raw_packet_to_buf<'a, P>(&mut self, packet: P) -> WriteResult<PreparedPacketHandle>
+ where
+ P: RawPacket<'a>,
+ {
+ let id_size = self.serialize_id_to_buf(packet.id())?;
+ let packet_data = packet.data();
+ let data_size = packet_data.len();
+ let buf = get_sized_buf(&mut self.raw_buf, HEADER_OFFSET, id_size + data_size);
+
+ (&mut buf[id_size..]).copy_from_slice(packet_data);
+
+ Ok(PreparedPacketHandle { id_size, data_size })
+ }
+
+ fn serialize_id_to_buf(&mut self, id: Id) -> WriteResult<usize> {
+ if id.direction != self.direction {
+ return Err(WriteError::BadDirection {
+ expected: self.direction,
+ attempted: id.direction,
+ });
+ }
+
+ if id.state != self.state {
+ return Err(WriteError::BadState {
+ expected: self.state,
+ attempted: id.state,
+ });
+ }
+
+ self.serialize_to_buf(HEADER_OFFSET, move |serializer| {
+ id.mc_serialize(serializer)
+ .map_err(move |err| WriteError::HeaderSerializeFail(err))
+ })
+ }
+
+ fn serialize_to_buf<'a, F>(&'a mut self, offset: usize, f: F) -> WriteResult<usize>
+ where
+ F: FnOnce(&mut GrowVecSerializer<'a>) -> Result<(), WriteError>,
+ {
+ let mut serializer = GrowVecSerializer::create(&mut self.raw_buf, offset);
+ f(&mut serializer)?;
+ Ok(serializer.finish().map(move |b| b.len()).unwrap_or(0))
+ }
+}
+
+#[derive(Debug)]
+struct GrowVecSerializer<'a> {
+ target: &'a mut Option<Vec<u8>>,
+ at: usize,
+ offset: usize,
+}
+
+impl<'a> Serializer for GrowVecSerializer<'a> {
+ fn serialize_bytes(&mut self, data: &[u8]) -> SerializeResult {
+ get_sized_buf(self.target, self.at + self.offset, data.len()).copy_from_slice(data);
+ Ok(())
+ }
+}
+
+impl<'a> GrowVecSerializer<'a> {
+ fn create(target: &'a mut Option<Vec<u8>>, offset: usize) -> Self {
+ Self {
+ target,
+ at: 0,
+ offset,
+ }
+ }
+
+ fn finish(self) -> Option<&'a mut [u8]> {
+ if let Some(buf) = self.target {
+ Some(&mut buf[self.offset..self.offset + self.at])
+ } else {
+ None
+ }
+ }
+}
+
+struct SliceSerializer<'a> {
+ target: &'a mut [u8],
+ at: usize,
+}
+
+impl<'a> Serializer for SliceSerializer<'a> {
+ fn serialize_bytes(&mut self, data: &[u8]) -> SerializeResult {
+ let end_at = self.at + data.len();
+ if end_at >= self.target.len() {
+ panic!(
+ "cannot fit data in slice ({} exceeds length {} at {})",
+ data.len(),
+ self.target.len(),
+ self.at
+ );
+ }
+
+ (&mut self.target[self.at..end_at]).copy_from_slice(data);
+ self.at = end_at;
+ Ok(())
+ }
+}
+
+impl<'a> SliceSerializer<'a> {
+ fn create(target: &'a mut [u8]) -> Self {
+ Self { target, at: 0 }
+ }
+
+ fn finish(self) -> &'a [u8] {
+ &self.target[..self.at]
+ }
+}
+
+fn compress<'a, 'b>(
+ src: &'b [u8],
+ output: &'a mut Option<Vec<u8>>,
+ offset: usize,
+) -> Result<&'a mut [u8], WriteError> {
+ let target = get_sized_buf(output, offset, src.len());
+ let mut compressor = flate2::Compress::new_with_window_bits(Compression::fast(), true, 15);
+ loop {
+ let input = &src[(compressor.total_in() as usize)..];
+ let eof = input.is_empty();
+ let output = &mut target[(compressor.total_out() as usize)..];
+ let flush = if eof {
+ FlushCompress::Finish
+ } else {
+ FlushCompress::None
+ };
+
+ match compressor
+ .compress(input, output, flush)
+ .map_err(move |err| WriteError::CompressFail(err))?
+ {
+ Status::Ok => {}
+ Status::BufError => return Err(WriteError::CompressBufError),
+ Status::StreamEnd => break,
+ }
+ }
+
+ Ok(&mut target[..(compressor.total_out() as usize)])
+}