Fix buffer allocations in server pinging code (#5751)

This commit is contained in:
aecsocket
2026-04-04 18:22:40 +01:00
committed by GitHub
parent 4a7525d0a1
commit e5f600ddd7
2 changed files with 134 additions and 14 deletions

View File

@@ -8,6 +8,33 @@ use tokio::net::ToSocketAddrs;
use tokio::select; use tokio::select;
use url::Url; use url::Url;
const MAX_MINECRAFT_STATUS_STRING_LENGTH: usize = 32_767;
const MAX_MODERN_STATUS_PACKET_LENGTH: usize =
MAX_MINECRAFT_STATUS_STRING_LENGTH + 4;
const MAX_LEGACY_STATUS_UTF16_LENGTH: usize =
MAX_MINECRAFT_STATUS_STRING_LENGTH;
/// Ensures the length of a packet as stated by a server is not longer than a
/// hard-coded limit.
///
/// For example, if we ping a server that says its status packet is 2 billion
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
/// will OOM our machine.
///
/// Implemented as a function so that you can easily find callsites and see
/// where we accept unvalidated input from servers.
fn cap_length(
length: usize,
max_length: usize,
context: &'static str,
) -> Result<usize> {
if length > max_length {
return Err(ErrorKind::InputError(context.to_string()).into());
}
Ok(length)
}
#[derive(Deserialize, Serialize, Debug, Clone)] #[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct ServerStatus { pub struct ServerStatus {
@@ -128,13 +155,11 @@ mod modern {
stream.write_all(&[0x01, 0x00]).await?; stream.write_all(&[0x01, 0x00]).await?;
stream.flush().await?; stream.flush().await?;
let packet_length = varint::read(stream).await?; let packet_length = cap_varint_length(
if packet_length < 0 { varint::read(stream).await?,
return Err(ErrorKind::InputError( super::MAX_MODERN_STATUS_PACKET_LENGTH,
"Invalid status response packet length".to_string(), "invalid status response packet length",
) )?;
.into());
}
let mut packet_stream = stream.take(packet_length as u64); let mut packet_stream = stream.take(packet_length as u64);
let packet_id = varint::read(&mut packet_stream).await?; let packet_id = varint::read(&mut packet_stream).await?;
@@ -144,8 +169,12 @@ mod modern {
) )
.into()); .into());
} }
let response_length = varint::read(&mut packet_stream).await?; let response_length = cap_varint_length(
let mut json_response = vec![0_u8; response_length as usize]; varint::read(&mut packet_stream).await?,
super::MAX_MINECRAFT_STATUS_STRING_LENGTH,
"invalid status response length",
)?;
let mut json_response = vec![0_u8; response_length];
packet_stream.read_exact(&mut json_response).await?; packet_stream.read_exact(&mut json_response).await?;
if packet_stream.limit() > 0 { if packet_stream.limit() > 0 {
@@ -155,6 +184,27 @@ mod modern {
Ok(serde_json::from_slice(&json_response)?) Ok(serde_json::from_slice(&json_response)?)
} }
/// Ensures the length of a varint as stated by a server is not longer than a
/// hard-coded limit.
///
/// For example, if we ping a server that says its status packet is 2 billion
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
/// will OOM our machine.
///
/// Implemented as a function so that you can easily find callsites and see
/// where we accept unvalidated input from servers.
fn cap_varint_length(
length: i32,
max_length: usize,
context: &'static str,
) -> crate::Result<usize> {
if length < 0 {
return Err(ErrorKind::InputError(context.to_string()).into());
}
super::cap_length(length as usize, max_length, context)
}
async fn ping(stream: &mut TcpStream) -> crate::Result<i64> { async fn ping(stream: &mut TcpStream) -> crate::Result<i64> {
let ping_magic = chrono::Utc::now().timestamp_millis(); let ping_magic = chrono::Utc::now().timestamp_millis();
@@ -275,8 +325,17 @@ mod legacy {
))); )));
} }
let data_length = stream.read_u16().await?; let data_length = super::cap_length(
let mut data = vec![0u8; data_length as usize * 2]; stream.read_u16().await? as usize,
super::MAX_LEGACY_STATUS_UTF16_LENGTH,
"invalid legacy status response length",
)?;
let data_byte_length = data_length.checked_mul(2).ok_or_else(|| {
ErrorKind::InputError(
"invalid legacy status response length".to_string(),
)
})?;
let mut data = vec![0u8; data_byte_length];
stream.read_exact(&mut data).await?; stream.read_exact(&mut data).await?;
drop(stream); drop(stream);

View File

