Fix buffer allocations in server pinging code (#5751)
This commit is contained in:
@@ -8,6 +8,33 @@ use tokio::net::ToSocketAddrs;
|
||||
use tokio::select;
|
||||
use url::Url;
|
||||
|
||||
const MAX_MINECRAFT_STATUS_STRING_LENGTH: usize = 32_767;
|
||||
const MAX_MODERN_STATUS_PACKET_LENGTH: usize =
|
||||
MAX_MINECRAFT_STATUS_STRING_LENGTH + 4;
|
||||
const MAX_LEGACY_STATUS_UTF16_LENGTH: usize =
|
||||
MAX_MINECRAFT_STATUS_STRING_LENGTH;
|
||||
|
||||
/// Ensures the length of a packet as stated by a server is not longer than a
|
||||
/// hard-coded limit.
|
||||
///
|
||||
/// For example, if we ping a server that says its status packet is 2 billion
|
||||
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
|
||||
/// will OOM our machine.
|
||||
///
|
||||
/// Implemented as a function so that you can easily find callsites and see
|
||||
/// where we accept unvalidated input from servers.
|
||||
fn cap_length(
|
||||
length: usize,
|
||||
max_length: usize,
|
||||
context: &'static str,
|
||||
) -> Result<usize> {
|
||||
if length > max_length {
|
||||
return Err(ErrorKind::InputError(context.to_string()).into());
|
||||
}
|
||||
|
||||
Ok(length)
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ServerStatus {
|
||||
@@ -128,13 +155,11 @@ mod modern {
|
||||
stream.write_all(&[0x01, 0x00]).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let packet_length = varint::read(stream).await?;
|
||||
if packet_length < 0 {
|
||||
return Err(ErrorKind::InputError(
|
||||
"Invalid status response packet length".to_string(),
|
||||
)
|
||||
.into());
|
||||
}
|
||||
let packet_length = cap_varint_length(
|
||||
varint::read(stream).await?,
|
||||
super::MAX_MODERN_STATUS_PACKET_LENGTH,
|
||||
"invalid status response packet length",
|
||||
)?;
|
||||
|
||||
let mut packet_stream = stream.take(packet_length as u64);
|
||||
let packet_id = varint::read(&mut packet_stream).await?;
|
||||
@@ -144,8 +169,12 @@ mod modern {
|
||||
)
|
||||
.into());
|
||||
}
|
||||
let response_length = varint::read(&mut packet_stream).await?;
|
||||
let mut json_response = vec![0_u8; response_length as usize];
|
||||
let response_length = cap_varint_length(
|
||||
varint::read(&mut packet_stream).await?,
|
||||
super::MAX_MINECRAFT_STATUS_STRING_LENGTH,
|
||||
"invalid status response length",
|
||||
)?;
|
||||
let mut json_response = vec![0_u8; response_length];
|
||||
packet_stream.read_exact(&mut json_response).await?;
|
||||
|
||||
if packet_stream.limit() > 0 {
|
||||
@@ -155,6 +184,27 @@ mod modern {
|
||||
Ok(serde_json::from_slice(&json_response)?)
|
||||
}
|
||||
|
||||
/// Ensures the length of a varint as stated by a server is not longer than a
|
||||
/// hard-coded limit.
|
||||
///
|
||||
/// For example, if we ping a server that says its status packet is 2 billion
|
||||
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
|
||||
/// will OOM our machine.
|
||||
///
|
||||
/// Implemented as a function so that you can easily find callsites and see
|
||||
/// where we accept unvalidated input from servers.
|
||||
fn cap_varint_length(
|
||||
length: i32,
|
||||
max_length: usize,
|
||||
context: &'static str,
|
||||
) -> crate::Result<usize> {
|
||||
if length < 0 {
|
||||
return Err(ErrorKind::InputError(context.to_string()).into());
|
||||
}
|
||||
|
||||
super::cap_length(length as usize, max_length, context)
|
||||
}
|
||||
|
||||
async fn ping(stream: &mut TcpStream) -> crate::Result<i64> {
|
||||
let ping_magic = chrono::Utc::now().timestamp_millis();
|
||||
|
||||
@@ -275,8 +325,17 @@ mod legacy {
|
||||
)));
|
||||
}
|
||||
|
||||
let data_length = stream.read_u16().await?;
|
||||
let mut data = vec![0u8; data_length as usize * 2];
|
||||
let data_length = super::cap_length(
|
||||
stream.read_u16().await? as usize,
|
||||
super::MAX_LEGACY_STATUS_UTF16_LENGTH,
|
||||
"invalid legacy status response length",
|
||||
)?;
|
||||
let data_byte_length = data_length.checked_mul(2).ok_or_else(|| {
|
||||
ErrorKind::InputError(
|
||||
"invalid legacy status response length".to_string(),
|
||||
)
|
||||
})?;
|
||||
let mut data = vec![0u8; data_byte_length];
|
||||
stream.read_exact(&mut data).await?;
|
||||
|
||||
drop(stream);
|
||||
|
||||
@@ -31,6 +31,27 @@ pub enum ProtocolError {
|
||||
Timeout(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
const MAX_MINECRAFT_STRING_LENGTH: usize = 32_767;
|
||||
const MAX_STATUS_RESPONSE_PACKET_LENGTH: usize = 32_771;
|
||||
const MAX_PONG_PACKET_LENGTH: usize = 9;
|
||||
|
||||
/// Ensures the length of a packet as stated by a server is not longer than a
|
||||
/// hard-coded limit.
|
||||
///
|
||||
/// For example, if we ping a server that says its status packet is 2 billion
|
||||
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
|
||||
/// will OOM our machine.
|
||||
///
|
||||
/// Implemented as a function so that you can easily find callsites and see
|
||||
/// where we accept unvalidated input from servers.
|
||||
fn cap_length(length: usize, max_length: usize) -> Result<usize, ProtocolError> {
|
||||
if length > max_length {
|
||||
return Err(ProtocolError::InvalidPacketLength);
|
||||
}
|
||||
|
||||
Ok(length)
|
||||
}
|
||||
|
||||
/// State represents the desired next state of the
|
||||
/// exchange.
|
||||
///
|
||||
@@ -98,7 +119,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncWireReadExt for R {
|
||||
}
|
||||
|
||||
async fn read_string(&mut self) -> Result<String, ProtocolError> {
|
||||
let length = self.read_varint().await?;
|
||||
let length = cap_length(self.read_varint().await?, MAX_MINECRAFT_STRING_LENGTH)?;
|
||||
|
||||
let mut buffer = vec![0; length];
|
||||
self.read_exact(&mut buffer).await?;
|
||||
@@ -157,6 +178,7 @@ pub trait PacketId {
|
||||
/// to generically get a packet's expected ID.
|
||||
pub trait ExpectedPacketId {
|
||||
fn get_expected_packet_id() -> usize;
|
||||
fn get_max_packet_length() -> usize;
|
||||
}
|
||||
|
||||
/// AsyncReadFromBuffer is used to allow
|
||||
@@ -196,7 +218,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
|
||||
async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
|
||||
&mut self,
|
||||
) -> Result<T, ProtocolError> {
|
||||
let length = self.read_varint().await?;
|
||||
let length = cap_length(self.read_varint().await?, T::get_max_packet_length())?;
|
||||
|
||||
if length == 0 {
|
||||
return Err(ProtocolError::InvalidPacketLength);
|
||||
@@ -213,7 +235,10 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
|
||||
});
|
||||
}
|
||||
|
||||
let mut buffer = vec![0; length - 1];
|
||||
let payload_length = length
|
||||
.checked_sub(1)
|
||||
.ok_or(ProtocolError::InvalidPacketLength)?;
|
||||
let mut buffer = vec![0; payload_length];
|
||||
self.read_exact(&mut buffer).await?;
|
||||
|
||||
T::read_from_buffer(buffer).await
|
||||
@@ -357,6 +382,10 @@ impl ExpectedPacketId for ResponsePacket {
|
||||
fn get_expected_packet_id() -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn get_max_packet_length() -> usize {
|
||||
MAX_STATUS_RESPONSE_PACKET_LENGTH
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -411,6 +440,10 @@ impl ExpectedPacketId for PongPacket {
|
||||
fn get_expected_packet_id() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn get_max_packet_length() -> usize {
|
||||
MAX_PONG_PACKET_LENGTH
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -573,4 +606,32 @@ mod tests {
|
||||
let result = reader.read_varint().await;
|
||||
assert!(matches!(result, Err(ProtocolError::InvalidVarInt)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_oversized_string_length_is_rejected() {
|
||||
let mut writer = Cursor::new(Vec::new());
|
||||
writer
|
||||
.write_varint(MAX_MINECRAFT_STRING_LENGTH + 1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut reader = Cursor::new(writer.into_inner());
|
||||
let result = reader.read_string().await;
|
||||
|
||||
assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_oversized_packet_length_is_rejected() {
|
||||
let mut writer = Cursor::new(Vec::new());
|
||||
writer
|
||||
.write_varint(MAX_STATUS_RESPONSE_PACKET_LENGTH + 1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut reader = Cursor::new(writer.into_inner());
|
||||
let result: Result<ResponsePacket, ProtocolError> = reader.read_packet().await;
|
||||
|
||||
assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user