noteguard

the nostr relay spam guardian
git clone git://jb55.com/noteguard
Log | Files | Refs | README | LICENSE

main.rs (10368B)


      1 use noteguard::filters::{Content, Kinds, ProtectedEvents, RateLimit, Whitelist};
      2 
      3 #[cfg(feature = "forwarder")]
      4 use noteguard::filters::Forwarder;
      5 
      6 use log::info;
      7 use noteguard::{Action, InputMessage, NoteFilter, OutputMessage};
      8 use serde::de::DeserializeOwned;
      9 use serde::Deserialize;
     10 use std::collections::HashMap;
     11 use std::io::{self, Read};
     12 
     13 #[derive(Deserialize)]
     14 struct Config {
     15     pipeline: Vec<String>,
     16     filters: HashMap<String, toml::Value>,
     17 }
     18 
     19 type ConstructFilter = Box<fn(toml::Value) -> Result<Box<dyn NoteFilter>, toml::de::Error>>;
     20 
     21 #[derive(Default)]
     22 struct Noteguard {
     23     registered_filters: HashMap<String, ConstructFilter>,
     24     loaded_filters: Vec<Box<dyn NoteFilter>>,
     25 }
     26 
     27 impl Noteguard {
     28     pub fn new() -> Self {
     29         let mut noteguard = Noteguard::default();
     30         noteguard.register_builtin_filters();
     31         noteguard
     32     }
     33 
     34     pub fn register_filter<F: NoteFilter + 'static + Default + DeserializeOwned>(&mut self) {
     35         self.registered_filters.insert(
     36             F::name(&F::default()).to_string(),
     37             Box::new(|filter_config| {
     38                 filter_config
     39                     .try_into()
     40                     .map(|filter: F| Box::new(filter) as Box<dyn NoteFilter>)
     41             }),
     42         );
     43     }
     44 
     45     /// All builtin filters are registered here, and are made available with
     46     /// every new instance of [`Noteguard`]
     47     fn register_builtin_filters(&mut self) {
     48         self.register_filter::<RateLimit>();
     49         self.register_filter::<Whitelist>();
     50         self.register_filter::<ProtectedEvents>();
     51         self.register_filter::<Kinds>();
     52         self.register_filter::<Content>();
     53 
     54         #[cfg(feature = "forwarder")]
     55         self.register_filter::<Forwarder>();
     56     }
     57 
     58     /// Run the loaded filters. You must call `load_config` before calling this, otherwise
     59     /// not filters will be run.
     60     fn run(&mut self, input: InputMessage) -> OutputMessage {
     61         let mut mout: Option<OutputMessage> = None;
     62 
     63         let id = input.event.id.clone();
     64         for filter in &mut self.loaded_filters {
     65             let out = filter.filter_note(&input);
     66             match out.action {
     67                 Action::Accept => {
     68                     mout = Some(out);
     69                     continue;
     70                 }
     71                 Action::Reject => {
     72                     return out;
     73                 }
     74                 Action::ShadowReject => {
     75                     return out;
     76                 }
     77             }
     78         }
     79 
     80         mout.unwrap_or_else(|| OutputMessage::new(id, Action::Accept, None))
     81     }
     82 
     83     /// Initializes a noteguard config. If it finds any filter configurations
     84     /// matching the registered filters, it loads those into our filter pipeline.
     85     fn load_config(&mut self, config: &Config) -> Result<(), toml::de::Error> {
     86         self.loaded_filters.clear();
     87 
     88         for name in &config.pipeline {
     89             let config_value = config
     90                 .filters
     91                 .get(name)
     92                 .unwrap_or_else(|| panic!("could not find filter configuration for {}", name));
     93 
     94             if let Some(constructor) = self.registered_filters.get(name.as_str()) {
     95                 let filter = constructor(config_value.clone())?;
     96                 self.loaded_filters.push(filter);
     97             } else {
     98                 panic!("Found config settings with no matching filter: {}", name);
     99             }
    100         }
    101 
    102         Ok(())
    103     }
    104 }
    105 
    106 #[cfg(feature = "forwarder")]
    107 #[tokio::main]
    108 async fn main() {
    109     noteguard();
    110 }
    111 
    112 #[cfg(not(feature = "forwarder"))]
    113 fn main() {
    114     noteguard();
    115 }
    116 
    117 fn serialize_output_message(msg: &OutputMessage) -> String {
    118     serde_json::to_string(msg).expect("OutputMessage should always serialize correctly")
    119 }
    120 
    121 fn noteguard() {
    122     env_logger::init();
    123     info!("running noteguard");
    124 
    125     let config_path = "noteguard.toml";
    126     let mut noteguard = Noteguard::new();
    127 
    128     let config: Config = {
    129         let mut file = std::fs::File::open(config_path).expect("Failed to open config file");
    130         let mut contents = String::new();
    131         file.read_to_string(&mut contents)
    132             .expect("Failed to read config file");
    133         toml::from_str(&contents).expect("Failed to parse config file")
    134     };
    135 
    136     noteguard
    137         .load_config(&config)
    138         .expect("Expected filter config to be loaded ok");
    139 
    140     let stdin = io::stdin();
    141 
    142     for line in stdin.lines() {
    143         let line = match line {
    144             Ok(line) => line,
    145             Err(e) => {
    146                 eprintln!("Failed to get line: {}", e);
    147                 continue;
    148             }
    149         };
    150 
    151         let input_message: InputMessage = match serde_json::from_str(&line) {
    152             Ok(msg) => msg,
    153             Err(e) => {
    154                 eprintln!("Failed to parse input: {}", e);
    155                 continue;
    156             }
    157         };
    158 
    159         if input_message.message_type != "new" {
    160             let out = OutputMessage::new(
    161                 input_message.event.id.clone(),
    162                 Action::Reject,
    163                 Some("invalid strfry write policy input".to_string()),
    164             );
    165             println!("{}", serialize_output_message(&out));
    166             continue;
    167         }
    168 
    169         let out = noteguard.run(input_message);
    170         let json = serialize_output_message(&out);
    171 
    172         println!("{}", json);
    173     }
    174 }
    175 
    176 #[cfg(test)]
    177 mod tests {
    178     use super::*;
    179     use noteguard::filters::{Kinds, ProtectedEvents, RateLimit, Whitelist};
    180     use noteguard::{Action, Note};
    181     use serde_json::json;
    182 
    183     // Helper function to create a mock InputMessage
    184     fn create_mock_input_message(event_id: &str, message_type: &str) -> InputMessage {
    185         InputMessage {
    186             message_type: message_type.to_string(),
    187             event: Note {
    188                 id: event_id.to_string(),
    189                 pubkey: "mock_pubkey".to_string(),
    190                 created_at: 0,
    191                 kind: 1,
    192                 tags: vec![vec!["-".to_string()]],
    193                 content: "mock_content".to_string(),
    194                 sig: "mock_signature".to_string(),
    195             },
    196             received_at: 0,
    197             source_type: "mock_source".to_string(),
    198             source_info: "mock_source_info".to_string(),
    199         }
    200     }
    201 
    202     // Helper function to create a mock OutputMessage
    203     fn create_mock_output_message(
    204         event_id: &str,
    205         action: Action,
    206         msg: Option<&str>,
    207     ) -> OutputMessage {
    208         OutputMessage {
    209             id: event_id.to_string(),
    210             action,
    211             msg: msg.map(|s| s.to_string()),
    212         }
    213     }
    214 
    215     #[test]
    216     fn test_register_builtin_filters() {
    217         let noteguard = Noteguard::new();
    218         assert!(noteguard.registered_filters.contains_key("ratelimit"));
    219         assert!(noteguard.registered_filters.contains_key("whitelist"));
    220         assert!(noteguard
    221             .registered_filters
    222             .contains_key("protected_events"));
    223         assert!(noteguard.registered_filters.contains_key("kinds"));
    224     }
    225 
    226     #[test]
    227     fn test_load_config() {
    228         let mut noteguard = Noteguard::new();
    229 
    230         // Create a mock config with one filter (RateLimit)
    231         let config: Config = toml::from_str(
    232             r#"
    233             pipeline = ["ratelimit"]
    234 
    235             [filters.ratelimit]
    236             posts_per_minute = 3
    237         "#,
    238         )
    239         .expect("Failed to parse config");
    240 
    241         assert!(noteguard.load_config(&config).is_ok());
    242         assert_eq!(noteguard.loaded_filters.len(), 1);
    243     }
    244 
    245     #[test]
    246     fn test_run_filters_accept() {
    247         let mut noteguard = Noteguard::new();
    248 
    249         // Create a mock config with one filter (RateLimit)
    250         let config: Config = toml::from_str(
    251             r#"
    252             pipeline = ["ratelimit"]
    253 
    254             [filters.ratelimit]
    255             posts_per_minute = 3
    256         "#,
    257         )
    258         .expect("Failed to parse config");
    259 
    260         noteguard
    261             .load_config(&config)
    262             .expect("Failed to load config");
    263 
    264         let input_message = create_mock_input_message("test_event_1", "new");
    265         let output_message = noteguard.run(input_message);
    266 
    267         assert_eq!(output_message.action, Action::Accept);
    268     }
    269 
    270     #[test]
    271     fn test_run_filters_shadow_reject() {
    272         let mut noteguard = Noteguard::new();
    273 
    274         // Create a mock config with one filter (ProtectedEvents) which will shadow reject the input
    275         let config: Config = toml::from_str(
    276             r#"
    277             pipeline = ["protected_events"]
    278 
    279             [filters.protected_events]
    280         "#,
    281         )
    282         .expect("Failed to parse config");
    283 
    284         noteguard
    285             .load_config(&config)
    286             .expect("Failed to load config");
    287 
    288         let input_message = create_mock_input_message("test_event_3", "new");
    289         let output_message = noteguard.run(input_message);
    290 
    291         assert_eq!(output_message.action, Action::Reject);
    292     }
    293 
    294     #[test]
    295     fn test_whitelist_reject() {
    296         let mut noteguard = Noteguard::new();
    297 
    298         // Create a mock config with one filter (Whitelist) which will reject the input
    299         let config: Config = toml::from_str(
    300             r#"
    301             pipeline = ["whitelist"]
    302             [filters.whitelist]
    303             pubkeys = ["something"]
    304         "#,
    305         )
    306         .expect("Failed to parse config");
    307 
    308         noteguard
    309             .load_config(&config)
    310             .expect("Failed to load config");
    311 
    312         let input_message = create_mock_input_message("test_event_2", "new");
    313         let output_message = noteguard.run(input_message);
    314 
    315         assert_eq!(output_message.action, Action::Reject);
    316     }
    317 
    318     #[test]
    319     fn test_deserialize_input_message() {
    320         let input_json = r#"
    321         {
    322             "type": "new",
    323             "event": {
    324                 "id": "test_event_5",
    325                 "pubkey": "mock_pubkey",
    326                 "created_at": 0,
    327                 "kind": 1,
    328                 "tags": [],
    329                 "content": "mock_content",
    330                 "sig": "mock_signature"
    331             },
    332             "receivedAt": 0,
    333             "sourceType": "mock_source",
    334             "sourceInfo": "mock_source_info"
    335         }
    336         "#;
    337 
    338         let input_message: InputMessage =
    339             serde_json::from_str(input_json).expect("Failed to deserialize input message");
    340         assert_eq!(input_message.event.id, "test_event_5");
    341         assert_eq!(input_message.message_type, "new");
    342     }
    343 }