main.rs (10368B)
1 use noteguard::filters::{Content, Kinds, ProtectedEvents, RateLimit, Whitelist}; 2 3 #[cfg(feature = "forwarder")] 4 use noteguard::filters::Forwarder; 5 6 use log::info; 7 use noteguard::{Action, InputMessage, NoteFilter, OutputMessage}; 8 use serde::de::DeserializeOwned; 9 use serde::Deserialize; 10 use std::collections::HashMap; 11 use std::io::{self, Read}; 12 13 #[derive(Deserialize)] 14 struct Config { 15 pipeline: Vec<String>, 16 filters: HashMap<String, toml::Value>, 17 } 18 19 type ConstructFilter = Box<fn(toml::Value) -> Result<Box<dyn NoteFilter>, toml::de::Error>>; 20 21 #[derive(Default)] 22 struct Noteguard { 23 registered_filters: HashMap<String, ConstructFilter>, 24 loaded_filters: Vec<Box<dyn NoteFilter>>, 25 } 26 27 impl Noteguard { 28 pub fn new() -> Self { 29 let mut noteguard = Noteguard::default(); 30 noteguard.register_builtin_filters(); 31 noteguard 32 } 33 34 pub fn register_filter<F: NoteFilter + 'static + Default + DeserializeOwned>(&mut self) { 35 self.registered_filters.insert( 36 F::name(&F::default()).to_string(), 37 Box::new(|filter_config| { 38 filter_config 39 .try_into() 40 .map(|filter: F| Box::new(filter) as Box<dyn NoteFilter>) 41 }), 42 ); 43 } 44 45 /// All builtin filters are registered here, and are made available with 46 /// every new instance of [`Noteguard`] 47 fn register_builtin_filters(&mut self) { 48 self.register_filter::<RateLimit>(); 49 self.register_filter::<Whitelist>(); 50 self.register_filter::<ProtectedEvents>(); 51 self.register_filter::<Kinds>(); 52 self.register_filter::<Content>(); 53 54 #[cfg(feature = "forwarder")] 55 self.register_filter::<Forwarder>(); 56 } 57 58 /// Run the loaded filters. You must call `load_config` before calling this, otherwise 59 /// not filters will be run. 60 fn run(&mut self, input: InputMessage) -> OutputMessage { 61 let mut mout: Option<OutputMessage> = None; 62 63 let id = input.event.id.clone(); 64 for filter in &mut self.loaded_filters { 65 let out = filter.filter_note(&input); 66 match out.action { 67 Action::Accept => { 68 mout = Some(out); 69 continue; 70 } 71 Action::Reject => { 72 return out; 73 } 74 Action::ShadowReject => { 75 return out; 76 } 77 } 78 } 79 80 mout.unwrap_or_else(|| OutputMessage::new(id, Action::Accept, None)) 81 } 82 83 /// Initializes a noteguard config. If it finds any filter configurations 84 /// matching the registered filters, it loads those into our filter pipeline. 85 fn load_config(&mut self, config: &Config) -> Result<(), toml::de::Error> { 86 self.loaded_filters.clear(); 87 88 for name in &config.pipeline { 89 let config_value = config 90 .filters 91 .get(name) 92 .unwrap_or_else(|| panic!("could not find filter configuration for {}", name)); 93 94 if let Some(constructor) = self.registered_filters.get(name.as_str()) { 95 let filter = constructor(config_value.clone())?; 96 self.loaded_filters.push(filter); 97 } else { 98 panic!("Found config settings with no matching filter: {}", name); 99 } 100 } 101 102 Ok(()) 103 } 104 } 105 106 #[cfg(feature = "forwarder")] 107 #[tokio::main] 108 async fn main() { 109 noteguard(); 110 } 111 112 #[cfg(not(feature = "forwarder"))] 113 fn main() { 114 noteguard(); 115 } 116 117 fn serialize_output_message(msg: &OutputMessage) -> String { 118 serde_json::to_string(msg).expect("OutputMessage should always serialize correctly") 119 } 120 121 fn noteguard() { 122 env_logger::init(); 123 info!("running noteguard"); 124 125 let config_path = "noteguard.toml"; 126 let mut noteguard = Noteguard::new(); 127 128 let config: Config = { 129 let mut file = std::fs::File::open(config_path).expect("Failed to open config file"); 130 let mut contents = String::new(); 131 file.read_to_string(&mut contents) 132 .expect("Failed to read config file"); 133 toml::from_str(&contents).expect("Failed to parse config file") 134 }; 135 136 noteguard 137 .load_config(&config) 138 .expect("Expected filter config to be loaded ok"); 139 140 let stdin = io::stdin(); 141 142 for line in stdin.lines() { 143 let line = match line { 144 Ok(line) => line, 145 Err(e) => { 146 eprintln!("Failed to get line: {}", e); 147 continue; 148 } 149 }; 150 151 let input_message: InputMessage = match serde_json::from_str(&line) { 152 Ok(msg) => msg, 153 Err(e) => { 154 eprintln!("Failed to parse input: {}", e); 155 continue; 156 } 157 }; 158 159 if input_message.message_type != "new" { 160 let out = OutputMessage::new( 161 input_message.event.id.clone(), 162 Action::Reject, 163 Some("invalid strfry write policy input".to_string()), 164 ); 165 println!("{}", serialize_output_message(&out)); 166 continue; 167 } 168 169 let out = noteguard.run(input_message); 170 let json = serialize_output_message(&out); 171 172 println!("{}", json); 173 } 174 } 175 176 #[cfg(test)] 177 mod tests { 178 use super::*; 179 use noteguard::filters::{Kinds, ProtectedEvents, RateLimit, Whitelist}; 180 use noteguard::{Action, Note}; 181 use serde_json::json; 182 183 // Helper function to create a mock InputMessage 184 fn create_mock_input_message(event_id: &str, message_type: &str) -> InputMessage { 185 InputMessage { 186 message_type: message_type.to_string(), 187 event: Note { 188 id: event_id.to_string(), 189 pubkey: "mock_pubkey".to_string(), 190 created_at: 0, 191 kind: 1, 192 tags: vec![vec!["-".to_string()]], 193 content: "mock_content".to_string(), 194 sig: "mock_signature".to_string(), 195 }, 196 received_at: 0, 197 source_type: "mock_source".to_string(), 198 source_info: "mock_source_info".to_string(), 199 } 200 } 201 202 // Helper function to create a mock OutputMessage 203 fn create_mock_output_message( 204 event_id: &str, 205 action: Action, 206 msg: Option<&str>, 207 ) -> OutputMessage { 208 OutputMessage { 209 id: event_id.to_string(), 210 action, 211 msg: msg.map(|s| s.to_string()), 212 } 213 } 214 215 #[test] 216 fn test_register_builtin_filters() { 217 let noteguard = Noteguard::new(); 218 assert!(noteguard.registered_filters.contains_key("ratelimit")); 219 assert!(noteguard.registered_filters.contains_key("whitelist")); 220 assert!(noteguard 221 .registered_filters 222 .contains_key("protected_events")); 223 assert!(noteguard.registered_filters.contains_key("kinds")); 224 } 225 226 #[test] 227 fn test_load_config() { 228 let mut noteguard = Noteguard::new(); 229 230 // Create a mock config with one filter (RateLimit) 231 let config: Config = toml::from_str( 232 r#" 233 pipeline = ["ratelimit"] 234 235 [filters.ratelimit] 236 posts_per_minute = 3 237 "#, 238 ) 239 .expect("Failed to parse config"); 240 241 assert!(noteguard.load_config(&config).is_ok()); 242 assert_eq!(noteguard.loaded_filters.len(), 1); 243 } 244 245 #[test] 246 fn test_run_filters_accept() { 247 let mut noteguard = Noteguard::new(); 248 249 // Create a mock config with one filter (RateLimit) 250 let config: Config = toml::from_str( 251 r#" 252 pipeline = ["ratelimit"] 253 254 [filters.ratelimit] 255 posts_per_minute = 3 256 "#, 257 ) 258 .expect("Failed to parse config"); 259 260 noteguard 261 .load_config(&config) 262 .expect("Failed to load config"); 263 264 let input_message = create_mock_input_message("test_event_1", "new"); 265 let output_message = noteguard.run(input_message); 266 267 assert_eq!(output_message.action, Action::Accept); 268 } 269 270 #[test] 271 fn test_run_filters_shadow_reject() { 272 let mut noteguard = Noteguard::new(); 273 274 // Create a mock config with one filter (ProtectedEvents) which will shadow reject the input 275 let config: Config = toml::from_str( 276 r#" 277 pipeline = ["protected_events"] 278 279 [filters.protected_events] 280 "#, 281 ) 282 .expect("Failed to parse config"); 283 284 noteguard 285 .load_config(&config) 286 .expect("Failed to load config"); 287 288 let input_message = create_mock_input_message("test_event_3", "new"); 289 let output_message = noteguard.run(input_message); 290 291 assert_eq!(output_message.action, Action::Reject); 292 } 293 294 #[test] 295 fn test_whitelist_reject() { 296 let mut noteguard = Noteguard::new(); 297 298 // Create a mock config with one filter (Whitelist) which will reject the input 299 let config: Config = toml::from_str( 300 r#" 301 pipeline = ["whitelist"] 302 [filters.whitelist] 303 pubkeys = ["something"] 304 "#, 305 ) 306 .expect("Failed to parse config"); 307 308 noteguard 309 .load_config(&config) 310 .expect("Failed to load config"); 311 312 let input_message = create_mock_input_message("test_event_2", "new"); 313 let output_message = noteguard.run(input_message); 314 315 assert_eq!(output_message.action, Action::Reject); 316 } 317 318 #[test] 319 fn test_deserialize_input_message() { 320 let input_json = r#" 321 { 322 "type": "new", 323 "event": { 324 "id": "test_event_5", 325 "pubkey": "mock_pubkey", 326 "created_at": 0, 327 "kind": 1, 328 "tags": [], 329 "content": "mock_content", 330 "sig": "mock_signature" 331 }, 332 "receivedAt": 0, 333 "sourceType": "mock_source", 334 "sourceInfo": "mock_source_info" 335 } 336 "#; 337 338 let input_message: InputMessage = 339 serde_json::from_str(input_json).expect("Failed to deserialize input message"); 340 assert_eq!(input_message.event.id, "test_event_5"); 341 assert_eq!(input_message.message_type, "new"); 342 } 343 }