ratelimit.rs (2166B)
1 use crate::{Action, InputMessage, NoteFilter, OutputMessage}; 2 use serde::Deserialize; 3 use std::collections::HashMap; 4 use std::time::{Duration, Instant}; 5 6 pub struct Tokens { 7 pub tokens: i32, 8 pub last_post: Instant, 9 } 10 11 #[derive(Deserialize, Default)] 12 pub struct RateLimit { 13 pub posts_per_minute: i32, 14 pub whitelist: Option<Vec<String>>, 15 16 #[serde(skip)] 17 pub sources: HashMap<String, Tokens>, 18 } 19 20 impl NoteFilter for RateLimit { 21 fn name(&self) -> &'static str { 22 "ratelimit" 23 } 24 25 fn filter_note(&mut self, msg: &InputMessage) -> OutputMessage { 26 if let Some(whitelist) = &self.whitelist { 27 if whitelist.contains(&msg.source_info) { 28 return OutputMessage::new(msg.event.id.clone(), Action::Accept, None); 29 } 30 } 31 32 if !self.sources.contains_key(&msg.source_info) { 33 self.sources.insert( 34 msg.source_info.to_owned(), 35 Tokens { 36 last_post: Instant::now(), 37 tokens: self.posts_per_minute, 38 }, 39 ); 40 return OutputMessage::new(msg.event.id.clone(), Action::Accept, None); 41 } 42 43 let entry = self.sources.get_mut(&msg.source_info).expect("impossiburu"); 44 let now = Instant::now(); 45 let mut diff = now - entry.last_post; 46 47 let min = Duration::from_secs(60); 48 if diff > min { 49 diff = min; 50 } 51 52 let percent = (diff.as_secs() as f32) / 60.0; 53 let new_tokens = (percent * self.posts_per_minute as f32).floor() as i32; 54 entry.tokens += new_tokens - 1; 55 56 if entry.tokens <= 0 { 57 entry.tokens = 0; 58 } 59 60 if entry.tokens >= self.posts_per_minute { 61 entry.tokens = self.posts_per_minute - 1; 62 } 63 64 if entry.tokens == 0 { 65 return OutputMessage::new( 66 msg.event.id.clone(), 67 Action::Reject, 68 Some("rate-limited: you are noting too much".to_string()), 69 ); 70 } 71 72 entry.last_post = now; 73 OutputMessage::new(msg.event.id.clone(), Action::Accept, None) 74 } 75 }