rate_limit.rs (1741B)
1 use crate::{Action, InputMessage, NoteFilter, OutputMessage}; 2 use serde::Deserialize; 3 use std::collections::HashMap; 4 use std::time::{Duration, Instant}; 5 6 pub struct RateInfo { 7 pub last_note: Instant, 8 } 9 10 #[derive(Deserialize, Default)] 11 pub struct RateLimit { 12 pub notes_per_second: u64, 13 pub whitelist: Option<Vec<String>>, 14 15 #[serde(skip)] 16 pub sources: HashMap<String, RateInfo>, 17 } 18 19 impl NoteFilter for RateLimit { 20 fn filter_note(&mut self, msg: &InputMessage) -> OutputMessage { 21 if let Some(whitelist) = &self.whitelist { 22 if whitelist.contains(&msg.source_info) { 23 return OutputMessage::new(msg.event.id.clone(), Action::Accept, None); 24 } 25 } 26 27 if self.sources.contains_key(&msg.source_info) { 28 let now = Instant::now(); 29 let entry = self.sources.get_mut(&msg.source_info).expect("impossiburu"); 30 if now - entry.last_note < Duration::from_secs(self.notes_per_second) { 31 return OutputMessage::new( 32 msg.event.id.clone(), 33 Action::Reject, 34 Some("rate-limited: you are noting too fast".to_string()), 35 ); 36 } else { 37 entry.last_note = Instant::now(); 38 return OutputMessage::new(msg.event.id.clone(), Action::Accept, None); 39 } 40 } else { 41 self.sources.insert( 42 msg.source_info.to_owned(), 43 RateInfo { 44 last_note: Instant::now(), 45 }, 46 ); 47 return OutputMessage::new(msg.event.id.clone(), Action::Accept, None); 48 } 49 } 50 51 fn name(&self) -> &'static str { 52 "ratelimit" 53 } 54 }