Fix buffer allocations in server pinging code (#5751)
This commit is contained in:
@@ -8,6 +8,33 @@ use tokio::net::ToSocketAddrs;
|
|||||||
use tokio::select;
|
use tokio::select;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
|
const MAX_MINECRAFT_STATUS_STRING_LENGTH: usize = 32_767;
|
||||||
|
const MAX_MODERN_STATUS_PACKET_LENGTH: usize =
|
||||||
|
MAX_MINECRAFT_STATUS_STRING_LENGTH + 4;
|
||||||
|
const MAX_LEGACY_STATUS_UTF16_LENGTH: usize =
|
||||||
|
MAX_MINECRAFT_STATUS_STRING_LENGTH;
|
||||||
|
|
||||||
|
/// Ensures the length of a packet as stated by a server is not longer than a
|
||||||
|
/// hard-coded limit.
|
||||||
|
///
|
||||||
|
/// For example, if we ping a server that says its status packet is 2 billion
|
||||||
|
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
|
||||||
|
/// will OOM our machine.
|
||||||
|
///
|
||||||
|
/// Implemented as a function so that you can easily find callsites and see
|
||||||
|
/// where we accept unvalidated input from servers.
|
||||||
|
fn cap_length(
|
||||||
|
length: usize,
|
||||||
|
max_length: usize,
|
||||||
|
context: &'static str,
|
||||||
|
) -> Result<usize> {
|
||||||
|
if length > max_length {
|
||||||
|
return Err(ErrorKind::InputError(context.to_string()).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(length)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct ServerStatus {
|
pub struct ServerStatus {
|
||||||
@@ -128,13 +155,11 @@ mod modern {
|
|||||||
stream.write_all(&[0x01, 0x00]).await?;
|
stream.write_all(&[0x01, 0x00]).await?;
|
||||||
stream.flush().await?;
|
stream.flush().await?;
|
||||||
|
|
||||||
let packet_length = varint::read(stream).await?;
|
let packet_length = cap_varint_length(
|
||||||
if packet_length < 0 {
|
varint::read(stream).await?,
|
||||||
return Err(ErrorKind::InputError(
|
super::MAX_MODERN_STATUS_PACKET_LENGTH,
|
||||||
"Invalid status response packet length".to_string(),
|
"invalid status response packet length",
|
||||||
)
|
)?;
|
||||||
.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut packet_stream = stream.take(packet_length as u64);
|
let mut packet_stream = stream.take(packet_length as u64);
|
||||||
let packet_id = varint::read(&mut packet_stream).await?;
|
let packet_id = varint::read(&mut packet_stream).await?;
|
||||||
@@ -144,8 +169,12 @@ mod modern {
|
|||||||
)
|
)
|
||||||
.into());
|
.into());
|
||||||
}
|
}
|
||||||
let response_length = varint::read(&mut packet_stream).await?;
|
let response_length = cap_varint_length(
|
||||||
let mut json_response = vec![0_u8; response_length as usize];
|
varint::read(&mut packet_stream).await?,
|
||||||
|
super::MAX_MINECRAFT_STATUS_STRING_LENGTH,
|
||||||
|
"invalid status response length",
|
||||||
|
)?;
|
||||||
|
let mut json_response = vec![0_u8; response_length];
|
||||||
packet_stream.read_exact(&mut json_response).await?;
|
packet_stream.read_exact(&mut json_response).await?;
|
||||||
|
|
||||||
if packet_stream.limit() > 0 {
|
if packet_stream.limit() > 0 {
|
||||||
@@ -155,6 +184,27 @@ mod modern {
|
|||||||
Ok(serde_json::from_slice(&json_response)?)
|
Ok(serde_json::from_slice(&json_response)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Ensures the length of a varint as stated by a server is not longer than a
|
||||||
|
/// hard-coded limit.
|
||||||
|
///
|
||||||
|
/// For example, if we ping a server that says its status packet is 2 billion
|
||||||
|
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
|
||||||
|
/// will OOM our machine.
|
||||||
|
///
|
||||||
|
/// Implemented as a function so that you can easily find callsites and see
|
||||||
|
/// where we accept unvalidated input from servers.
|
||||||
|
fn cap_varint_length(
|
||||||
|
length: i32,
|
||||||
|
max_length: usize,
|
||||||
|
context: &'static str,
|
||||||
|
) -> crate::Result<usize> {
|
||||||
|
if length < 0 {
|
||||||
|
return Err(ErrorKind::InputError(context.to_string()).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
super::cap_length(length as usize, max_length, context)
|
||||||
|
}
|
||||||
|
|
||||||
async fn ping(stream: &mut TcpStream) -> crate::Result<i64> {
|
async fn ping(stream: &mut TcpStream) -> crate::Result<i64> {
|
||||||
let ping_magic = chrono::Utc::now().timestamp_millis();
|
let ping_magic = chrono::Utc::now().timestamp_millis();
|
||||||
|
|
||||||
@@ -275,8 +325,17 @@ mod legacy {
|
|||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let data_length = stream.read_u16().await?;
|
let data_length = super::cap_length(
|
||||||
let mut data = vec![0u8; data_length as usize * 2];
|
stream.read_u16().await? as usize,
|
||||||
|
super::MAX_LEGACY_STATUS_UTF16_LENGTH,
|
||||||
|
"invalid legacy status response length",
|
||||||
|
)?;
|
||||||
|
let data_byte_length = data_length.checked_mul(2).ok_or_else(|| {
|
||||||
|
ErrorKind::InputError(
|
||||||
|
"invalid legacy status response length".to_string(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let mut data = vec![0u8; data_byte_length];
|
||||||
stream.read_exact(&mut data).await?;
|
stream.read_exact(&mut data).await?;
|
||||||
|
|
||||||
drop(stream);
|
drop(stream);
|
||||||
|
|||||||
@@ -31,6 +31,27 @@ pub enum ProtocolError {
|
|||||||
Timeout(#[from] tokio::time::error::Elapsed),
|
Timeout(#[from] tokio::time::error::Elapsed),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const MAX_MINECRAFT_STRING_LENGTH: usize = 32_767;
|
||||||
|
const MAX_STATUS_RESPONSE_PACKET_LENGTH: usize = 32_771;
|
||||||
|
const MAX_PONG_PACKET_LENGTH: usize = 9;
|
||||||
|
|
||||||
|
/// Ensures the length of a packet as stated by a server is not longer than a
|
||||||
|
/// hard-coded limit.
|
||||||
|
///
|
||||||
|
/// For example, if we ping a server that says its status packet is 2 billion
|
||||||
|
/// bytes long, we don't try to allocate a 2 billion byte buffer, since that
|
||||||
|
/// will OOM our machine.
|
||||||
|
///
|
||||||
|
/// Implemented as a function so that you can easily find callsites and see
|
||||||
|
/// where we accept unvalidated input from servers.
|
||||||
|
fn cap_length(length: usize, max_length: usize) -> Result<usize, ProtocolError> {
|
||||||
|
if length > max_length {
|
||||||
|
return Err(ProtocolError::InvalidPacketLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(length)
|
||||||
|
}
|
||||||
|
|
||||||
/// State represents the desired next state of the
|
/// State represents the desired next state of the
|
||||||
/// exchange.
|
/// exchange.
|
||||||
///
|
///
|
||||||
@@ -98,7 +119,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncWireReadExt for R {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn read_string(&mut self) -> Result<String, ProtocolError> {
|
async fn read_string(&mut self) -> Result<String, ProtocolError> {
|
||||||
let length = self.read_varint().await?;
|
let length = cap_length(self.read_varint().await?, MAX_MINECRAFT_STRING_LENGTH)?;
|
||||||
|
|
||||||
let mut buffer = vec![0; length];
|
let mut buffer = vec![0; length];
|
||||||
self.read_exact(&mut buffer).await?;
|
self.read_exact(&mut buffer).await?;
|
||||||
@@ -157,6 +178,7 @@ pub trait PacketId {
|
|||||||
/// to generically get a packet's expected ID.
|
/// to generically get a packet's expected ID.
|
||||||
pub trait ExpectedPacketId {
|
pub trait ExpectedPacketId {
|
||||||
fn get_expected_packet_id() -> usize;
|
fn get_expected_packet_id() -> usize;
|
||||||
|
fn get_max_packet_length() -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// AsyncReadFromBuffer is used to allow
|
/// AsyncReadFromBuffer is used to allow
|
||||||
@@ -196,7 +218,7 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
|
|||||||
async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
|
async fn read_packet<T: ExpectedPacketId + AsyncReadFromBuffer + Send + Sync>(
|
||||||
&mut self,
|
&mut self,
|
||||||
) -> Result<T, ProtocolError> {
|
) -> Result<T, ProtocolError> {
|
||||||
let length = self.read_varint().await?;
|
let length = cap_length(self.read_varint().await?, T::get_max_packet_length())?;
|
||||||
|
|
||||||
if length == 0 {
|
if length == 0 {
|
||||||
return Err(ProtocolError::InvalidPacketLength);
|
return Err(ProtocolError::InvalidPacketLength);
|
||||||
@@ -213,7 +235,10 @@ impl<R: AsyncRead + Unpin + Send + Sync> AsyncReadRawPacket for R {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut buffer = vec![0; length - 1];
|
let payload_length = length
|
||||||
|
.checked_sub(1)
|
||||||
|
.ok_or(ProtocolError::InvalidPacketLength)?;
|
||||||
|
let mut buffer = vec![0; payload_length];
|
||||||
self.read_exact(&mut buffer).await?;
|
self.read_exact(&mut buffer).await?;
|
||||||
|
|
||||||
T::read_from_buffer(buffer).await
|
T::read_from_buffer(buffer).await
|
||||||
@@ -357,6 +382,10 @@ impl ExpectedPacketId for ResponsePacket {
|
|||||||
fn get_expected_packet_id() -> usize {
|
fn get_expected_packet_id() -> usize {
|
||||||
0
|
0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_max_packet_length() -> usize {
|
||||||
|
MAX_STATUS_RESPONSE_PACKET_LENGTH
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -411,6 +440,10 @@ impl ExpectedPacketId for PongPacket {
|
|||||||
fn get_expected_packet_id() -> usize {
|
fn get_expected_packet_id() -> usize {
|
||||||
1
|
1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_max_packet_length() -> usize {
|
||||||
|
MAX_PONG_PACKET_LENGTH
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -573,4 +606,32 @@ mod tests {
|
|||||||
let result = reader.read_varint().await;
|
let result = reader.read_varint().await;
|
||||||
assert!(matches!(result, Err(ProtocolError::InvalidVarInt)));
|
assert!(matches!(result, Err(ProtocolError::InvalidVarInt)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_oversized_string_length_is_rejected() {
|
||||||
|
let mut writer = Cursor::new(Vec::new());
|
||||||
|
writer
|
||||||
|
.write_varint(MAX_MINECRAFT_STRING_LENGTH + 1)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut reader = Cursor::new(writer.into_inner());
|
||||||
|
let result = reader.read_string().await;
|
||||||
|
|
||||||
|
assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_oversized_packet_length_is_rejected() {
|
||||||
|
let mut writer = Cursor::new(Vec::new());
|
||||||
|
writer
|
||||||
|
.write_varint(MAX_STATUS_RESPONSE_PACKET_LENGTH + 1)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut reader = Cursor::new(writer.into_inner());
|
||||||
|
let result: Result<ResponsePacket, ProtocolError> = reader.read_packet().await;
|
||||||
|
|
||||||
|
assert!(matches!(result, Err(ProtocolError::InvalidPacketLength)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user