use crate::env::ENV; use super::models::DatabaseError; use ariadne::ids::base62_impl::{parse_base62, to_base62}; use chrono::{TimeZone, Utc}; use dashmap::DashMap; use deadpool_redis::{Config, Runtime}; use futures::TryStreamExt; use futures::future::Either; use futures::stream::{FuturesUnordered, StreamExt}; use prometheus::{IntGauge, Registry}; use redis::ToRedisArgs; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt::{Debug, Display}; use std::future::Future; use std::hash::Hash; use std::sync::Arc; use std::time::Duration; use tracing::{Instrument, info, info_span}; use util::{cmd, redis_pipe}; pub mod util; const DEFAULT_EXPIRY: i64 = 60 * 60 * 12; // 12 hours const ACTUAL_EXPIRY: i64 = 60 * 30; // 30 minutes // Bound how many commands we send in a single Redis pipeline. The multiplexed // connection's BytesMut write buffer keeps its peak capacity for the life of // the connection, so larger pipelines cause higher steady-state RSS. const PIPELINE_CHUNK_SIZE: usize = 25; // Bound how many keys we send in a single MGET. Each MGET response must fit // into the connection's read buffer, which also retains its peak capacity. At // ~1 MB per cached value, 32 keys caps any single response at ~32 MB. const MGET_CHUNK_SIZE: usize = 32; // How long a pooled Redis connection lives before being recycled, regardless // of activity. Forced recycling is the only way to release the per-connection // BytesMut peak capacity that builds up under steady load. const REDIS_MAX_CONN_AGE: Duration = Duration::from_secs(120); #[derive(Clone)] pub struct RedisPool { pub url: String, pub pool: deadpool_redis::Pool, cache_list: Arc>, meta_namespace: Arc, } pub struct RedisConnection { pub connection: deadpool_redis::Connection, meta_namespace: Arc, } impl RedisPool { // initiate a new redis pool // testing pool uses a hashmap to mimic redis behaviour for very small data sizes (ie: tests) // PANICS: production pool will panic if redis url is not set pub fn new(meta_namespace: impl Into>) -> Self { let wait_timeout = Duration::from_millis(ENV.REDIS_WAIT_TIMEOUT_MS); let url = &ENV.REDIS_URL; let pool = Config::from_url(url.clone()) .builder() .expect("Error building Redis pool") .max_size(ENV.REDIS_MAX_CONNECTIONS as usize) .wait_timeout(Some(wait_timeout)) .runtime(Runtime::Tokio1) .build() .expect("Redis connection failed"); let pool = RedisPool { url: url.clone(), pool, cache_list: Arc::new(DashMap::with_capacity(2048)), meta_namespace: meta_namespace.into(), }; let redis_min_connections = ENV.REDIS_MIN_CONNECTIONS; let spawn_min_connections = (0..redis_min_connections) .map(|_| { let pool = pool.clone(); tokio::spawn(async move { pool.pool.get().await }) }) .collect::>(); tokio::spawn({ let pool = pool.clone(); async move { // collect the connections into a buffer while we're spawning them, // to make sure that we're not `get`ing any connections we previously took let _connections = spawn_min_connections.try_collect::>().await; info!( pool_status = ?pool.pool.status(), "Finished getting {redis_min_connections} initial Redis connections" ); } }); let interval = Duration::from_secs(30); let max_idle = Duration::from_secs(5 * 60); // 5 minutes let pool_ref = pool.clone(); tokio::spawn(async move { loop { tokio::time::sleep(interval).await; pool_ref.pool.retain(|_, metrics| { // Drop connections that have been idle too long, OR that // are older than REDIS_MAX_CONN_AGE regardless of use. // The age-based recycle is what releases the per-connection // BytesMut peak capacity under steady traffic. metrics.last_used() < max_idle && metrics.created.elapsed() < REDIS_MAX_CONN_AGE }); } }); pool } pub async fn register_and_set_metrics( &self, registry: &Registry, ) -> Result<(), prometheus::Error> { let redis_max_size = IntGauge::new( "labrinth_redis_pool_max_size", "Maximum size of Redis pool", )?; let redis_size = IntGauge::new( "labrinth_redis_pool_size", "Current size of Redis pool", )?; let redis_available = IntGauge::new( "labrinth_redis_pool_available", "Available connections in Redis pool", )?; let redis_waiting = IntGauge::new( "labrinth_redis_pool_waiting", "Number of futures waiting for a Redis connection", )?; registry.register(Box::new(redis_max_size.clone()))?; registry.register(Box::new(redis_size.clone()))?; registry.register(Box::new(redis_available.clone()))?; registry.register(Box::new(redis_waiting.clone()))?; let redis_pool_ref = self.pool.clone(); tokio::spawn(async move { loop { let status = redis_pool_ref.status(); redis_max_size.set(status.max_size as i64); redis_size.set(status.size as i64); redis_available.set(status.available as i64); redis_waiting.set(status.waiting as i64); tokio::time::sleep(Duration::from_secs(5)).await; } }); Ok(()) } #[tracing::instrument(skip(self))] pub async fn connect(&self) -> Result { Ok(RedisConnection { connection: self.pool.get().await?, meta_namespace: self.meta_namespace.clone(), }) } #[tracing::instrument(skip(self, closure))] pub async fn get_cached_keys( &self, namespace: &str, keys: &[K], closure: F, ) -> Result, DatabaseError> where F: FnOnce(Vec) -> Fut, Fut: Future, DatabaseError>>, T: Serialize + DeserializeOwned, K: Display + Hash + Eq + PartialEq + Clone + DeserializeOwned + Serialize + Debug, { Ok(self .get_cached_keys_raw(namespace, keys, closure) .await? .into_iter() .map(|x| x.1) .collect()) } #[tracing::instrument(skip(self, closure))] pub async fn get_cached_keys_raw( &self, namespace: &str, keys: &[K], closure: F, ) -> Result, DatabaseError> where F: FnOnce(Vec) -> Fut, Fut: Future, DatabaseError>>, T: Serialize + DeserializeOwned, K: Display + Hash + Eq + PartialEq + Clone + DeserializeOwned + Serialize + Debug, { self.get_cached_keys_raw_with_slug( namespace, None, false, keys, |ids| async move { Ok(closure(ids) .await? .into_iter() .map(|(key, val)| (key, (None::, val))) .collect()) }, ) .await } #[tracing::instrument(skip(self, closure))] pub async fn get_cached_keys_with_slug( &self, namespace: &str, slug_namespace: &str, case_sensitive: bool, keys: &[I], closure: F, ) -> Result, DatabaseError> where F: FnOnce(Vec) -> Fut, Fut: Future, T)>, DatabaseError>>, T: Serialize + DeserializeOwned, I: Display + Hash + Eq + PartialEq + Clone + Debug, K: Display + Hash + Eq + PartialEq + Clone + DeserializeOwned + Serialize, S: Display + Clone + DeserializeOwned + Serialize + Debug, { Ok(self .get_cached_keys_raw_with_slug( namespace, Some(slug_namespace), case_sensitive, keys, closure, ) .await? .into_iter() .map(|x| x.1) .collect()) } #[tracing::instrument(skip(self, closure))] pub async fn get_cached_keys_raw_with_slug( &self, namespace: &str, slug_namespace: Option<&str>, case_sensitive: bool, keys: &[I], closure: F, ) -> Result, DatabaseError> where F: FnOnce(Vec) -> Fut, Fut: Future, T)>, DatabaseError>>, T: Serialize + DeserializeOwned, I: Display + Hash + Eq + PartialEq + Clone + Debug, K: Display + Hash + Eq + PartialEq + Clone + DeserializeOwned + Serialize, S: Display + Clone + DeserializeOwned + Serialize + Debug, { let ids = keys .iter() .map(|x| (x.to_string(), x.clone())) .collect::>(); if ids.is_empty() { return Ok(HashMap::new()); } let get_cached_values = |ids: DashMap| { async move { let slug_ids = if let Some(slug_namespace) = slug_namespace { async { let mut connection = self.pool.get().await?; let args = ids .iter() .map(|x| { format!( "{}_{slug_namespace}:{}", self.meta_namespace, if case_sensitive { x.value().to_string() } else { x.value().to_string().to_lowercase() } ) }) .collect::>(); let mut v = Vec::new(); for chunk in args.chunks(MGET_CHUNK_SIZE) { let part = cmd("MGET") .arg(chunk) .query_async::>>( &mut connection, ) .await?; v.extend(part.into_iter().flatten()); } Ok::<_, DatabaseError>(v) } .instrument(info_span!("get slug ids")) .await? } else { Vec::new() }; let mut connection = self.pool.get().await?; let args = ids .iter() .map(|x| x.value().to_string()) .chain(ids.iter().filter_map(|x| { parse_base62(&x.value().to_string()) .ok() .map(|x| x.to_string()) })) .chain(slug_ids) .map(|x| format!("{}_{namespace}:{x}", self.meta_namespace)) .collect::>(); let mut cached_values = HashMap::new(); for chunk in args.chunks(MGET_CHUNK_SIZE) { let part = cmd("MGET") .arg(chunk) .query_async::>>(&mut connection) .await?; cached_values.extend(part.into_iter().filter_map(|x| { x.and_then(|val| { serde_json::from_str::>(&val) .ok() }) .map(|val| (val.key.clone(), val)) })); } Ok::<_, DatabaseError>((cached_values, ids)) } .instrument(info_span!("get cached values")) }; let current_time = Utc::now(); let mut expired_values = HashMap::new(); let (cached_values_raw, ids) = get_cached_values(ids).await?; let mut cached_values = cached_values_raw .into_iter() .filter_map(|(key, val)| { if Utc.timestamp_opt(val.iat + ACTUAL_EXPIRY, 0).unwrap() < current_time { expired_values.insert(val.key.to_string(), val); None } else { let key_str = val.key.to_string(); ids.remove(&key_str); if let Ok(value) = key_str.parse::() { let base62 = to_base62(value); ids.remove(&base62); } if let Some(ref alias) = val.alias { ids.remove(&alias.to_string()); } Some((key, val)) } }) .collect::>(); let subscribe_ids = DashMap::new(); let mut cache_writers = HashMap::new(); if !ids.is_empty() { let fetch_ids = ids.iter().map(|x| x.key().clone()).collect::>(); fetch_ids.into_iter().for_each(|key| { let ns_key_value = if case_sensitive { key.to_lowercase() } else { key.clone() }; let namespaced_key = format!( "{}_{namespace}:{ns_key_value}", self.meta_namespace, ); let either = self.acquire_lock(namespaced_key); match either { Either::Left(sentinel) => { cache_writers.insert(key, sentinel); } Either::Right(subscriber) => { if let Some((key, raw_key)) = ids.remove(&key) { if let Some(val) = expired_values.remove(&key) { if let Some(ref alias) = val.alias { ids.remove(&alias.to_string()); } if let Ok(value) = val.key.to_string().parse::() { let base62 = to_base62(value); ids.remove(&base62); } cached_values.insert(val.key.clone(), val); } else { subscribe_ids.insert(raw_key, subscriber); } } } } }); } let mut fetch_tasks = Vec::new(); if !ids.is_empty() { fetch_tasks.push(Either::Left(async { let fetch_ids = ids.iter().map(|x| x.value().clone()).collect::>(); let vals = closure(fetch_ids).await?; let mut return_values = HashMap::new(); let mut pipe = redis_pipe(); let mut pipe_cmds: usize = 0; let mut connection = self.pool.get().await?; // Doesn't need to be atomic if !vals.is_empty() { for (key, (slug, value)) in vals { let value = RedisValue { key: key.clone(), iat: Utc::now().timestamp(), val: value, alias: slug.clone(), }; pipe.set_ex( format!( "{}_{namespace}:{key}", self.meta_namespace ), serde_json::to_string(&value)?, DEFAULT_EXPIRY as u64, ); pipe_cmds += 1; if let Some(slug) = slug { ids.remove(&slug.to_string()); if let Some(slug_namespace) = slug_namespace { let actual_slug = if case_sensitive { slug.to_string() } else { slug.to_string().to_lowercase() }; pipe.set_ex( format!( "{}_{slug_namespace}:{}", self.meta_namespace, actual_slug ), key.to_string(), DEFAULT_EXPIRY as u64, ); pipe_cmds += 1; } } let key_str = key.to_string(); ids.remove(&key_str); if let Ok(value) = key_str.parse::() { let base62 = to_base62(value); ids.remove(&base62); } return_values.insert(key, value); if pipe_cmds >= PIPELINE_CHUNK_SIZE { pipe.query_async::<()>(&mut connection).await?; pipe = redis_pipe(); pipe_cmds = 0; } } } if pipe_cmds > 0 { pipe.query_async::<()>(&mut connection).await?; } drop(cache_writers); Result::<_, DatabaseError>::Ok(return_values) })); } if !subscribe_ids.is_empty() { fetch_tasks.push(Either::Right(async move { let mut futures = FuturesUnordered::new(); let len = subscribe_ids.len(); for (key, subscriber) in subscribe_ids { futures.push(async move { ( key, subscriber .wait_timeout(Duration::from_secs(5)) .await, ) }); } let fetch_ids = DashMap::with_capacity(len); while let Some((key, result)) = futures.next().await { result?; fetch_ids.insert(key.to_string(), key); } let (return_values, _) = get_cached_values(fetch_ids).await?; Ok(return_values) })); } if !fetch_tasks.is_empty() { for map in futures::future::try_join_all(fetch_tasks).await? { for (key, value) in map { cached_values.insert(key, value); } } } Ok(cached_values.into_iter().map(|x| (x.0, x.1.val)).collect()) } /// Acquire or create a cache lock onto the given key. fn acquire_lock( &self, key: String, ) -> Either, util::CacheSubscriber> { let mut out_writer = None; let subscriber = self.cache_list.entry(key.clone()).or_insert_with(|| { let (writer, subscriber) = util::cache(); out_writer = Some(writer); subscriber }); match out_writer { Some(writer) => Either::Left(LockSentinel { pool: self, key, writer, }), None => Either::Right(subscriber.clone()), } } } struct LockSentinel<'a> { pool: &'a RedisPool, key: String, writer: util::CacheWriter, } impl<'a> Drop for LockSentinel<'a> { fn drop(&mut self) { self.writer.write(); self.pool.cache_list.remove(&self.key); } } impl RedisConnection { #[tracing::instrument(skip(self))] pub async fn set( &mut self, namespace: &str, id: &str, data: &str, expiry: Option, ) -> Result<(), DatabaseError> { let mut cmd = cmd("SET"); redis_args( &mut cmd, vec![ format!("{}_{}:{}", self.meta_namespace, namespace, id), data.to_string(), "EX".to_string(), expiry.unwrap_or(DEFAULT_EXPIRY).to_string(), ] .as_slice(), ); redis_execute::<()>(&mut cmd, &mut self.connection).await?; Ok(()) } #[tracing::instrument(skip(self, id, data))] pub async fn set_serialized_to_json( &mut self, namespace: &str, id: Id, data: D, expiry: Option, ) -> Result<(), DatabaseError> where Id: Display, D: serde::Serialize, { self.set( namespace, &id.to_string(), &serde_json::to_string(&data)?, expiry, ) .await } #[tracing::instrument(skip(self))] pub async fn get( &mut self, namespace: &str, id: &str, ) -> Result, DatabaseError> { let mut cmd = cmd("GET"); redis_args( &mut cmd, vec![format!("{}_{}:{}", self.meta_namespace, namespace, id)] .as_slice(), ); let res = redis_execute(&mut cmd, &mut self.connection).await?; Ok(res) } #[tracing::instrument(skip(self))] pub async fn get_many( &mut self, namespace: &str, ids: &[String], ) -> Result>, DatabaseError> { let mut cmd = cmd("MGET"); redis_args( &mut cmd, ids.iter() .map(|x| format!("{}_{}:{}", self.meta_namespace, namespace, x)) .collect::>() .as_slice(), ); let res = redis_execute(&mut cmd, &mut self.connection).await?; Ok(res) } #[tracing::instrument(skip(self))] pub async fn get_deserialized_from_json( &mut self, namespace: &str, id: &str, ) -> Result, DatabaseError> where R: for<'a> serde::Deserialize<'a>, { Ok(self .get(namespace, id) .await? .and_then(|x| serde_json::from_str(&x).ok())) } #[tracing::instrument(skip(self))] pub async fn get_many_deserialized_from_json( &mut self, namespace: &str, ids: &[String], ) -> Result>, DatabaseError> where R: for<'a> serde::Deserialize<'a>, { Ok(self .get_many(namespace, ids) .await? .into_iter() .map(|x| x.and_then(|val| serde_json::from_str::(&val).ok())) .collect::>()) } #[tracing::instrument(skip(self, id))] pub async fn delete( &mut self, namespace: &str, id: T1, ) -> Result<(), DatabaseError> where T1: Display, { let mut cmd = cmd("DEL"); redis_args( &mut cmd, vec![format!("{}_{}:{}", self.meta_namespace, namespace, id)] .as_slice(), ); redis_execute::<()>(&mut cmd, &mut self.connection).await?; Ok(()) } #[tracing::instrument(skip(self, iter))] pub async fn delete_many( &mut self, iter: impl IntoIterator)>, ) -> Result<(), DatabaseError> { let mut cmd = cmd("DEL"); let mut any = false; for (namespace, id) in iter { if let Some(id) = id { redis_args( &mut cmd, [format!("{}_{}:{}", self.meta_namespace, namespace, id)] .as_slice(), ); any = true; } } if any { redis_execute::<()>(&mut cmd, &mut self.connection).await?; } Ok(()) } #[tracing::instrument(skip(self, value))] pub async fn lpush( &mut self, namespace: &str, key: &str, value: impl ToRedisArgs + Send + Sync + Debug, ) -> Result<(), DatabaseError> { let key = format!("{}_{namespace}:{key}", self.meta_namespace); cmd("LPUSH") .arg(key) .arg(value) .query_async::<()>(&mut self.connection) .await?; Ok(()) } #[tracing::instrument(skip(self))] pub async fn brpop( &mut self, namespace: &str, key: &str, timeout: Option, ) -> Result, DatabaseError> { let key = format!("{}_{namespace}:{key}", self.meta_namespace); // a timeout of 0 is infinite let timeout = timeout.unwrap_or(0.0); let values = cmd("BRPOP") .arg(key) .arg(timeout) .query_async(&mut self.connection) .await?; Ok(values) } #[tracing::instrument(skip(self))] pub async fn incr( &mut self, namespace: &str, id: &str, ) -> Result, DatabaseError> { let key = format!("{}_{namespace}:{id}", self.meta_namespace); let value = cmd("INCR") .arg(key) .query_async(&mut self.connection) .await?; Ok(value) } } #[derive(Serialize, Deserialize)] pub struct RedisValue { key: K, #[serde(skip_serializing_if = "Option::is_none")] alias: Option, iat: i64, val: T, } pub fn redis_args(cmd: &mut util::InstrumentedCmd, args: &[String]) { for arg in args { cmd.arg(arg); } } pub async fn redis_execute( cmd: &mut util::InstrumentedCmd, redis: &mut deadpool_redis::Connection, ) -> Result where T: redis::FromRedisValue, { let res = cmd.query_async::(redis).await?; Ok(res) }