use std::{
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};
use tokio::time::{sleep_until, Duration, Instant, Sleep};
const COMMANDS_PER_PERIOD: u8 = 120;
const PERIOD: Duration = Duration::from_secs(60);
#[derive(Debug)]
pub struct CommandRatelimiter {
delay: Pin<Box<Sleep>>,
instants: Vec<Instant>,
}
impl CommandRatelimiter {
pub(crate) fn new(heartbeat_interval: Duration) -> Self {
let allotted = nonreserved_commands_per_reset(heartbeat_interval);
let now = Instant::now();
let mut delay = Box::pin(sleep_until(now));
delay.as_mut().reset(now);
Self {
delay,
instants: Vec::with_capacity(allotted.into()),
}
}
#[allow(clippy::cast_possible_truncation)]
pub fn available(&self) -> u8 {
let now = Instant::now();
let elapsed_permits = self.instants.partition_point(|&elapsed| elapsed <= now);
let used_permits = self.instants.len() - elapsed_permits;
self.max() - used_permits as u8
}
#[allow(clippy::cast_possible_truncation)]
pub fn max(&self) -> u8 {
self.instants.capacity() as u8
}
pub fn next_available(&self) -> Duration {
self.instants.first().map_or(Duration::ZERO, |elapsed| {
elapsed.saturating_duration_since(Instant::now())
})
}
pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<()> {
ready!(self.poll_ready(cx));
self.instants.push(Instant::now() + PERIOD);
Poll::Ready(())
}
pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> {
if self.instants.len() != self.instants.capacity() {
return Poll::Ready(());
}
if !self.delay.is_elapsed() {
return Poll::Pending;
}
let new_deadline = self.instants[0];
let now = Instant::now();
if new_deadline > now {
tracing::debug!(duration = ?(new_deadline - now), "ratelimited");
self.delay.as_mut().reset(new_deadline);
_ = self.delay.as_mut().poll(cx);
Poll::Pending
} else {
let elapsed_permits = self.instants.partition_point(|&elapsed| elapsed <= now);
let used_permits = self.instants.len() - elapsed_permits;
self.instants.rotate_right(used_permits);
self.instants.truncate(used_permits);
Poll::Ready(())
}
}
}
fn nonreserved_commands_per_reset(heartbeat_interval: Duration) -> u8 {
const MAX_NONRESERVED_COMMANDS_PER_PERIOD: u8 = COMMANDS_PER_PERIOD - 10;
let heartbeats_per_reset = PERIOD.as_secs_f32() / heartbeat_interval.as_secs_f32();
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
let heartbeats_per_reset = heartbeats_per_reset.ceil() as u8;
let heartbeats_per_reset = heartbeats_per_reset.saturating_add(1);
let nonreserved_commands_per_reset = COMMANDS_PER_PERIOD.saturating_sub(heartbeats_per_reset);
nonreserved_commands_per_reset.max(MAX_NONRESERVED_COMMANDS_PER_PERIOD)
}
#[cfg(test)]
mod tests {
use super::{nonreserved_commands_per_reset, CommandRatelimiter, PERIOD};
use static_assertions::assert_impl_all;
use std::{fmt::Debug, future::poll_fn, task::Poll, time::Duration};
use tokio::time;
assert_impl_all!(CommandRatelimiter: Debug, Send, Sync);
#[test]
fn nonreserved_commands() {
assert_eq!(
118,
nonreserved_commands_per_reset(Duration::from_secs(u64::MAX))
);
assert_eq!(118, nonreserved_commands_per_reset(Duration::from_secs(60)));
assert_eq!(
117,
nonreserved_commands_per_reset(Duration::from_millis(42_500))
);
assert_eq!(117, nonreserved_commands_per_reset(Duration::from_secs(30)));
assert_eq!(
116,
nonreserved_commands_per_reset(Duration::from_millis(29_999))
);
assert_eq!(110, nonreserved_commands_per_reset(Duration::ZERO));
}
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(60);
#[tokio::test(start_paused = true)]
async fn full_reset() {
let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
assert_eq!(ratelimiter.available(), ratelimiter.max());
for _ in 0..ratelimiter.max() {
poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
}
assert_eq!(ratelimiter.available(), 0);
time::advance(PERIOD - Duration::from_millis(100)).await;
assert_eq!(ratelimiter.available(), 0);
time::advance(Duration::from_millis(100)).await;
assert_eq!(ratelimiter.available(), ratelimiter.max());
}
#[tokio::test(start_paused = true)]
async fn half_reset() {
let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
assert_eq!(ratelimiter.available(), ratelimiter.max());
for _ in 0..ratelimiter.max() / 2 {
poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
}
assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
time::advance(PERIOD / 2).await;
assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
for _ in 0..ratelimiter.max() / 2 {
poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
}
assert_eq!(ratelimiter.available(), 0);
time::advance(PERIOD / 2).await;
assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
time::advance(PERIOD / 2).await;
assert_eq!(ratelimiter.available(), ratelimiter.max());
}
#[tokio::test(start_paused = true)]
async fn constant_capacity() {
let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
let max = ratelimiter.max();
for _ in 0..max {
poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
}
assert_eq!(ratelimiter.available(), 0);
poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
assert_eq!(max, ratelimiter.max());
}
#[tokio::test(start_paused = true)]
async fn spurious_poll() {
let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
for _ in 0..ratelimiter.max() {
poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
}
assert_eq!(ratelimiter.available(), 0);
poll_fn(|cx| {
if ratelimiter.poll_ready(cx).is_ready() {
return Poll::Ready(());
};
let deadline = ratelimiter.delay.deadline();
assert!(ratelimiter.poll_ready(cx).is_pending());
assert_eq!(deadline, ratelimiter.delay.deadline(), "deadline was reset");
Poll::Pending
})
.await;
}
}