diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cfb8.rs | 105 | ||||
-rw-r--r-- | src/connection.rs | 125 | ||||
-rw-r--r-- | src/lib.rs | 14 | ||||
-rw-r--r-- | src/reader.rs | 294 | ||||
-rw-r--r-- | src/tcp.rs | 69 | ||||
-rw-r--r-- | src/util.rs | 55 | ||||
-rw-r--r-- | src/wrapper.rs | 12 | ||||
-rw-r--r-- | src/writer.rs | 420 |
8 files changed, 1094 insertions, 0 deletions
diff --git a/src/cfb8.rs b/src/cfb8.rs new file mode 100644 index 0000000..6fc3bb6 --- /dev/null +++ b/src/cfb8.rs @@ -0,0 +1,105 @@ +use aes::{ + cipher::{consts::U16, generic_array::GenericArray, BlockCipherMut, NewBlockCipher}, + Aes128, +}; +use thiserror::Error; + +pub type CraftCipherResult<T> = Result<T, CipherError>; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum CipherComponent { + Key, + Iv, +} + +#[derive(Debug, Error)] +pub enum CipherError { + #[error("encryption is already enabled and cannot be enabled again")] + AlreadyEnabled, + #[error("bad size '{1}' for '{0:?}'")] + BadSize(CipherComponent, usize), +} + +const BYTES_SIZE: usize = 16; + +pub struct CraftCipher { + iv: GenericArray<u8, U16>, + tmp: GenericArray<u8, U16>, + cipher: Aes128, +} + +impl CraftCipher { + pub fn new(key: &[u8], iv: &[u8]) -> CraftCipherResult<Self> { + if iv.len() != BYTES_SIZE { + return Err(CipherError::BadSize(CipherComponent::Iv, iv.len())); + } + + if key.len() != BYTES_SIZE { + return Err(CipherError::BadSize(CipherComponent::Key, iv.len())); + } + + let mut iv_out = [0u8; BYTES_SIZE]; + iv_out.copy_from_slice(iv); + + let mut key_out = [0u8; BYTES_SIZE]; + key_out.copy_from_slice(key); + + let tmp = [0u8; BYTES_SIZE]; + + Ok(Self { + iv: GenericArray::from(iv_out), + tmp: GenericArray::from(tmp), + cipher: Aes128::new(&GenericArray::from(key_out)), + }) + } + + pub fn encrypt(&mut self, data: &mut [u8]) { + unsafe { self.crypt(data, false) } + } + + pub fn decrypt(&mut self, data: &mut [u8]) { + unsafe { self.crypt(data, true) } + } + + unsafe fn crypt(&mut self, data: &mut [u8], decrypt: bool) { + let iv = &mut self.iv; + const IV_SIZE: usize = 16; + const IV_SIZE_MINUS_ONE: usize = IV_SIZE - 1; + let iv_ptr = iv.as_mut_ptr(); + let iv_end_ptr = iv_ptr.offset(IV_SIZE_MINUS_ONE as isize); + let tmp_ptr = self.tmp.as_mut_ptr(); + let tmp_offset_one_ptr = tmp_ptr.offset(1); + let cipher = &mut self.cipher; + let n = data.len(); + let mut data_ptr = data.as_mut_ptr(); + let data_end_ptr = data_ptr.offset(n as isize); + + while data_ptr != data_end_ptr { + std::ptr::copy_nonoverlapping(iv_ptr, tmp_ptr, IV_SIZE); + cipher.encrypt_block(iv); + let orig = *data_ptr; + let updated = orig ^ *iv_ptr; + std::ptr::copy_nonoverlapping(tmp_offset_one_ptr, iv_ptr, IV_SIZE_MINUS_ONE); + if decrypt { + *iv_end_ptr = orig; + } else { + *iv_end_ptr = updated; + } + *data_ptr = updated; + data_ptr = data_ptr.offset(1); + } + } +} + +pub(crate) fn setup_craft_cipher( + target: &mut Option<CraftCipher>, + key: &[u8], + iv: &[u8], +) -> Result<(), CipherError> { + if target.is_some() { + Err(CipherError::AlreadyEnabled) + } else { + *target = Some(CraftCipher::new(key, iv)?); + Ok(()) + } +} diff --git a/src/connection.rs b/src/connection.rs new file mode 100644 index 0000000..af6f4a9 --- /dev/null +++ b/src/connection.rs @@ -0,0 +1,125 @@ +use crate::cfb8::CipherError; +use crate::reader::{CraftAsyncReader, CraftReader, CraftSyncReader, ReadResult}; +use crate::wrapper::{CraftIo, CraftWrapper}; +use crate::writer::{CraftAsyncWriter, CraftSyncWriter, CraftWriter, WriteResult}; +use mcproto_rs::protocol::{Packet, RawPacket, State}; + +#[cfg(feature = "async")] +use async_trait::async_trait; + +pub struct CraftConnection<R, W> { + pub(crate) reader: CraftReader<R>, + pub(crate) writer: CraftWriter<W>, +} + +impl<R, W> CraftWrapper<(CraftReader<R>, CraftWriter<W>)> for CraftConnection<R, W> { + fn into_inner(self) -> (CraftReader<R>, CraftWriter<W>) { + (self.reader, self.writer) + } +} + +impl<R, W> CraftIo for CraftConnection<R, W> { + fn set_state(&mut self, next: State) { + self.reader.set_state(next); + self.writer.set_state(next); + } + + fn set_compression_threshold(&mut self, threshold: Option<i32>) { + self.reader.set_compression_threshold(threshold); + self.writer.set_compression_threshold(threshold); + } + + fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError> { + self.reader.enable_encryption(key, iv)?; + self.writer.enable_encryption(key, iv)?; + Ok(()) + } +} + +impl<R, W> CraftSyncReader for CraftConnection<R, W> +where + CraftReader<R>: CraftSyncReader, + CraftWriter<W>: CraftSyncWriter, +{ + fn read_packet<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet> + where + P: RawPacket<'a>, + { + self.reader.read_packet::<P>() + } + + fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P> + where + P: RawPacket<'a>, + { + self.reader.read_raw_packet::<P>() + } +} + +impl<R, W> CraftSyncWriter for CraftConnection<R, W> +where + CraftReader<R>: CraftSyncReader, + CraftWriter<W>: CraftSyncWriter, +{ + fn write_packet<P>(&mut self, packet: P) -> WriteResult<()> + where + P: Packet, + { + self.writer.write_packet(packet) + } + + fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()> + where + P: RawPacket<'a>, + { + self.writer.write_raw_packet(packet) + } +} + +#[cfg(feature = "async")] +#[async_trait] +impl<R, W> CraftAsyncReader for CraftConnection<R, W> +where + CraftReader<R>: CraftAsyncReader, + R: Send + Sync, + CraftWriter<W>: CraftAsyncWriter, + W: Send + Sync, +{ + async fn read_packet<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet> + where + P: RawPacket<'a>, + { + self.reader.read_packet::<P>().await + } + + async fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P> + where + P: RawPacket<'a>, + { + self.reader.read_raw_packet::<P>().await + } +} + +#[cfg(feature = "async")] +#[async_trait] +impl<R, W> CraftAsyncWriter for CraftConnection<R, W> +where + CraftReader<R>: CraftAsyncReader, + R: Send + Sync, + CraftWriter<W>: CraftAsyncWriter, + W: Send + Sync, +{ + async fn write_packet<P>(&mut self, packet: P) -> WriteResult<()> + where + P: Packet + Send + Sync, + { + self.writer.write_packet(packet).await + } + + async fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()> + where + P: RawPacket<'a> + Send + Sync, + { + self.writer.write_raw_packet(packet).await + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..8fafec0 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,14 @@ +mod cfb8; +mod connection; +mod reader; +mod tcp; +mod util; +mod wrapper; +mod writer; + +pub use connection::CraftConnection; +pub use reader::*; +pub use writer::*; +pub use tcp::*; +pub use cfb8::CipherError; +pub use wrapper::*; diff --git a/src/reader.rs b/src/reader.rs new file mode 100644 index 0000000..68b9a9b --- /dev/null +++ b/src/reader.rs @@ -0,0 +1,294 @@ +use crate::cfb8::{setup_craft_cipher, CipherError, CraftCipher}; +use crate::util::{get_sized_buf, VAR_INT_BUF_SIZE}; +use crate::wrapper::{CraftIo, CraftWrapper}; +use flate2::{DecompressError, FlushDecompress, Status}; +use mcproto_rs::protocol::{Id, PacketDirection, RawPacket, State}; +use mcproto_rs::types::VarInt; +use mcproto_rs::{Deserialize, Deserialized}; +use thiserror::Error; + +#[cfg(feature = "async")] +use {async_trait::async_trait, futures::AsyncReadExt}; + +#[derive(Debug, Error)] +pub enum ReadError { + #[error("i/o failure during read")] + IoFailure(#[from] std::io::Error), + #[error("failed to read header VarInt")] + PacketHeaderErr(#[from] mcproto_rs::DeserializeErr), + #[error("failed to read packet")] + PacketErr(#[from] mcproto_rs::protocol::PacketErr), + #[error("failed to decompress packet")] + DecompressFailed(#[from] DecompressErr), +} + +#[derive(Debug, Error)] +pub enum DecompressErr { + #[error("buf error")] + BufError, + #[error("failure while decompressing")] + Failure(#[from] DecompressError), +} + +pub type ReadResult<P> = Result<Option<P>, ReadError>; + +#[cfg(feature = "async")] +#[async_trait] +pub trait CraftAsyncReader { + async fn read_packet<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet> + where + P: RawPacket<'a>, + { + deserialize_raw_packet(self.read_raw_packet::<P>().await) + } + + async fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P> + where + P: RawPacket<'a>; +} + +pub trait CraftSyncReader { + fn read_packet<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet> + where + P: RawPacket<'a>, + { + deserialize_raw_packet(self.read_raw_packet::<'a, P>()) + } + + fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P> + where + P: RawPacket<'a>; +} + +pub struct CraftReader<R> { + inner: R, + raw_buf: Option<Vec<u8>>, + decompress_buf: Option<Vec<u8>>, + compression_threshold: Option<i32>, + state: State, + direction: PacketDirection, + encryption: Option<CraftCipher>, +} + +impl<R> CraftWrapper<R> for CraftReader<R> { + fn into_inner(self) -> R { + self.inner + } +} + +impl<R> CraftIo for CraftReader<R> { + fn set_state(&mut self, next: State) { + self.state = next; + } + + fn set_compression_threshold(&mut self, threshold: Option<i32>) { + self.compression_threshold = threshold; + } + + fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError> { + setup_craft_cipher(&mut self.encryption, key, iv) + } +} + +macro_rules! rr_unwrap { + ($result: expr) => { + match $result { + Ok(Some(r)) => r, + Ok(None) => return Ok(None), + Err(err) => return Err(err), + } + }; +} + +macro_rules! check_unexpected_eof { + ($result: expr) => { + if let Err(err) = $result { + if err.kind() == std::io::ErrorKind::UnexpectedEof { + return Ok(None); + } + + return Err(ReadError::IoFailure(err)); + } + }; +} + +impl<R> CraftSyncReader for CraftReader<R> +where + R: std::io::Read, +{ + fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P> + where + P: RawPacket<'a>, + { + let primary_packet_len = rr_unwrap!(self.read_one_varint_sync()).0 as usize; + rr_unwrap!(self.read_n(primary_packet_len)); + self.read_packet_in_buf::<'a, P>(primary_packet_len) + } +} + +#[cfg(feature = "async")] +#[async_trait] +impl<R> CraftAsyncReader for CraftReader<R> +where + R: futures::AsyncRead + Unpin + Sync + Send, +{ + async fn read_raw_packet<'a, P>(&'a mut self) -> Result<Option<P>, ReadError> + where + P: RawPacket<'a>, + { + let primary_packet_len = rr_unwrap!(self.read_one_varint_async().await).0 as usize; + rr_unwrap!(self.read_n_async(primary_packet_len).await); + self.read_packet_in_buf::<P>(primary_packet_len) + } +} + +impl<R> CraftReader<R> +where + R: std::io::Read, +{ + fn read_one_varint_sync(&mut self) -> ReadResult<VarInt> { + deserialize_varint(rr_unwrap!(self.read_n(VAR_INT_BUF_SIZE))) + } + + fn read_n(&mut self, n: usize) -> ReadResult<&mut [u8]> { + let buf = get_sized_buf(&mut self.raw_buf, 0, n); + check_unexpected_eof!(self.inner.read_exact(buf)); + Ok(Some(buf)) + } +} + +#[cfg(feature = "async")] +impl<R> CraftReader<R> +where + R: futures::io::AsyncRead + Unpin + Sync + Send, +{ + async fn read_one_varint_async(&mut self) -> ReadResult<VarInt> { + deserialize_varint(rr_unwrap!(self.read_n_async(VAR_INT_BUF_SIZE).await)) + } + + async fn read_n_async(&mut self, n: usize) -> ReadResult<&mut [u8]> { + let buf = get_sized_buf(&mut self.raw_buf, 0, n); + check_unexpected_eof!(self.inner.read_exact(buf).await); + Ok(Some(buf)) + } +} + +macro_rules! dsz_unwrap { + ($bnam: expr, $k: ty) => { + match <$k>::mc_deserialize($bnam) { + Ok(Deserialized { + value: val, + data: rest, + }) => (val, rest), + Err(err) => { + return Err(ReadError::PacketHeaderErr(err)); + } + }; + }; +} + +impl<R> CraftReader<R> { + pub fn wrap(inner: R, direction: PacketDirection) -> Self { + Self::wrap_with_state(inner, direction, State::Handshaking) + } + + pub fn wrap_with_state(inner: R, direction: PacketDirection, state: State) -> Self { + Self { + inner, + raw_buf: None, + decompress_buf: None, + compression_threshold: None, + state, + direction, + encryption: None, + } + } + + fn read_packet_in_buf<'a, P>(&'a mut self, size: usize) -> ReadResult<P> + where + P: RawPacket<'a>, + { + // find data in buf + let buf = &mut self.raw_buf.as_mut().expect("should exist right now")[..size]; + // decrypt the packet if encryption is enabled + if let Some(encryption) = self.encryption.as_mut() { + encryption.decrypt(buf); + } + + // try to get the packet body bytes... this boils down to: + // * check if compression enabled, + // * read data len (VarInt) which isn't compressed + // * if data len is 0, then rest of packet is not compressed, remaining data is body + // * otherwise, data len is decompressed length, so prepare a decompression buf and decompress from + // the buffer into the decompression buffer, and return the slice of the decompression buffer + // which contains this packet's data + // * if compression not enabled, then the buf contains only the packet body bytes + + let packet_buf = if let Some(_) = self.compression_threshold { + let (data_len, rest) = dsz_unwrap!(buf, VarInt); + let data_len = data_len.0 as usize; + if data_len == 0 { + rest + } else { + decompress(rest, &mut self.decompress_buf, data_len)? + } + } else { + buf + }; + + let (raw_id, body_buf) = dsz_unwrap!(packet_buf, VarInt); + + let id = Id { + id: raw_id.0, + state: self.state.clone(), + direction: self.direction.clone(), + }; + + match P::create(id, body_buf) { + Ok(raw) => Ok(Some(raw)), + Err(err) => Err(ReadError::PacketErr(err)), + } + } +} + +fn deserialize_raw_packet<'a, P>(raw: ReadResult<P>) -> ReadResult<P::Packet> +where + P: RawPacket<'a>, +{ + match raw { + Ok(Some(raw)) => match raw.deserialize() { + Ok(deserialized) => Ok(Some(deserialized)), + Err(err) => Err(ReadError::PacketErr(err)), + }, + Ok(None) => Ok(None), + Err(err) => Err(err), + } +} + +fn deserialize_varint(buf: &[u8]) -> ReadResult<VarInt> { + match VarInt::mc_deserialize(buf) { + Ok(v) => Ok(Some(v.value)), + Err(err) => Err(ReadError::PacketHeaderErr(err)), + } +} + +fn decompress<'a>( + src: &'a [u8], + target: &'a mut Option<Vec<u8>>, + decompressed_len: usize, +) -> Result<&'a mut [u8], ReadError> { + let mut decompress = flate2::Decompress::new(true); + let decompress_buf = get_sized_buf(target, 0, decompressed_len); + loop { + match decompress.decompress(src, decompress_buf, FlushDecompress::Finish) { + Ok(Status::StreamEnd) => break, + Ok(Status::Ok) => {} + Ok(Status::BufError) => { + return Err(ReadError::DecompressFailed(DecompressErr::BufError)) + } + Err(err) => return Err(ReadError::DecompressFailed(DecompressErr::Failure(err))), + } + } + + Ok(&mut decompress_buf[..(decompress.total_out() as usize)]) +} diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 0000000..e2ead90 --- /dev/null +++ b/src/tcp.rs @@ -0,0 +1,69 @@ +use crate::connection::CraftConnection; +use crate::reader::CraftReader; +use crate::writer::CraftWriter; +use mcproto_rs::protocol::{PacketDirection, State}; +use std::convert::TryFrom; +use std::io::BufReader as StdBufReader; +use std::net::TcpStream; + +#[cfg(feature = "async")] +use futures::io::{AsyncRead, AsyncWrite, BufReader as AsyncBufReader}; + +pub const BUF_SIZE: usize = 8192; + +pub type CraftTcpConnection = CraftConnection<StdBufReader<TcpStream>, TcpStream>; + +impl CraftConnection<StdBufReader<TcpStream>, TcpStream> { + pub fn connect_server_std(to: String) -> Result<Self, std::io::Error> { + Self::from_std(TcpStream::connect(to)?, PacketDirection::ClientBound) + } + + pub fn wrap_client_stream_std(stream: TcpStream) -> Result<Self, std::io::Error> { + Self::from_std(stream, PacketDirection::ServerBound) + } + + pub fn from_std( + s1: TcpStream, + read_direction: PacketDirection, + ) -> Result<Self, std::io::Error> { + Self::from_std_with_state(s1, read_direction, State::Handshaking) + } + + pub fn from_std_with_state( + s1: TcpStream, + read_direction: PacketDirection, + state: State, + ) -> Result<Self, std::io::Error> { + let write = s1.try_clone()?; + let read = StdBufReader::with_capacity(BUF_SIZE, s1); + + Ok(Self { + reader: CraftReader::wrap_with_state(read, read_direction, state), + writer: CraftWriter::wrap_with_state(write, read_direction.opposite(), state), + }) + } +} + +#[cfg(feature = "async")] +impl<R, W> CraftConnection<AsyncBufReader<R>, W> +where + R: AsyncRead + Send + Sync + Unpin, + W: AsyncWrite + Send + Sync + Unpin, +{ + pub fn from_async(tuple: (R, W), read_direction: PacketDirection) -> Self { + Self::from_async_with_state(tuple, read_direction, State::Handshaking) + } + + pub fn from_async_with_state( + tuple: (R, W), + read_direction: PacketDirection, + state: State, + ) -> Self { + let (reader, writer) = tuple; + let reader = AsyncBufReader::with_capacity(BUF_SIZE, reader); + Self { + reader: CraftReader::wrap_with_state(reader, read_direction, state), + writer: CraftWriter::wrap_with_state(writer, read_direction.opposite(), state), + } + } +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..a19922f --- /dev/null +++ b/src/util.rs @@ -0,0 +1,55 @@ +pub(crate) const VAR_INT_BUF_SIZE: usize = 5; + +pub(crate) fn get_sized_buf(buf: &mut Option<Vec<u8>>, offset: usize, size: usize) -> &mut [u8] { + let end_at = offset + size; + loop { + match buf { + Some(v) => { + ensure_buf_has_size(v, end_at); + break &mut v[offset..end_at]; + } + None => { + let new_buf = Vec::with_capacity(end_at); + *buf = Some(new_buf); + } + } + } +} +fn ensure_buf_has_size(buf: &mut Vec<u8>, total_size: usize) { + let cur_len = buf.len(); + if cur_len >= total_size { + return; + } + + let additional = total_size - cur_len; + buf.reserve(additional); + unsafe { + let start_at = buf.as_mut_ptr(); + let start_write_at = start_at.offset(cur_len as isize); + std::ptr::write_bytes(start_write_at, 0, additional); + buf.set_len(total_size); + } +} + +pub(crate) fn move_data_rightwards(target: &mut [u8], size: usize, shift_amount: usize) { + let required_len = size + shift_amount; + let actual_len = target.len(); + if actual_len < required_len { + panic!( + "move of data to the right (0..{} -> {}..{}) exceeds size of buffer {}", + size, shift_amount, required_len, actual_len, + ) + } + + unsafe { move_data_rightwards_unchecked(target, size, shift_amount) } +} + +unsafe fn move_data_rightwards_unchecked(target: &mut [u8], size: usize, shift_amount: usize) { + if shift_amount == 0 { + return; + } + + let src_ptr = target.as_mut_ptr(); + let dst_ptr = src_ptr.offset(shift_amount as isize); + std::ptr::copy(src_ptr, dst_ptr, size); +} diff --git a/src/wrapper.rs b/src/wrapper.rs new file mode 100644 index 0000000..06f451b --- /dev/null +++ b/src/wrapper.rs @@ -0,0 +1,12 @@ +use crate::cfb8::CipherError; +use mcproto_rs::protocol::State; + +pub trait CraftWrapper<I> { + fn into_inner(self) -> I; +} + +pub trait CraftIo { + fn set_state(&mut self, next: State); + fn set_compression_threshold(&mut self, threshold: Option<i32>); + fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError>; +} diff --git a/src/writer.rs b/src/writer.rs new file mode 100644 index 0000000..6c76b23 --- /dev/null +++ b/src/writer.rs @@ -0,0 +1,420 @@ +use crate::cfb8::{setup_craft_cipher, CipherError, CraftCipher}; +use crate::util::{get_sized_buf, move_data_rightwards, VAR_INT_BUF_SIZE}; +use crate::wrapper::{CraftIo, CraftWrapper}; +use flate2::{CompressError, Compression, FlushCompress, Status}; +use mcproto_rs::protocol::{Id, Packet, PacketDirection, RawPacket, State}; +use mcproto_rs::types::VarInt; +use mcproto_rs::{Serialize, SerializeErr, SerializeResult, Serializer}; +use thiserror::Error; + +#[cfg(feature = "async")] +use {async_trait::async_trait, futures::AsyncWriteExt}; + +#[derive(Debug, Error)] +pub enum WriteError { + #[error("serialization of header data failed")] + HeaderSerializeFail(SerializeErr), + #[error("packet body serialization failed")] + BodySerializeFail(SerializeErr), + #[error("failed to compress packet")] + CompressFail(CompressError), + #[error("compression gave buf error")] + CompressBufError, + #[error("io error while writing data")] + IoFail(#[from] std::io::Error), + #[error("bad direction")] + BadDirection { + attempted: PacketDirection, + expected: PacketDirection, + }, + #[error("bad state")] + BadState { attempted: State, expected: State }, +} + +pub type WriteResult<P> = Result<P, WriteError>; + +#[cfg(feature = "async")] +#[async_trait] +pub trait CraftAsyncWriter { + async fn write_packet<P>(&mut self, packet: P) -> WriteResult<()> + where + P: Packet + Send + Sync; + + async fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()> + where + P: RawPacket<'a> + Send + Sync; +} + +pub trait CraftSyncWriter { + fn write_packet<P>(&mut self, packet: P) -> WriteResult<()> + where + P: Packet; + + fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()> + where + P: RawPacket<'a>; +} + +pub struct CraftWriter<W> { + inner: W, + + raw_buf: Option<Vec<u8>>, + compress_buf: Option<Vec<u8>>, + compression_threshold: Option<i32>, + state: State, + direction: PacketDirection, + encryption: Option<CraftCipher>, +} + +impl<W> CraftWrapper<W> for CraftWriter<W> { + fn into_inner(self) -> W { + self.inner + } +} + +impl<W> CraftIo for CraftWriter<W> { + fn set_state(&mut self, next: State) { + self.state = next; + } + + fn set_compression_threshold(&mut self, threshold: Option<i32>) { + self.compression_threshold = threshold; + } + + fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError> { + setup_craft_cipher(&mut self.encryption, key, iv) + } +} + +impl<W> CraftSyncWriter for CraftWriter<W> +where + W: std::io::Write, +{ + fn write_packet<P>(&mut self, packet: P) -> WriteResult<()> + where + P: Packet, + { + let prepared = self.serialize_packet_to_buf(packet)?; + write_data_to_target_sync(self.prepare_packet_in_buf(prepared)?)?; + Ok(()) + } + + fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()> + where + P: RawPacket<'a>, + { + let prepared = self.serialize_raw_packet_to_buf(packet)?; + write_data_to_target_sync(self.prepare_packet_in_buf(prepared)?)?; + Ok(()) + } +} + +fn write_data_to_target_sync<'a, W>(tuple: (&'a [u8], &'a mut W)) -> Result<(), std::io::Error> +where + W: std::io::Write, +{ + let (data, target) = tuple; + target.write_all(data) +} + +#[cfg(feature = "async")] +#[async_trait] +impl<W> CraftAsyncWriter for CraftWriter<W> +where + W: futures::AsyncWrite + Unpin + Send + Sync, +{ + async fn write_packet<P>(&mut self, packet: P) -> WriteResult<()> + where + P: Packet + Send + Sync, + { + let prepared = self.serialize_packet_to_buf(packet)?; + write_data_to_target_async(self.prepare_packet_in_buf(prepared)?).await?; + Ok(()) + } + + async fn write_raw_packet<'a, P>(&mut self, packet: P) -> WriteResult<()> + where + P: RawPacket<'a> + Send + Sync, + { + let prepared = self.serialize_raw_packet_to_buf(packet)?; + write_data_to_target_async(self.prepare_packet_in_buf(prepared)?).await?; + Ok(()) + } +} + +#[cfg(feature = "async")] +async fn write_data_to_target_async<'a, W>( + tuple: (&'a [u8], &'a mut W), +) -> Result<(), std::io::Error> +where + W: futures::AsyncWrite + Unpin + Send + Sync, +{ + let (data, target) = tuple; + target.write_all(data).await +} + +const HEADER_OFFSET: usize = VAR_INT_BUF_SIZE * 2; + +struct PreparedPacketHandle { + id_size: usize, + data_size: usize, +} + +impl<W> CraftWriter<W> { + pub fn wrap(inner: W, direction: PacketDirection) -> Self { + Self::wrap_with_state(inner, direction, State::Handshaking) + } + + pub fn wrap_with_state(inner: W, direction: PacketDirection, state: State) -> Self { + Self { + inner, + raw_buf: None, + compression_threshold: None, + compress_buf: None, + state, + direction, + encryption: None, + } + } + + fn prepare_packet_in_buf( + &mut self, + prepared: PreparedPacketHandle, + ) -> WriteResult<(&[u8], &mut W)> { + // assume id and body are in raw buf from HEADER_OFFSET .. size + HEADER_OFFSET + let body_size = prepared.id_size + prepared.data_size; + let buf = get_sized_buf(&mut self.raw_buf, 0, body_size); + + let packet_data = if let Some(threshold) = self.compression_threshold { + if threshold >= 0 && (threshold as usize) <= body_size { + let compressed_size = compress(buf, &mut self.compress_buf, HEADER_OFFSET)?.len(); + let compress_buf = + get_sized_buf(&mut self.compress_buf, 0, compressed_size + HEADER_OFFSET); + + let data_len_target = &mut compress_buf[VAR_INT_BUF_SIZE..HEADER_OFFSET]; + let mut data_len_serializer = SliceSerializer::create(data_len_target); + VarInt(body_size as i32) + .mc_serialize(&mut data_len_serializer) + .map_err(move |err| WriteError::HeaderSerializeFail(err))?; + let data_len_bytes = data_len_serializer.finish().len(); + + let packet_len_target = &mut compress_buf[..VAR_INT_BUF_SIZE]; + let mut packet_len_serializer = SliceSerializer::create(packet_len_target); + VarInt((compressed_size + data_len_bytes) as i32) + .mc_serialize(&mut packet_len_serializer) + .map_err(move |err| WriteError::HeaderSerializeFail(err))?; + let packet_len_bytes = packet_len_serializer.finish().len(); + + let n_shift_packet_len = VAR_INT_BUF_SIZE - packet_len_bytes; + move_data_rightwards( + &mut compress_buf[..HEADER_OFFSET], + packet_len_bytes, + n_shift_packet_len, + ); + let n_shift_data_len = VAR_INT_BUF_SIZE - data_len_bytes; + move_data_rightwards( + &mut compress_buf[n_shift_packet_len..HEADER_OFFSET], + packet_len_bytes + data_len_bytes, + n_shift_data_len, + ); + let start_offset = n_shift_data_len + n_shift_packet_len; + let end_at = start_offset + data_len_bytes + packet_len_bytes + compressed_size; + &mut compress_buf[start_offset..end_at] + } else { + let packet_len_start_at = VAR_INT_BUF_SIZE - 1; + let packet_len_target = &mut buf[packet_len_start_at..HEADER_OFFSET - 1]; + let mut packet_len_serializer = SliceSerializer::create(packet_len_target); + VarInt((body_size + 1) as i32) + .mc_serialize(&mut packet_len_serializer) + .map_err(move |err| WriteError::HeaderSerializeFail(err))?; + + let packet_len_bytes = packet_len_serializer.finish().len(); + let n_shift_packet_len = VAR_INT_BUF_SIZE - packet_len_bytes; + move_data_rightwards( + &mut buf[packet_len_start_at..HEADER_OFFSET - 1], + packet_len_bytes, + n_shift_packet_len, + ); + + let start_offset = packet_len_start_at + n_shift_packet_len; + let end_at = start_offset + packet_len_bytes + 1 + body_size; + &mut buf[start_offset..end_at] + } + } else { + let packet_len_target = &mut buf[VAR_INT_BUF_SIZE..HEADER_OFFSET]; + let mut packet_len_serializer = SliceSerializer::create(packet_len_target); + VarInt(body_size as i32) + .mc_serialize(&mut packet_len_serializer) + .map_err(move |err| WriteError::HeaderSerializeFail(err))?; + let packet_len_bytes = packet_len_serializer.finish().len(); + let n_shift_packet_len = VAR_INT_BUF_SIZE - packet_len_bytes; + move_data_rightwards( + &mut buf[VAR_INT_BUF_SIZE..HEADER_OFFSET], + packet_len_bytes, + n_shift_packet_len, + ); + let start_offset = VAR_INT_BUF_SIZE + n_shift_packet_len; + let end_at = start_offset + packet_len_bytes + body_size; + &mut buf[start_offset..end_at] + }; + + if let Some(encryption) = &mut self.encryption { + encryption.encrypt(packet_data); + } + + Ok((packet_data, &mut self.inner)) + } + + fn serialize_packet_to_buf<P>(&mut self, packet: P) -> WriteResult<PreparedPacketHandle> + where + P: Packet, + { + let id_size = self.serialize_id_to_buf(packet.id())?; + let data_size = self.serialize_to_buf(HEADER_OFFSET + id_size, move |serializer| { + packet + .mc_serialize_body(serializer) + .map_err(move |err| WriteError::BodySerializeFail(err)) + })?; + + Ok(PreparedPacketHandle { id_size, data_size }) + } + + fn serialize_raw_packet_to_buf<'a, P>(&mut self, packet: P) -> WriteResult<PreparedPacketHandle> + where + P: RawPacket<'a>, + { + let id_size = self.serialize_id_to_buf(packet.id())?; + let packet_data = packet.data(); + let data_size = packet_data.len(); + let buf = get_sized_buf(&mut self.raw_buf, HEADER_OFFSET, id_size + data_size); + + (&mut buf[id_size..]).copy_from_slice(packet_data); + + Ok(PreparedPacketHandle { id_size, data_size }) + } + + fn serialize_id_to_buf(&mut self, id: Id) -> WriteResult<usize> { + if id.direction != self.direction { + return Err(WriteError::BadDirection { + expected: self.direction, + attempted: id.direction, + }); + } + + if id.state != self.state { + return Err(WriteError::BadState { + expected: self.state, + attempted: id.state, + }); + } + + self.serialize_to_buf(HEADER_OFFSET, move |serializer| { + id.mc_serialize(serializer) + .map_err(move |err| WriteError::HeaderSerializeFail(err)) + }) + } + + fn serialize_to_buf<'a, F>(&'a mut self, offset: usize, f: F) -> WriteResult<usize> + where + F: FnOnce(&mut GrowVecSerializer<'a>) -> Result<(), WriteError>, + { + let mut serializer = GrowVecSerializer::create(&mut self.raw_buf, offset); + f(&mut serializer)?; + Ok(serializer.finish().map(move |b| b.len()).unwrap_or(0)) + } +} + +#[derive(Debug)] +struct GrowVecSerializer<'a> { + target: &'a mut Option<Vec<u8>>, + at: usize, + offset: usize, +} + +impl<'a> Serializer for GrowVecSerializer<'a> { + fn serialize_bytes(&mut self, data: &[u8]) -> SerializeResult { + get_sized_buf(self.target, self.at + self.offset, data.len()).copy_from_slice(data); + Ok(()) + } +} + +impl<'a> GrowVecSerializer<'a> { + fn create(target: &'a mut Option<Vec<u8>>, offset: usize) -> Self { + Self { + target, + at: 0, + offset, + } + } + + fn finish(self) -> Option<&'a mut [u8]> { + if let Some(buf) = self.target { + Some(&mut buf[self.offset..self.offset + self.at]) + } else { + None + } + } +} + +struct SliceSerializer<'a> { + target: &'a mut [u8], + at: usize, +} + +impl<'a> Serializer for SliceSerializer<'a> { + fn serialize_bytes(&mut self, data: &[u8]) -> SerializeResult { + let end_at = self.at + data.len(); + if end_at >= self.target.len() { + panic!( + "cannot fit data in slice ({} exceeds length {} at {})", + data.len(), + self.target.len(), + self.at + ); + } + + (&mut self.target[self.at..end_at]).copy_from_slice(data); + self.at = end_at; + Ok(()) + } +} + +impl<'a> SliceSerializer<'a> { + fn create(target: &'a mut [u8]) -> Self { + Self { target, at: 0 } + } + + fn finish(self) -> &'a [u8] { + &self.target[..self.at] + } +} + +fn compress<'a, 'b>( + src: &'b [u8], + output: &'a mut Option<Vec<u8>>, + offset: usize, +) -> Result<&'a mut [u8], WriteError> { + let target = get_sized_buf(output, offset, src.len()); + let mut compressor = flate2::Compress::new_with_window_bits(Compression::fast(), true, 15); + loop { + let input = &src[(compressor.total_in() as usize)..]; + let eof = input.is_empty(); + let output = &mut target[(compressor.total_out() as usize)..]; + let flush = if eof { + FlushCompress::Finish + } else { + FlushCompress::None + }; + + match compressor + .compress(input, output, flush) + .map_err(move |err| WriteError::CompressFail(err))? + { + Status::Ok => {} + Status::BufError => return Err(WriteError::CompressBufError), + Status::StreamEnd => break, + } + } + + Ok(&mut target[..(compressor.total_out() as usize)]) +} |