nostr-rs-relay

My dev fork of nostr-rs-relay
git clone git://jb55.com/nostr-rs-relay
Log | Files | Refs | README | LICENSE

commit 92e9a5e639847f475b41df96607afb5e69617ec7
parent d0c2b242cd523707993bd4559633f3786a1e5335
Author: Greg Heartsfield <scsibug@imap.cc>
Date:   Sun,  5 Dec 2021 16:53:26 -0600

feat: parse and validate events from websockets

Establishes a websocket listener, parses events, and performs
validation to ensure valid signatures.

Diffstat:
Asrc/conn.rs | 22++++++++++++++++++++++
Asrc/error.rs | 39+++++++++++++++++++++++++++++++++++++++
Asrc/event.rs | 240+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Asrc/lib.rs | 4++++
Msrc/main.rs | 104+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
Asrc/protostream.rs | 99+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
6 files changed, 506 insertions(+), 2 deletions(-)

diff --git a/src/conn.rs b/src/conn.rs @@ -0,0 +1,22 @@ +//use std::collections::HashMap; +use uuid::Uuid; + +// state for a client connection +pub struct ClientConn { + _client_id: Uuid, + // current set of subscriptions + //subscriptions: HashMap<String, Subscription>, + // websocket + //stream: WebSocketStream<TcpStream>, + _max_subs: usize, +} + +impl ClientConn { + pub fn new() -> Self { + let client_id = Uuid::new_v4(); + ClientConn { + _client_id: client_id, + _max_subs: 128, + } + } +} diff --git a/src/error.rs b/src/error.rs @@ -0,0 +1,39 @@ +//! Error handling. + +use std::result; +use thiserror::Error; +use tungstenite::error::Error as WsError; + +pub type Result<T, E = Error> = result::Result<T, E>; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Protocol parse error")] + ProtoParseError, + #[error("Connection error")] + ConnError, + #[error("Client write error")] + ConnWriteError, + #[error("Event parse failed")] + EventParseFailed, + #[error("Event validation failed")] + EventInvalid, + #[error("JSON parsing failed")] + JsonParseFailed(serde_json::Error), + #[error("WebSocket proto error")] + WebsocketError(WsError), + #[error("Command unknown")] + CommandUnknownError, +} + +impl From<serde_json::Error> for Error { + fn from(r: serde_json::Error) -> Self { + Error::JsonParseFailed(r) + } +} + +impl From<WsError> for Error { + fn from(r: WsError) -> Self { + Error::WebsocketError(r) + } +} diff --git a/src/event.rs b/src/event.rs @@ -0,0 +1,240 @@ +use crate::error::Error::*; +use crate::error::Result; +use bitcoin_hashes::{sha256, Hash}; +use log::info; +use secp256k1::{schnorrsig, Secp256k1}; +use serde::{Deserialize, Deserializer, Serialize}; +use serde_json::value::Value; +use serde_json::Number; +use std::str::FromStr; + +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] +pub struct EventCmd { + cmd: String, // expecting static "EVENT" + event: Event, +} + +#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] +pub struct Event { + pub(crate) id: String, + pub(crate) pubkey: String, + pub(crate) created_at: u64, + pub(crate) kind: u64, + #[serde(deserialize_with = "tag_from_string")] + // TODO: array-of-arrays may need to be more general than a string container + pub(crate) tags: Vec<Vec<String>>, + pub(crate) content: String, + pub(crate) sig: String, +} + +type Tag = Vec<Vec<String>>; + +// handle a default value (empty vec) for null tags +fn tag_from_string<'de, D>(deserializer: D) -> Result<Tag, D::Error> +where + D: Deserializer<'de>, +{ + let opt = Option::deserialize(deserializer)?; + Ok(opt.unwrap_or_else(|| vec![])) +} + +impl From<EventCmd> for Result<Event> { + fn from(ec: EventCmd) -> Result<Event> { + // ensure command is correct + if ec.cmd != "EVENT" { + return Err(CommandUnknownError); + } else if ec.event.is_valid() { + return Ok(ec.event); + } else { + return Err(EventInvalid); + } + } +} + +impl Event { + // check if this event is valid (should be propagated, stored) based on signature. + fn is_valid(&self) -> bool { + // validation is performed by: + // * parsing JSON string into event fields + // * create an array: + // ** [0, pubkey-hex-string, created-at-num, kind-num, tags-array-of-arrays, content-string] + // * serialize with no spaces/newlines + let c_opt = self.to_canonical(); + if c_opt.is_none() { + info!("event could not be canonicalized"); + return false; + } + let c = c_opt.unwrap(); + // * compute the sha256sum. + let digest: sha256::Hash = sha256::Hash::hash(&c.as_bytes()); + let hex_digest = format!("{:x}", digest); + // * ensure the id matches the computed sha256sum. + if self.id != hex_digest { + return false; + } + // * validate the message digest (sig) using the pubkey & computed sha256 message hash. + let secp = Secp256k1::new(); + let sig = schnorrsig::Signature::from_str(&self.sig).unwrap(); + let message = secp256k1::Message::from(digest); + let pubkey = schnorrsig::PublicKey::from_str(&self.pubkey).unwrap(); + let verify = secp.schnorrsig_verify(&sig, &message, &pubkey); + match verify { + Ok(()) => { + info!("verified event"); + true + } + _ => false, + } + } + + // convert event to canonical representation for signing + fn to_canonical(&self) -> Option<String> { + // create a JsonValue for each event element + let mut c: Vec<Value> = vec![]; + // id must be set to 0 + let id = Number::from(0 as u64); + c.push(serde_json::Value::Number(id)); + // public key + c.push(Value::String(self.pubkey.to_owned())); + // creation time + let created_at = Number::from(self.created_at); + c.push(serde_json::Value::Number(created_at)); + // kind + let kind = Number::from(self.kind); + c.push(serde_json::Value::Number(kind)); + // tags + c.push(self.tags_to_canonical()); + // content + c.push(Value::String(self.content.to_owned())); + serde_json::to_string(&Value::Array(c)).ok() + } + fn tags_to_canonical(&self) -> Value { + let mut tags = Vec::<Value>::new(); + // iterate over self tags, + for t in self.tags.iter() { + // each tag is a vec of strings + let mut a = Vec::<Value>::new(); + for v in t.iter() { + a.push(serde_json::Value::String(v.to_owned())); + } + tags.push(serde_json::Value::Array(a)); + } + serde_json::Value::Array(tags) + } + + // check if given event is referenced in a tag + pub fn event_tag_match(&self, eventid: &str) -> bool { + for t in self.tags.iter() { + if t.len() == 2 { + if t.get(0).unwrap() == "#e" { + if t.get(1).unwrap() == eventid { + return true; + } + } + } + } + return false; + } +} + +#[cfg(test)] +mod tests { + use super::*; + fn simple_event() -> Event { + Event { + id: "0".to_owned(), + pubkey: "0".to_owned(), + created_at: 0, + kind: 0, + tags: vec![], + content: "".to_owned(), + sig: "0".to_owned(), + } + } + + #[test] + fn event_creation() { + // create an event + let event = simple_event(); + assert_eq!(event.id, "0"); + } + + #[test] + fn event_serialize() -> Result<()> { + // serialize an event to JSON string + let event = simple_event(); + let j = serde_json::to_string(&event)?; + assert_eq!(j, "{\"id\":\"0\",\"pubkey\":\"0\",\"created_at\":0,\"kind\":0,\"tags\":[],\"content\":\"\",\"sig\":\"0\"}"); + Ok(()) + } + + #[test] + fn event_tags_serialize() -> Result<()> { + // serialize an event with tags to JSON string + let mut event = simple_event(); + event.tags = vec![ + vec![ + "e".to_owned(), + "xxxx".to_owned(), + "wss://example.com".to_owned(), + ], + vec![ + "p".to_owned(), + "yyyyy".to_owned(), + "wss://example.com:3033".to_owned(), + ], + ]; + let j = serde_json::to_string(&event)?; + assert_eq!(j, "{\"id\":\"0\",\"pubkey\":\"0\",\"created_at\":0,\"kind\":0,\"tags\":[[\"e\",\"xxxx\",\"wss://example.com\"],[\"p\",\"yyyyy\",\"wss://example.com:3033\"]],\"content\":\"\",\"sig\":\"0\"}"); + Ok(()) + } + + #[test] + fn event_deserialize() -> Result<()> { + let raw_json = r#"{"id":"1384757da583e6129ce831c3d7afc775a33a090578f888dd0d010328ad047d0c","pubkey":"bbbd9711d357df4f4e498841fd796535c95c8e751fa35355008a911c41265fca","created_at":1612650459,"kind":1,"tags":null,"content":"hello world","sig":"59d0cc47ab566e81f72fe5f430bcfb9b3c688cb0093d1e6daa49201c00d28ecc3651468b7938642869ed98c0f1b262998e49a05a6ed056c0d92b193f4e93bc21"}"#; + let e: Event = serde_json::from_str(raw_json)?; + assert_eq!(e.kind, 1); + assert_eq!(e.tags.len(), 0); + Ok(()) + } + + #[test] + fn event_canonical() { + let e = Event { + id: "999".to_owned(), + pubkey: "012345".to_owned(), + created_at: 501234, + kind: 1, + tags: vec![], + content: "this is a test".to_owned(), + sig: "abcde".to_owned(), + }; + let c = e.to_canonical(); + let expected = Some(r#"[0,"012345",501234,1,[],"this is a test"]"#.to_owned()); + assert_eq!(c, expected); + } + + #[test] + fn event_canonical_with_tags() { + let e = Event { + id: "999".to_owned(), + pubkey: "012345".to_owned(), + created_at: 501234, + kind: 1, + tags: vec![ + vec!["#e".to_owned(), "aoeu".to_owned()], + vec![ + "#p".to_owned(), + "aaaa".to_owned(), + "ws://example.com".to_owned(), + ], + ], + content: "this is a test".to_owned(), + sig: "abcde".to_owned(), + }; + let c = e.to_canonical(); + let expected_json = r###"[0,"012345",501234,1,[["#e","aoeu"],["#p","aaaa","ws://example.com"]],"this is a test"]"###; + let expected = Some(expected_json.to_owned()); + assert_eq!(c, expected); + } +} diff --git a/src/lib.rs b/src/lib.rs @@ -0,0 +1,4 @@ +pub mod conn; +pub mod error; +pub mod event; +pub mod protostream; diff --git a/src/main.rs b/src/main.rs @@ -1,3 +1,103 @@ -fn main() { - println!("Hello, world!"); +use futures::StreamExt; +use log::*; +use nostr_rs_relay::conn; +use nostr_rs_relay::error::{Error, Result}; +use nostr_rs_relay::event::Event; +use nostr_rs_relay::protostream; +use nostr_rs_relay::protostream::NostrMessage::*; +use rusqlite::Result as SQLResult; +use std::env; +use tokio::net::{TcpListener, TcpStream}; +use tokio::runtime::Builder; +use tokio::sync::broadcast; +use tokio::sync::broadcast::Sender; +use tokio::sync::mpsc; + +/// Start running a Nostr relay server. +fn main() -> Result<(), Error> { + // setup logger + let _ = env_logger::try_init(); + let addr = env::args() + .nth(1) + .unwrap_or_else(|| "0.0.0.0:8888".to_string()); + // configure tokio runtime + let rt = Builder::new_multi_thread() + .enable_all() + .thread_name("tokio-ws") + .build() + .unwrap(); + // start tokio + rt.block_on(async { + let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); + info!("Listening on: {}", addr); + // Establish global broadcast channel. This is where all + // accepted events will be distributed for other connected clients. + let (bcast_tx, _) = broadcast::channel::<Event>(64); + // Establish database writer channel. This needs to be + // accessible from sync code, which is why the broadcast + // cannot be reused. + let (event_tx, _) = mpsc::channel::<Event>(64); + // start the database writer. + // TODO: manage program termination, to close the DB. + //let _db_handle = db_writer(event_rx).await; + while let Ok((stream, _)) = listener.accept().await { + tokio::spawn(nostr_server(stream, bcast_tx.clone(), event_tx.clone())); + } + }); + Ok(()) +} + +async fn _db_writer(_event_rx: tokio::sync::mpsc::Receiver<Event>) -> SQLResult<()> { + unimplemented!(); +} + +async fn nostr_server( + stream: TcpStream, + broadcast: Sender<Event>, + _event_tx: tokio::sync::mpsc::Sender<Event>, +) { + // get a broadcast channel for clients to communicate on + // wrap the TCP stream in a websocket. + let mut _bcast_rx = broadcast.subscribe(); + let conn = tokio_tungstenite::accept_async(stream).await; + let ws_stream = conn.expect("websocket handshake error"); + // a stream & sink of Nostr protocol messages + let mut nostr_stream = protostream::wrap_ws_in_nostr(ws_stream); + //let task_queue = mpsc::channel::<NostrMessage>(16); + // track connection state so we can break when it fails + // Track internal client state + let _conn = conn::ClientConn::new(); + let mut conn_good = true; + loop { + tokio::select! { + proto_next = nostr_stream.next() => { + match proto_next { + Some(Ok(EventMsg(e))) => { + // handle each type of message + let _x : Result<Event> = Result::<Event>::from(e); + }, + Some(Ok(SubMsg)) => {}, + Some(Ok(CloseMsg)) => {}, + None => { + info!("stream ended"); + //conn_good = true; + }, + Some(Err(Error::ConnError)) => { + info!("got connection error, disconnecting"); + conn_good = false; + if conn_good { + info!("Lint bug?, https://github.com/rust-lang/rust/pull/57302"); + } + return + } + Some(Err(e)) => { + info!("got error, continuing: {:?}", e); + }, + } + } + } + if conn_good == false { + break; + } + } } diff --git a/src/protostream.rs b/src/protostream.rs @@ -0,0 +1,99 @@ +use crate::error::{Error, Result}; +use crate::event::EventCmd; +use core::pin::Pin; +use futures::sink::Sink; +use futures::stream::Stream; +use futures::task::Context; +use futures::task::Poll; +use log::*; +use serde::{Deserialize, Serialize}; +use tokio::net::TcpStream; +use tokio_tungstenite::WebSocketStream; +use tungstenite::error::Error as WsError; +use tungstenite::protocol::Message; + +// A Nostr message is either event, subscription, or close. +#[derive(Deserialize, Serialize, Clone, PartialEq, Debug)] +#[serde(untagged)] +pub enum NostrMessage { + EventMsg(EventCmd), + SubMsg, + CloseMsg, +} + +// Either an event w/ subscription, or a notice +#[derive(Deserialize, Serialize, Clone, PartialEq, Debug)] +enum NostrResponse { + Notice(String), +} + +// A Nostr protocol stream is layered on top of a Websocket stream. +pub struct NostrStream { + ws_stream: WebSocketStream<TcpStream>, +} + +// given a websocket, return a protocol stream +//impl Stream<Item = Result<BasicMessage, BasicError>> + Sink<BasicResponse> +pub fn wrap_ws_in_nostr(ws: WebSocketStream<TcpStream>) -> NostrStream { + return NostrStream { ws_stream: ws }; +} + +impl Stream for NostrStream { + type Item = Result<NostrMessage>; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { + // convert Message to NostrMessage + fn convert(msg: String) -> Result<NostrMessage> { + debug!("Input message: {}", &msg); + let parsed_res: Result<NostrMessage> = serde_json::from_str(&msg).map_err(|e| e.into()); + match parsed_res { + Ok(m) => Ok(m), + Err(e) => { + debug!("Proto parse error: {:?}", e); + Err(Error::ProtoParseError) + } + } + } + + match Pin::new(&mut self.ws_stream).poll_next(cx) { + Poll::Pending => Poll::Pending, // not ready + Poll::Ready(None) => Poll::Ready(None), // done + Poll::Ready(Some(v)) => match v { + Ok(Message::Text(vs)) => Poll::Ready(Some(convert(vs))), // convert message->basicmessage + Ok(Message::Binary(_)) => Poll::Ready(Some(Err(Error::ProtoParseError))), + Ok(Message::Pong(_)) | Ok(Message::Ping(_)) => Poll::Pending, + Ok(Message::Close(_)) => Poll::Ready(None), + Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => Poll::Ready(None), // done + Err(_) => Poll::Ready(Some(Err(Error::ConnError))), + }, + } + } +} + +impl Sink<NostrResponse> for NostrStream { + type Error = Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + // map the error type + match Pin::new(&mut self.ws_stream).poll_ready(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(_)) => Poll::Ready(Err(Error::ConnWriteError)), + Poll::Pending => Poll::Pending, + } + } + + fn start_send(mut self: Pin<&mut Self>, item: NostrResponse) -> Result<(), Self::Error> { + let res_message = serde_json::to_string(&item).expect("Could convert message to string"); + match Pin::new(&mut self.ws_stream).start_send(Message::Text(res_message)) { + Ok(()) => Ok(()), + Err(_) => Err(Error::ConnWriteError), + } + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + Poll::Ready(Ok(())) + } +}