Fix buffer allocations in server pinging code (#5751)

This commit is contained in:
aecsocket
2026-04-04 18:22:40 +01:00
committed by GitHub
parent 4a7525d0a1
commit e5f600ddd7
2 changed files with 134 additions and 14 deletions

View File

@@ -31,6 +31,27 @@ pub enum ProtocolError {
Timeout(#[from] tokio::time::error::Elapsed),
}
const MAX_MINECRAFT_STRING_LENGTH: usize = 32_767;
const MAX_STATUS_RESPONSE_PACKET_LENGTH: usize = 32_771;
const MAX_PONG_PACKET_LENGTH: usize = 9;
/// Ensures the length of a packet as stated by a server is not longer than a
/// hard-coded limit.
///
/// For example, if we ping a server that says its status packet is 2 billion
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
/// will OOM our machine.
///
/// Implemented as a function so that you can easily find callsites and see
/// where we accept unvalidated input from servers.
fn cap_length(length: usize, max_length: usize) -> Result<usize, ProtocolError> {
if length > max_length {
return Err(ProtocolError::InvalidPacketLength);
}
Ok(length)
}
/// State represents the desired next state of the
/// exchange.
///
@@ -98,7 +119,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncWireReadExt for R {
}
async fn read_string(&mut self) -> Result<String, ProtocolError> {
let length = self.read_varint().await?;
let length = cap_length(self.read_varint().await?, MAX_MINECRAFT_STRING_LENGTH)?;
let mut buffer = vec![0; length];
self.read_exact(&mut buffer).await?;
@@ -157,6 +178,7 @@ pub trait PacketId {
/// to generically get a packet's expected ID.
pub trait ExpectedPacketId {
fn get_expected_packet_id() -> usize;
fn get_max_packet_length() -> usize;
}
/// AsyncReadFromBuffer is used to allow
@@ -196,7 +218,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
&mut self,
) -> Result<T, ProtocolError> {
let length = self.read_varint().await?;
let length = cap_length(self.read_varint().await?, T::get_max_packet_length())?;
if length == 0 {
return Err(ProtocolError::InvalidPacketLength);
@@ -213,7 +235,10 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
});
}
let mut buffer = vec![0; length - 1];
let payload_length = length
.checked_sub(1)
.ok_or(ProtocolError::InvalidPacketLength)?;
let mut buffer = vec![0; payload_length];
self.read_exact(&mut buffer).await?;
T::read_from_buffer(buffer).await
@@ -357,6 +382,10 @@ impl ExpectedPacketId for ResponsePacket {
fn get_expected_packet_id() -> usize {
0
}
fn get_max_packet_length() -> usize {
MAX_STATUS_RESPONSE_PACKET_LENGTH
}
}
#[async_trait]
@@ -411,6 +440,10 @@ impl ExpectedPacketId for PongPacket {
fn get_expected_packet_id() -> usize {
1
}
fn get_max_packet_length() -> usize {
MAX_PONG_PACKET_LENGTH
}
}
#[async_trait]
@@ -573,4 +606,32 @@ mod tests {
let result = reader.read_varint().await;
assert!(matches!(result, Err(ProtocolError::InvalidVarInt)));
}
#[tokio::test]
async fn test_oversized_string_length_is_rejected() {
let mut writer = Cursor::new(Vec::new());
writer
.write_varint(MAX_MINECRAFT_STRING_LENGTH + 1)
.await
.unwrap();
let mut reader = Cursor::new(writer.into_inner());
let result = reader.read_string().await;
assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
}
#[tokio::test]
async fn test_oversized_packet_length_is_rejected() {
let mut writer = Cursor::new(Vec::new());
writer
.write_varint(MAX_STATUS_RESPONSE_PACKET_LENGTH + 1)
.await
.unwrap();
let mut reader = Cursor::new(writer.into_inner());
let result: Result<ResponsePacket, ProtocolError> = reader.read_packet().await;
assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
}
}