Fix buffer allocations in server pinging code (#5751)
This commit is contained in:
@@ -31,6 +31,27 @@ pub enum ProtocolError {
|
||||
Timeout(#[from] tokio::time::error::Elapsed),
|
||||
}
|
||||
|
||||
const MAX_MINECRAFT_STRING_LENGTH: usize = 32_767;
|
||||
const MAX_STATUS_RESPONSE_PACKET_LENGTH: usize = 32_771;
|
||||
const MAX_PONG_PACKET_LENGTH: usize = 9;
|
||||
|
||||
/// Ensures the length of a packet as stated by a server is not longer than a
|
||||
/// hard-coded limit.
|
||||
///
|
||||
/// For example, if we ping a server that says its status packet is 2 billion
|
||||
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
|
||||
/// will OOM our machine.
|
||||
///
|
||||
/// Implemented as a function so that you can easily find callsites and see
|
||||
/// where we accept unvalidated input from servers.
|
||||
fn cap_length(length: usize, max_length: usize) -> Result<usize, ProtocolError> {
|
||||
if length > max_length {
|
||||
return Err(ProtocolError::InvalidPacketLength);
|
||||
}
|
||||
|
||||
Ok(length)
|
||||
}
|
||||
|
||||
/// State represents the desired next state of the
|
||||
/// exchange.
|
||||
///
|
||||
@@ -98,7 +119,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncWireReadExt for R {
|
||||
}
|
||||
|
||||
async fn read_string(&mut self) -> Result<String, ProtocolError> {
|
||||
let length = self.read_varint().await?;
|
||||
let length = cap_length(self.read_varint().await?, MAX_MINECRAFT_STRING_LENGTH)?;
|
||||
|
||||
let mut buffer = vec![0; length];
|
||||
self.read_exact(&mut buffer).await?;
|
||||
@@ -157,6 +178,7 @@ pub trait PacketId {
|
||||
/// to generically get a packet's expected ID.
|
||||
pub trait ExpectedPacketId {
|
||||
fn get_expected_packet_id() -> usize;
|
||||
fn get_max_packet_length() -> usize;
|
||||
}
|
||||
|
||||
/// AsyncReadFromBuffer is used to allow
|
||||
@@ -196,7 +218,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
|
||||
async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
|
||||
&mut self,
|
||||
) -> Result<T, ProtocolError> {
|
||||
let length = self.read_varint().await?;
|
||||
let length = cap_length(self.read_varint().await?, T::get_max_packet_length())?;
|
||||
|
||||
if length == 0 {
|
||||
return Err(ProtocolError::InvalidPacketLength);
|
||||
@@ -213,7 +235,10 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
|
||||
});
|
||||
}
|
||||
|
||||
let mut buffer = vec![0; length - 1];
|
||||
let payload_length = length
|
||||
.checked_sub(1)
|
||||
.ok_or(ProtocolError::InvalidPacketLength)?;
|
||||
let mut buffer = vec![0; payload_length];
|
||||
self.read_exact(&mut buffer).await?;
|
||||
|
||||
T::read_from_buffer(buffer).await
|
||||
@@ -357,6 +382,10 @@ impl ExpectedPacketId for ResponsePacket {
|
||||
fn get_expected_packet_id() -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn get_max_packet_length() -> usize {
|
||||
MAX_STATUS_RESPONSE_PACKET_LENGTH
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -411,6 +440,10 @@ impl ExpectedPacketId for PongPacket {
|
||||
fn get_expected_packet_id() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn get_max_packet_length() -> usize {
|
||||
MAX_PONG_PACKET_LENGTH
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -573,4 +606,32 @@ mod tests {
|
||||
let result = reader.read_varint().await;
|
||||
assert!(matches!(result, Err(ProtocolError::InvalidVarInt)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_oversized_string_length_is_rejected() {
|
||||
let mut writer = Cursor::new(Vec::new());
|
||||
writer
|
||||
.write_varint(MAX_MINECRAFT_STRING_LENGTH + 1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut reader = Cursor::new(writer.into_inner());
|
||||
let result = reader.read_string().await;
|
||||
|
||||
assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_oversized_packet_length_is_rejected() {
|
||||
let mut writer = Cursor::new(Vec::new());
|
||||
writer
|
||||
.write_varint(MAX_STATUS_RESPONSE_PACKET_LENGTH + 1)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut reader = Cursor::new(writer.into_inner());
|
||||
let result: Result<ResponsePacket, ProtocolError> = reader.read_packet().await;
|
||||
|
||||
assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user