//! This module defines various methods to read and //! write packets in Minecraft's //! [ServerListPing](https://wiki.vg/Server_List_Ping) //! protocol. use std::io::Cursor; use std::time::Duration; use async_trait::async_trait; use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; #[derive(Error, Debug)] pub enum ProtocolError { #[error("error reading or writing data")] Io(#[from] std::io::Error), #[error("invalid packet length")] InvalidPacketLength, #[error("invalid varint data")] InvalidVarInt, #[error("invalid packet (expected ID {expected:?}, actual ID {actual:?})")] InvalidPacketId { expected: usize, actual: usize }, #[error("invalid ServerListPing response body (invalid UTF-8)")] InvalidResponseBody, #[error("connection timed out")] Timeout(#[from] tokio::time::error::Elapsed), } const MAX_MINECRAFT_STRING_LENGTH: usize = 32_767; const MAX_STATUS_RESPONSE_PACKET_LENGTH: usize = 32_771; const MAX_PONG_PACKET_LENGTH: usize = 9; /// Ensures the length of a packet as stated by a server is not longer than a /// hard-coded limit. /// /// For example, if we ping a server that says its status packet is 2 billion /// bytes long, we don't try to allocate a 2 billion byte buffer, since that /// will OOM our machine. /// /// Implemented as a function so that you can easily find callsites and see /// where we accept unvalidated input from servers. fn cap_length(length: usize, max_length: usize) -> Result { if length > max_length { return Err(ProtocolError::InvalidPacketLength); } Ok(length) } /// State represents the desired next state of the /// exchange. /// /// It's a bit silly now as there's only /// one entry, but technically there is more than /// one type that can be sent here. #[derive(Clone, Copy)] pub enum State { Status, } impl From for usize { fn from(state: State) -> Self { match state { State::Status => 1, } } } /// RawPacket is the underlying wrapper of data that /// gets read from and written to the socket. /// /// Typically, the flow looks like this: /// 1. Construct a specific packet (HandshakePacket /// for example). /// 2. Write that packet's contents to a byte buffer. /// 3. Construct a RawPacket using that byte buffer. /// 4. Write the RawPacket to the socket. struct RawPacket { id: usize, data: Box<[u8]>, } impl RawPacket { fn new(id: usize, data: Box<[u8]>) -> Self { RawPacket { id, data } } } /// AsyncWireReadExt adds varint and varint-backed /// string support to things that implement AsyncRead. #[async_trait] pub trait AsyncWireReadExt { async fn read_varint(&mut self) -> Result; async fn read_string(&mut self) -> Result; } #[async_trait] impl AsyncWireReadExt for R { async fn read_varint(&mut self) -> Result { let mut read = 0; let mut result = 0; loop { let read_value = self.read_u8().await?; let value = read_value & 0b0111_1111; result |= (value as usize) << (7 * read); read += 1; if read > 5 { return Err(ProtocolError::InvalidVarInt); } if (read_value & 0b1000_0000) == 0 { return Ok(result); } } } async fn read_string(&mut self) -> Result { let length = cap_length(self.read_varint().await?, MAX_MINECRAFT_STRING_LENGTH)?; let mut buffer = vec![0; length]; self.read_exact(&mut buffer).await?; Ok(String::from_utf8(buffer).map_err(|_| ProtocolError::InvalidResponseBody)?) } } /// AsyncWireWriteExt adds varint and varint-backed /// string support to things that implement AsyncWrite. #[async_trait] pub trait AsyncWireWriteExt { async fn write_varint(&mut self, int: usize) -> Result<(), ProtocolError>; async fn write_string(&mut self, string: &str) -> Result<(), ProtocolError>; } #[async_trait] impl AsyncWireWriteExt for W { async fn write_varint(&mut self, int: usize) -> Result<(), ProtocolError> { let mut int = (int as u64) & 0xFFFF_FFFF; let mut written = 0; let mut buffer = [0; 5]; loop { let temp = (int & 0b0111_1111) as u8; int >>= 7; if int != 0 { buffer[written] = temp | 0b1000_0000; } else { buffer[written] = temp; } written += 1; if int == 0 { break; } } self.write_all(&buffer[0..written]).await?; Ok(()) } async fn write_string(&mut self, string: &str) -> Result<(), ProtocolError> { self.write_varint(string.len()).await?; self.write_all(string.as_bytes()).await?; Ok(()) } } /// PacketId is used to allow AsyncWriteRawPacket /// to generically get a packet's ID. pub trait PacketId { fn get_packet_id(&self) -> usize; } /// ExpectedPacketId is used to allow AsyncReadRawPacket /// to generically get a packet's expected ID. pub trait ExpectedPacketId { fn get_expected_packet_id() -> usize; fn get_max_packet_length() -> usize; } /// AsyncReadFromBuffer is used to allow /// AsyncReadRawPacket to generically read a /// packet's specific data from a buffer. #[async_trait] pub trait AsyncReadFromBuffer: Sized { async fn read_from_buffer(buffer: Vec) -> Result; } /// AsyncWriteToBuffer is used to allow /// AsyncWriteRawPacket to generically write a /// packet's specific data into a buffer. #[async_trait] pub trait AsyncWriteToBuffer { async fn write_to_buffer(&self) -> Result, ProtocolError>; } /// AsyncReadRawPacket is the core piece of /// the read side of the protocol. It allows /// the user to construct a specific packet /// from something that implements AsyncRead. #[async_trait] pub trait AsyncReadRawPacket { async fn read_packet( &mut self, ) -> Result; async fn read_packet_with_timeout( &mut self, timeout: Duration, ) -> Result; } #[async_trait] impl AsyncReadRawPacket for R { async fn read_packet( &mut self, ) -> Result { let length = cap_length(self.read_varint().await?, T::get_max_packet_length())?; if length == 0 { return Err(ProtocolError::InvalidPacketLength); } let packet_id = self.read_varint().await?; let expected_packet_id = T::get_expected_packet_id(); if packet_id != expected_packet_id { return Err(ProtocolError::InvalidPacketId { expected: expected_packet_id, actual: packet_id, }); } let payload_length = length .checked_sub(1) .ok_or(ProtocolError::InvalidPacketLength)?; let mut buffer = vec![0; payload_length]; self.read_exact(&mut buffer).await?; T::read_from_buffer(buffer).await } async fn read_packet_with_timeout( &mut self, timeout: Duration, ) -> Result { tokio::time::timeout(timeout, self.read_packet()).await? } } /// AsyncWriteRawPacket is the core piece of /// the write side of the protocol. It allows /// the user to write a specific packet to /// something that implements AsyncWrite. #[async_trait] pub trait AsyncWriteRawPacket { async fn write_packet( &mut self, packet: T, ) -> Result<(), ProtocolError>; async fn write_packet_with_timeout( &mut self, packet: T, timeout: Duration, ) -> Result<(), ProtocolError>; } #[async_trait] impl AsyncWriteRawPacket for W { async fn write_packet( &mut self, packet: T, ) -> Result<(), ProtocolError> { let packet_buffer = packet.write_to_buffer().await?; let raw_packet = RawPacket::new(packet.get_packet_id(), packet_buffer.into_boxed_slice()); let mut buffer: Cursor> = Cursor::new(Vec::new()); buffer.write_varint(raw_packet.id).await?; buffer.write_all(&raw_packet.data).await?; let inner = buffer.into_inner(); self.write_varint(inner.len()).await?; self.write_all(&inner).await?; Ok(()) } async fn write_packet_with_timeout( &mut self, packet: T, timeout: Duration, ) -> Result<(), ProtocolError> { tokio::time::timeout(timeout, self.write_packet(packet)).await? } } /// HandshakePacket is the first of two packets /// to be sent during a status check for /// ServerListPing. pub struct HandshakePacket { pub packet_id: usize, pub protocol_version: usize, pub server_address: String, pub server_port: u16, pub next_state: State, } impl HandshakePacket { pub fn new(protocol_version: usize, server_address: String, server_port: u16) -> Self { Self { packet_id: 0, protocol_version, server_address, server_port, next_state: State::Status, } } } #[async_trait] impl AsyncWriteToBuffer for HandshakePacket { async fn write_to_buffer(&self) -> Result, ProtocolError> { let mut buffer = Cursor::new(Vec::::new()); buffer.write_varint(self.protocol_version).await?; buffer.write_string(&self.server_address).await?; buffer.write_u16(self.server_port).await?; buffer.write_varint(self.next_state.into()).await?; Ok(buffer.into_inner()) } } impl PacketId for HandshakePacket { fn get_packet_id(&self) -> usize { self.packet_id } } /// RequestPacket is the second of two packets /// to be sent during a status check for /// ServerListPing. pub struct RequestPacket { pub packet_id: usize, } impl RequestPacket { pub fn new() -> Self { Self { packet_id: 0 } } } #[async_trait] impl AsyncWriteToBuffer for RequestPacket { async fn write_to_buffer(&self) -> Result, ProtocolError> { Ok(Vec::new()) } } impl PacketId for RequestPacket { fn get_packet_id(&self) -> usize { self.packet_id } } /// ResponsePacket is the response from the /// server to a status check for /// ServerListPing. pub struct ResponsePacket { #[allow(dead_code)] pub packet_id: usize, pub body: String, } impl ExpectedPacketId for ResponsePacket { fn get_expected_packet_id() -> usize { 0 } fn get_max_packet_length() -> usize { MAX_STATUS_RESPONSE_PACKET_LENGTH } } #[async_trait] impl AsyncReadFromBuffer for ResponsePacket { async fn read_from_buffer(buffer: Vec) -> Result { let mut reader = Cursor::new(buffer); let body = reader.read_string().await?; Ok(ResponsePacket { packet_id: 0, body }) } } pub struct PingPacket { pub packet_id: usize, pub payload: u64, } impl PingPacket { pub fn new(payload: u64) -> Self { Self { packet_id: 1, payload, } } } #[async_trait] impl AsyncWriteToBuffer for PingPacket { async fn write_to_buffer(&self) -> Result, ProtocolError> { let mut buffer = Cursor::new(Vec::::new()); buffer.write_u64(self.payload).await?; Ok(buffer.into_inner()) } } impl PacketId for PingPacket { fn get_packet_id(&self) -> usize { self.packet_id } } pub struct PongPacket { #[allow(dead_code)] pub packet_id: usize, pub payload: u64, } impl ExpectedPacketId for PongPacket { fn get_expected_packet_id() -> usize { 1 } fn get_max_packet_length() -> usize { MAX_PONG_PACKET_LENGTH } } #[async_trait] impl AsyncReadFromBuffer for PongPacket { async fn read_from_buffer(buffer: Vec) -> Result { let mut reader = Cursor::new(buffer); let payload = reader.read_u64().await?; Ok(PongPacket { packet_id: 0, payload, }) } } #[cfg(test)] mod tests { use super::*; use std::io::Cursor; #[tokio::test] async fn test_varint_roundtrip() { let test_cases = vec![ 0usize, 1, 127, 128, 255, 256, 16383, 16384, 2097151, 2097152, 268435455, ]; for value in test_cases { let mut buffer = Cursor::new(Vec::new()); buffer.write_varint(value).await.unwrap(); let mut reader = Cursor::new(buffer.into_inner()); let result = reader.read_varint().await.unwrap(); assert_eq!(value, result, "Varint roundtrip failed for {}", value); } } #[tokio::test] async fn test_varint_encoding() { // Test specific known encodings let cases = vec![ (0usize, vec![0x00]), (1, vec![0x01]), (127, vec![0x7f]), (128, vec![0x80, 0x01]), (255, vec![0xff, 0x01]), (25565, vec![0xdd, 0xc7, 0x01]), (2097151, vec![0xff, 0xff, 0x7f]), ]; for (value, expected) in cases { let mut buffer = Cursor::new(Vec::new()); buffer.write_varint(value).await.unwrap(); assert_eq!( buffer.into_inner(), expected, "Varint encoding failed for {}", value ); } } #[tokio::test] async fn test_string_roundtrip() { let test_cases = vec![ "", "hello", "localhost", "mc.example.com", "こんにちは", // Unicode ]; for s in test_cases { let mut buffer = Cursor::new(Vec::new()); buffer.write_string(s).await.unwrap(); let mut reader = Cursor::new(buffer.into_inner()); let result = reader.read_string().await.unwrap(); assert_eq!(s, result, "String roundtrip failed for {:?}", s); } } #[tokio::test] async fn test_handshake_packet_serialization() { let packet = HandshakePacket::new(578, "localhost".to_string(), 25565); let buffer = packet.write_to_buffer().await.unwrap(); // Verify the buffer contains expected data let mut reader = Cursor::new(buffer); // Protocol version (578 as varint) let protocol = reader.read_varint().await.unwrap(); assert_eq!(protocol, 578); // Server address let address = reader.read_string().await.unwrap(); assert_eq!(address, "localhost"); // Server port (big-endian u16) let port = reader.read_u16().await.unwrap(); assert_eq!(port, 25565); // Next state (1 for status) let state = reader.read_varint().await.unwrap(); assert_eq!(state, 1); } #[tokio::test] async fn test_request_packet_serialization() { let packet = RequestPacket::new(); let buffer = packet.write_to_buffer().await.unwrap(); // Request packet has no data assert!(buffer.is_empty()); } #[tokio::test] async fn test_ping_packet_serialization() { let packet = PingPacket::new(12345678); let buffer = packet.write_to_buffer().await.unwrap(); // Ping packet contains a u64 payload (8 bytes, big-endian) assert_eq!(buffer.len(), 8); let mut reader = Cursor::new(buffer); let payload = reader.read_u64().await.unwrap(); assert_eq!(payload, 12345678); } #[tokio::test] async fn test_response_packet_deserialization() { // Create a buffer with a JSON string let json = r#"{"version":{"name":"1.20.4","protocol":765}}"#; let mut buffer = Cursor::new(Vec::new()); buffer.write_string(json).await.unwrap(); let packet = ResponsePacket::read_from_buffer(buffer.into_inner()) .await .unwrap(); assert_eq!(packet.body, json); } #[tokio::test] async fn test_pong_packet_deserialization() { let payload: u64 = 987654321; let buffer = payload.to_be_bytes().to_vec(); let packet = PongPacket::read_from_buffer(buffer).await.unwrap(); assert_eq!(packet.payload, payload); } #[tokio::test] async fn test_invalid_varint() { // A varint with more than 5 continuation bytes is invalid let invalid = vec![0x80, 0x80, 0x80, 0x80, 0x80, 0x80]; let mut reader = Cursor::new(invalid); let result = reader.read_varint().await; assert!(matches!(result, Err(ProtocolError::InvalidVarInt))); } #[tokio::test] async fn test_oversized_string_length_is_rejected() { let mut writer = Cursor::new(Vec::new()); writer .write_varint(MAX_MINECRAFT_STRING_LENGTH + 1) .await .unwrap(); let mut reader = Cursor::new(writer.into_inner()); let result = reader.read_string().await; assert!(matches!(result, Err(ProtocolError::InvalidPacketLength))); } #[tokio::test] async fn test_oversized_packet_length_is_rejected() { let mut writer = Cursor::new(Vec::new()); writer .write_varint(MAX_STATUS_RESPONSE_PACKET_LENGTH + 1) .await .unwrap(); let mut reader = Cursor::new(writer.into_inner()); let result: Result = reader.read_packet().await; assert!(matches!(result, Err(ProtocolError::InvalidPacketLength))); } }