aboutsummaryrefslogtreecommitdiff
path: root/src/reader.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/reader.rs')
-rw-r--r--src/reader.rs294
1 files changed, 294 insertions, 0 deletions
diff --git a/src/reader.rs b/src/reader.rs
new file mode 100644
index 0000000..68b9a9b
--- /dev/null
+++ b/src/reader.rs
@@ -0,0 +1,294 @@
+use crate::cfb8::{setup_craft_cipher, CipherError, CraftCipher};
+use crate::util::{get_sized_buf, VAR_INT_BUF_SIZE};
+use crate::wrapper::{CraftIo, CraftWrapper};
+use flate2::{DecompressError, FlushDecompress, Status};
+use mcproto_rs::protocol::{Id, PacketDirection, RawPacket, State};
+use mcproto_rs::types::VarInt;
+use mcproto_rs::{Deserialize, Deserialized};
+use thiserror::Error;
+
+#[cfg(feature = "async")]
+use {async_trait::async_trait, futures::AsyncReadExt};
+
+#[derive(Debug, Error)]
+pub enum ReadError {
+ #[error("i/o failure during read")]
+ IoFailure(#[from] std::io::Error),
+ #[error("failed to read header VarInt")]
+ PacketHeaderErr(#[from] mcproto_rs::DeserializeErr),
+ #[error("failed to read packet")]
+ PacketErr(#[from] mcproto_rs::protocol::PacketErr),
+ #[error("failed to decompress packet")]
+ DecompressFailed(#[from] DecompressErr),
+}
+
+#[derive(Debug, Error)]
+pub enum DecompressErr {
+ #[error("buf error")]
+ BufError,
+ #[error("failure while decompressing")]
+ Failure(#[from] DecompressError),
+}
+
+pub type ReadResult<P> = Result<Option<P>, ReadError>;
+
+#[cfg(feature = "async")]
+#[async_trait]
+pub trait CraftAsyncReader {
+ async fn read_packet<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet>
+ where
+ P: RawPacket<'a>,
+ {
+ deserialize_raw_packet(self.read_raw_packet::<P>().await)
+ }
+
+ async fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P>
+ where
+ P: RawPacket<'a>;
+}
+
+pub trait CraftSyncReader {
+ fn read_packet<'a, P>(&'a mut self) -> ReadResult<<P as RawPacket<'a>>::Packet>
+ where
+ P: RawPacket<'a>,
+ {
+ deserialize_raw_packet(self.read_raw_packet::<'a, P>())
+ }
+
+ fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P>
+ where
+ P: RawPacket<'a>;
+}
+
+pub struct CraftReader<R> {
+ inner: R,
+ raw_buf: Option<Vec<u8>>,
+ decompress_buf: Option<Vec<u8>>,
+ compression_threshold: Option<i32>,
+ state: State,
+ direction: PacketDirection,
+ encryption: Option<CraftCipher>,
+}
+
+impl<R> CraftWrapper<R> for CraftReader<R> {
+ fn into_inner(self) -> R {
+ self.inner
+ }
+}
+
+impl<R> CraftIo for CraftReader<R> {
+ fn set_state(&mut self, next: State) {
+ self.state = next;
+ }
+
+ fn set_compression_threshold(&mut self, threshold: Option<i32>) {
+ self.compression_threshold = threshold;
+ }
+
+ fn enable_encryption(&mut self, key: &[u8], iv: &[u8]) -> Result<(), CipherError> {
+ setup_craft_cipher(&mut self.encryption, key, iv)
+ }
+}
+
+macro_rules! rr_unwrap {
+ ($result: expr) => {
+ match $result {
+ Ok(Some(r)) => r,
+ Ok(None) => return Ok(None),
+ Err(err) => return Err(err),
+ }
+ };
+}
+
+macro_rules! check_unexpected_eof {
+ ($result: expr) => {
+ if let Err(err) = $result {
+ if err.kind() == std::io::ErrorKind::UnexpectedEof {
+ return Ok(None);
+ }
+
+ return Err(ReadError::IoFailure(err));
+ }
+ };
+}
+
+impl<R> CraftSyncReader for CraftReader<R>
+where
+ R: std::io::Read,
+{
+ fn read_raw_packet<'a, P>(&'a mut self) -> ReadResult<P>
+ where
+ P: RawPacket<'a>,
+ {
+ let primary_packet_len = rr_unwrap!(self.read_one_varint_sync()).0 as usize;
+ rr_unwrap!(self.read_n(primary_packet_len));
+ self.read_packet_in_buf::<'a, P>(primary_packet_len)
+ }
+}
+
+#[cfg(feature = "async")]
+#[async_trait]
+impl<R> CraftAsyncReader for CraftReader<R>
+where
+ R: futures::AsyncRead + Unpin + Sync + Send,
+{
+ async fn read_raw_packet<'a, P>(&'a mut self) -> Result<Option<P>, ReadError>
+ where
+ P: RawPacket<'a>,
+ {
+ let primary_packet_len = rr_unwrap!(self.read_one_varint_async().await).0 as usize;
+ rr_unwrap!(self.read_n_async(primary_packet_len).await);
+ self.read_packet_in_buf::<P>(primary_packet_len)
+ }
+}
+
+impl<R> CraftReader<R>
+where
+ R: std::io::Read,
+{
+ fn read_one_varint_sync(&mut self) -> ReadResult<VarInt> {
+ deserialize_varint(rr_unwrap!(self.read_n(VAR_INT_BUF_SIZE)))
+ }
+
+ fn read_n(&mut self, n: usize) -> ReadResult<&mut [u8]> {
+ let buf = get_sized_buf(&mut self.raw_buf, 0, n);
+ check_unexpected_eof!(self.inner.read_exact(buf));
+ Ok(Some(buf))
+ }
+}
+
+#[cfg(feature = "async")]
+impl<R> CraftReader<R>
+where
+ R: futures::io::AsyncRead + Unpin + Sync + Send,
+{
+ async fn read_one_varint_async(&mut self) -> ReadResult<VarInt> {
+ deserialize_varint(rr_unwrap!(self.read_n_async(VAR_INT_BUF_SIZE).await))
+ }
+
+ async fn read_n_async(&mut self, n: usize) -> ReadResult<&mut [u8]> {
+ let buf = get_sized_buf(&mut self.raw_buf, 0, n);
+ check_unexpected_eof!(self.inner.read_exact(buf).await);
+ Ok(Some(buf))
+ }
+}
+
+macro_rules! dsz_unwrap {
+ ($bnam: expr, $k: ty) => {
+ match <$k>::mc_deserialize($bnam) {
+ Ok(Deserialized {
+ value: val,
+ data: rest,
+ }) => (val, rest),
+ Err(err) => {
+ return Err(ReadError::PacketHeaderErr(err));
+ }
+ };
+ };
+}
+
+impl<R> CraftReader<R> {
+ pub fn wrap(inner: R, direction: PacketDirection) -> Self {
+ Self::wrap_with_state(inner, direction, State::Handshaking)
+ }
+
+ pub fn wrap_with_state(inner: R, direction: PacketDirection, state: State) -> Self {
+ Self {
+ inner,
+ raw_buf: None,
+ decompress_buf: None,
+ compression_threshold: None,
+ state,
+ direction,
+ encryption: None,
+ }
+ }
+
+ fn read_packet_in_buf<'a, P>(&'a mut self, size: usize) -> ReadResult<P>
+ where
+ P: RawPacket<'a>,
+ {
+ // find data in buf
+ let buf = &mut self.raw_buf.as_mut().expect("should exist right now")[..size];
+ // decrypt the packet if encryption is enabled
+ if let Some(encryption) = self.encryption.as_mut() {
+ encryption.decrypt(buf);
+ }
+
+ // try to get the packet body bytes... this boils down to:
+ // * check if compression enabled,
+ // * read data len (VarInt) which isn't compressed
+ // * if data len is 0, then rest of packet is not compressed, remaining data is body
+ // * otherwise, data len is decompressed length, so prepare a decompression buf and decompress from
+ // the buffer into the decompression buffer, and return the slice of the decompression buffer
+ // which contains this packet's data
+ // * if compression not enabled, then the buf contains only the packet body bytes
+
+ let packet_buf = if let Some(_) = self.compression_threshold {
+ let (data_len, rest) = dsz_unwrap!(buf, VarInt);
+ let data_len = data_len.0 as usize;
+ if data_len == 0 {
+ rest
+ } else {
+ decompress(rest, &mut self.decompress_buf, data_len)?
+ }
+ } else {
+ buf
+ };
+
+ let (raw_id, body_buf) = dsz_unwrap!(packet_buf, VarInt);
+
+ let id = Id {
+ id: raw_id.0,
+ state: self.state.clone(),
+ direction: self.direction.clone(),
+ };
+
+ match P::create(id, body_buf) {
+ Ok(raw) => Ok(Some(raw)),
+ Err(err) => Err(ReadError::PacketErr(err)),
+ }
+ }
+}
+
+fn deserialize_raw_packet<'a, P>(raw: ReadResult<P>) -> ReadResult<P::Packet>
+where
+ P: RawPacket<'a>,
+{
+ match raw {
+ Ok(Some(raw)) => match raw.deserialize() {
+ Ok(deserialized) => Ok(Some(deserialized)),
+ Err(err) => Err(ReadError::PacketErr(err)),
+ },
+ Ok(None) => Ok(None),
+ Err(err) => Err(err),
+ }
+}
+
+fn deserialize_varint(buf: &[u8]) -> ReadResult<VarInt> {
+ match VarInt::mc_deserialize(buf) {
+ Ok(v) => Ok(Some(v.value)),
+ Err(err) => Err(ReadError::PacketHeaderErr(err)),
+ }
+}
+
+fn decompress<'a>(
+ src: &'a [u8],
+ target: &'a mut Option<Vec<u8>>,
+ decompressed_len: usize,
+) -> Result<&'a mut [u8], ReadError> {
+ let mut decompress = flate2::Decompress::new(true);
+ let decompress_buf = get_sized_buf(target, 0, decompressed_len);
+ loop {
+ match decompress.decompress(src, decompress_buf, FlushDecompress::Finish) {
+ Ok(Status::StreamEnd) => break,
+ Ok(Status::Ok) => {}
+ Ok(Status::BufError) => {
+ return Err(ReadError::DecompressFailed(DecompressErr::BufError))
+ }
+ Err(err) => return Err(ReadError::DecompressFailed(DecompressErr::Failure(err))),
+ }
+ }
+
+ Ok(&mut decompress_buf[..(decompress.total_out() as usize)])
+}