noteguard

the nostr relay spam guardian
git clone git://jb55.com/noteguard
Log | Files | Refs | README | LICENSE

forwarder.rs (3752B)


      1 use serde::Deserialize;
      2 use crate::{Note, Action, NoteFilter, InputMessage, OutputMessage};
      3 use futures_util::{SinkExt, StreamExt};
      4 use tokio::sync::mpsc::{self, Sender, Receiver};
      5 use tokio_tungstenite::connect_async;
      6 use tokio_tungstenite::tungstenite::Message;
      7 use tokio_tungstenite::WebSocketStream;
      8 use tokio::time::{sleep, timeout, Duration};
      9 use serde_json::json;
     10 use log::{error, info, debug};
     11 
     12 #[derive(Default, Deserialize)]
     13 pub struct Forwarder {
     14     relay: String,
     15 
     16     /// the size of our bounded queue
     17     queue_size: Option<u32>,
     18 
     19     /// The channel used for communicating with the forwarder thread
     20     #[serde(skip)]
     21     channel: Option<Sender<Note>>,
     22 }
     23 
     24 async fn client_reconnect(relay: &str) -> WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>> {
     25     loop {
     26         match connect_async(relay).await {
     27             Err(e) => {
     28                 error!("failed to connect to relay {}: {}", relay, e);
     29                 sleep(Duration::from_secs(5)).await;
     30                 continue;
     31             }
     32             Ok((ws, _)) => {
     33                 info!("connected to relay: {}", relay);
     34                 return ws;
     35             }
     36         }
     37     }
     38 }
     39 
     40 async fn forwarder_task(relay: String, mut rx: Receiver<Note>) {
     41     let stream = client_reconnect(&relay).await;
     42     let (mut writer, mut reader) = stream.split();
     43 
     44     loop {
     45         tokio::select! {
     46             result = timeout(Duration::from_secs(10), rx.recv()) => {
     47                 match result {
     48                     Ok(Some(note)) => {
     49                         if let Err(e) = writer.send(Message::Text(serde_json::to_string(&json!(["EVENT", note])).unwrap())).await {
     50                             error!("got error: '{}', reconnecting...", e);
     51                             let (w, r) = client_reconnect(&relay).await.split();
     52                             writer = w;
     53                             reader = r;
     54                         }
     55                     },
     56                     Ok(None) => {
     57                         // Channel has been closed, exit the loop
     58                         error!("channel closed, stopping forwarder_task");
     59                         break;
     60                     }
     61                     Err(_) => {
     62                         // Timeout occurred, send a ping
     63                         // try reading for pongs, etc
     64                         let _r = reader.next();
     65                         debug!("timeout reading note queue, sending ping");
     66 
     67                         if let Err(e) = writer.send(Message::Ping(vec![])).await {
     68                             error!("error during ping ({}), reconnecting...", e);
     69                             let (w, r) = client_reconnect(&relay).await.split();
     70                             writer = w;
     71                             reader = r;
     72                         }
     73                     }
     74                 }
     75             }
     76         }
     77     }
     78 }
     79 
     80 impl NoteFilter for Forwarder {
     81     fn name(&self) -> &'static str {
     82         "forwarder"
     83     }
     84 
     85     fn filter_note(&mut self, input: &InputMessage) -> OutputMessage {
     86         if self.channel.is_none() {
     87             let (tx, rx) = mpsc::channel(self.queue_size.unwrap_or(1000) as usize);
     88             let relay = self.relay.clone();
     89 
     90             tokio::task::spawn(async move {
     91                 forwarder_task(relay, rx).await;
     92             });
     93 
     94             self.channel = Some(tx);
     95         }
     96 
     97         // Add code to process input and send through channel
     98         if let Some(ref channel) = self.channel {
     99             if let Err(e) = channel.try_send(input.event.clone()) {
    100                 eprintln!("could not forward note: {}", e);
    101             }
    102         }
    103 
    104         // Create and return an appropriate OutputMessage
    105         OutputMessage::new(input.event.id.clone(), Action::Accept, None)
    106     }
    107 }