use crate::cfb8::{setup_craft_cipher, CipherError, CraftCipher}; use crate::util::{get_sized_buf, VAR_INT_BUF_SIZE}; use crate::wrapper::{CraftIo, CraftWrapper}; use flate2::{DecompressError, FlushDecompress, Status}; use mcproto_rs::protocol::{Id, PacketDirection, RawPacket, State}; use mcproto_rs::types::VarInt; use mcproto_rs::{Deserialize, Deserialized}; use thiserror::Error; #[cfg(feature = "async")] use {async_trait::async_trait, futures::AsyncReadExt}; #[derive(Debug, Error)] pub enum ReadError { #[error("i/o failure during read")] IoFailure(#[from] std::io::Error), #[error("failed to read header VarInt")] PacketHeaderErr(#[from] mcproto_rs::DeserializeErr), #[error("failed to read packet")] PacketErr(#[from] mcproto_rs::protocol::PacketErr), #[error("failed to decompress packet")] DecompressFailed(#[from] DecompressErr), } #[derive(Debug, Error)] pub enum DecompressErr { #[error("buf error")] BufError, #[error("failure while decompressing")] Failure(#[from] DecompressError), } pub type ReadResult

= Result, ReadError>; #[cfg(feature = "async")] #[async_trait] pub trait CraftAsyncReader { async fn read_packet<'a, P>(&'a mut self) -> ReadResult<

>::Packet> where P: RawPacket<'a>, { deserialize_raw_packet(self.read_raw_packet::

().await) } async fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult

where P: RawPacket<'a>; } pub trait CraftSyncReader { fn read_packet<'a, P>(&'a mut self) -> ReadResult<

>::Packet> where P: RawPacket<'a>, { deserialize_raw_packet(self.read_raw_packet::<'a, P>()) } fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult

where P: RawPacket<'a>; } pub struct CraftReader { inner: R, raw_buf: Option>, decompress_buf: Option>, compression_threshold: Option, state: State, direction: PacketDirection, encryption: Option, } impl CraftWrapper for CraftReader { fn into_inner(self) -> R { self.inner } } impl CraftIo for CraftReader { fn set_state(&mut self, next: State) { self.state = next; } fn set_compression_threshold(&mut self, threshold: Option) { self.compression_threshold = threshold; } fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError> { setup_craft_cipher(&mut self.encryption, key, iv) } } macro_rules! rr_unwrap { ($result: expr) => { match $result { Ok(Some(r)) => r, Ok(None) => return Ok(None), Err(err) => return Err(err), } }; } macro_rules! check_unexpected_eof { ($result: expr) => { if let Err(err) = $result { if err.kind() == std::io::ErrorKind::UnexpectedEof { return Ok(None); } return Err(ReadError::IoFailure(err)); } }; } impl CraftSyncReader for CraftReader where R: std::io::Read, { fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult

where P: RawPacket<'a>, { let primary_packet_len = rr_unwrap!(self.read_one_varint_sync()).0 as usize; rr_unwrap!(self.read_n(primary_packet_len)); self.read_packet_in_buf::<'a, P>(primary_packet_len) } } #[cfg(feature = "async")] #[async_trait] impl CraftAsyncReader for CraftReader where R: futures::AsyncRead + Unpin + Sync + Send, { async fn read_raw_packet<'a, P>(&'a mut self) -> Result, ReadError> where P: RawPacket<'a>, { let primary_packet_len = rr_unwrap!(self.read_one_varint_async().await).0 as usize; rr_unwrap!(self.read_n_async(primary_packet_len).await); self.read_packet_in_buf::