@@ -31,6 +31,27 @@ pub enum ProtocolError {
Timeout(#[from] tokio::time::error::Elapsed), Timeout(#[from] tokio::time::error::Elapsed),
} }
const MAX_MINECRAFT_STRING_LENGTH: usize = 32_767;
const MAX_STATUS_RESPONSE_PACKET_LENGTH: usize = 32_771;
const MAX_PONG_PACKET_LENGTH: usize = 9;
/// Ensures the length of a packet as stated by a server is not longer than a
/// hard-coded limit.
///
/// For example, if we ping a server that says its status packet is 2 billion
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
/// will OOM our machine.
///
/// Implemented as a function so that you can easily find callsites and see
/// where we accept unvalidated input from servers.
fn cap_length(length: usize, max_length: usize) -> Result<usize, ProtocolError> {
if length > max_length {
return Err(ProtocolError::InvalidPacketLength);
}
Ok(length)
}
/// State represents the desired next state of the /// State represents the desired next state of the
/// exchange. /// exchange.
/// ///
@@ -98,7 +119,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncWireReadExt for R {
} }
async fn read_string(&mut self) -> Result<String, ProtocolError> { async fn read_string(&mut self) -> Result<String, ProtocolError> {
let length = self.read_varint().await?; let length = cap_length(self.read_varint().await?, MAX_MINECRAFT_STRING_LENGTH)?;
let mut buffer = vec![0; length]; let mut buffer = vec![0; length];
self.read_exact(&mut buffer).await?; self.read_exact(&mut buffer).await?;
@@ -157,6 +178,7 @@ pub trait PacketId {
/// to generically get a packet's expected ID. /// to generically get a packet's expected ID.
pub trait ExpectedPacketId { pub trait ExpectedPacketId {
fn get_expected_packet_id() -> usize; fn get_expected_packet_id() -> usize;
fn get_max_packet_length() -> usize;
} }
/// AsyncReadFromBuffer is used to allow /// AsyncReadFromBuffer is used to allow
@@ -196,7 +218,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>( async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
&mut self, &mut self,
) -> Result<T, ProtocolError> { ) -> Result<T, ProtocolError> {
let length = self.read_varint().await?; let length = cap_length(self.read_varint().await?, T::get_max_packet_length())?;
if length == 0 { if length == 0 {
return Err(ProtocolError::InvalidPacketLength); return Err(ProtocolError::InvalidPacketLength);
@@ -213,7 +235,10 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
}); });
} }
let mut buffer = vec![0; length - 1]; let payload_length = length
.checked_sub(1)
.ok_or(ProtocolError::InvalidPacketLength)?;
let mut buffer = vec![0; payload_length];
self.read_exact(&mut buffer).await?; self.read_exact(&mut buffer).await?;
T::read_from_buffer(buffer).await T::read_from_buffer(buffer).await
@@ -357,6 +382,10 @@ impl ExpectedPacketId for ResponsePacket {
fn get_expected_packet_id() -> usize { fn get_expected_packet_id() -> usize {
0 0
} }
fn get_max_packet_length() -> usize {
MAX_STATUS_RESPONSE_PACKET_LENGTH
}
} }
#[async_trait] #[async_trait]
@@ -411,6 +440,10 @@ impl ExpectedPacketId for PongPacket {
fn get_expected_packet_id() -> usize { fn get_expected_packet_id() -> usize {
1 1
} }
fn get_max_packet_length() -> usize {
MAX_PONG_PACKET_LENGTH
}
} }
#[async_trait] #[async_trait]
@@ -573,4 +606,32 @@ mod tests {
let result = reader.read_varint().await; let result = reader.read_varint().await;
assert!(matches!(result, Err(ProtocolError::InvalidVarInt))); assert!(matches!(result, Err(ProtocolError::InvalidVarInt)));
} }
#[tokio::test]
async fn test_oversized_string_length_is_rejected() {
let mut writer = Cursor::new(Vec::new());
writer
.write_varint(MAX_MINECRAFT_STRING_LENGTH + 1)
.await
.unwrap();
let mut reader = Cursor::new(writer.into_inner());
let result = reader.read_string().await;
assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
}
#[tokio::test]
async fn test_oversized_packet_length_is_rejected() {
let mut writer = Cursor::new(Vec::new());
writer
.write_varint(MAX_STATUS_RESPONSE_PACKET_LENGTH + 1)
.await
.unwrap();
let mut reader = Cursor::new(writer.into_inner());
let result: Result<ResponsePacket, ProtocolError> = reader.read_packet().await;
assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
}
} }