(primary_packet_len) } } impl CraftReader where R: std::io::Read, { fn read_one_varint_sync(&mut self) -> ReadResult { deserialize_varint(rr_unwrap!(self.read_n(VAR_INT_BUF_SIZE))) } fn read_n(&mut self, n: usize) -> ReadResult<&mut [u8]> { let buf = get_sized_buf(&mut self.raw_buf, 0, n); check_unexpected_eof!(self.inner.read_exact(buf)); Ok(Some(buf)) } } #[cfg(feature = "async")] impl CraftReader where R: futures::io::AsyncRead + Unpin + Sync + Send, { async fn read_one_varint_async(&mut self) -> ReadResult { deserialize_varint(rr_unwrap!(self.read_n_async(VAR_INT_BUF_SIZE).await)) } async fn read_n_async(&mut self, n: usize) -> ReadResult<&mut [u8]> { let buf = get_sized_buf(&mut self.raw_buf, 0, n); check_unexpected_eof!(self.inner.read_exact(buf).await); Ok(Some(buf)) } } macro_rules! dsz_unwrap { ($bnam: expr, $k: ty) => { match <$k>::mc_deserialize($bnam) { Ok(Deserialized { value: val, data: rest, }) => (val, rest), Err(err) => { return Err(ReadError::PacketHeaderErr(err)); } }; }; } impl CraftReader { pub fn wrap(inner: R, direction: PacketDirection) -> Self { Self::wrap_with_state(inner, direction, State::Handshaking) } pub fn wrap_with_state(inner: R, direction: PacketDirection, state: State) -> Self { Self { inner, raw_buf: None, decompress_buf: None, compression_threshold: None, state, direction, encryption: None, } } fn read_packet_in_buf<'a, P>(&'a mut self, size: usize) -> ReadResult

where P: RawPacket<'a>, { // find data in buf let buf = &mut self.raw_buf.as_mut().expect("should exist right now")[..size]; // decrypt the packet if encryption is enabled if let Some(encryption) = self.encryption.as_mut() { encryption.decrypt(buf); } // try to get the packet body bytes... this boils down to: // * check if compression enabled, // * read data len (VarInt) which isn't compressed // * if data len is 0, then rest of packet is not compressed, remaining data is body // * otherwise, data len is decompressed length, so prepare a decompression buf and decompress from // the buffer into the decompression buffer, and return the slice of the decompression buffer // which contains this packet's data // * if compression not enabled, then the buf contains only the packet body bytes let packet_buf = if let Some(_) = self.compression_threshold { let (data_len, rest) = dsz_unwrap!(buf, VarInt); let data_len = data_len.0 as usize; if data_len == 0 { rest } else { decompress(rest, &mut self.decompress_buf, data_len)? } } else { buf }; let (raw_id, body_buf) = dsz_unwrap!(packet_buf, VarInt); let id = Id { id: raw_id.0, state: self.state.clone(), direction: self.direction.clone(), }; match P::create(id, body_buf) { Ok(raw) => Ok(Some(raw)), Err(err) => Err(ReadError::PacketErr(err)), } } } fn deserialize_raw_packet<'a, P>(raw: ReadResult

) -> ReadResult where P: RawPacket<'a>, { match raw { Ok(Some(raw)) => match raw.deserialize() { Ok(deserialized) => Ok(Some(deserialized)), Err(err) => Err(ReadError::PacketErr(err)), }, Ok(None) => Ok(None), Err(err) => Err(err), } } fn deserialize_varint(buf: &[u8]) -> ReadResult { match VarInt::mc_deserialize(buf) { Ok(v) => Ok(Some(v.value)), Err(err) => Err(ReadError::PacketHeaderErr(err)), } } fn decompress<'a>( src: &'a [u8], target: &'a mut Option>, decompressed_len: usize, ) -> Result<&'a mut [u8], ReadError> { let mut decompress = flate2::Decompress::new(true); let decompress_buf = get_sized_buf(target, 0, decompressed_len); loop { match decompress.decompress(src, decompress_buf, FlushDecompress::Finish) { Ok(Status::StreamEnd) => break, Ok(Status::Ok) => {} Ok(Status::BufError) => { return Err(ReadError::DecompressFailed(DecompressErr::BufError)) } Err(err) => return Err(ReadError::DecompressFailed(DecompressErr::Failure(err))), } } Ok(&mut decompress_buf[..(decompress.total_out() as usize)]) }