commit 6df138a2c634f5dbdb5a131add00bb1bb5118756
parent 5e14ec7e2308622fd73a57c02f2daf9c37a888df
Author: William Casarin <jb55@jb55.com>
Date: Mon, 9 Dec 2024 16:39:37 -0800
async: adding efficient, poll-based stream support
This is a much more efficient, polling-based stream implementation that
doesn't rely on horrible things like spawning threads just to do async.
Changelog-Added: Add async stream support
Signed-off-by: William Casarin <jb55@jb55.com>
Diffstat:
6 files changed, 189 insertions(+), 52 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
@@ -21,6 +21,7 @@ bindgen = []
flatbuffers = "23.5.26"
libc = "0.2.151"
thiserror = "2.0.7"
+futures = "0.3.31"
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
diff --git a/src/config.rs b/src/config.rs
@@ -3,8 +3,6 @@ use crate::bindings;
#[derive(Copy, Clone)]
pub struct Config {
pub config: bindings::ndb_config,
- // We add a flag to know if we've installed a Rust closure so we can clean it up in Drop.
- is_rust_closure: bool,
}
impl Default for Config {
@@ -29,11 +27,7 @@ impl Config {
bindings::ndb_default_config(&mut config);
}
- let is_rust_closure = false;
- Config {
- config,
- is_rust_closure,
- }
+ Config { config }
}
//
@@ -54,7 +48,8 @@ impl Config {
self
}
- /// Set a callback for when we have
+ /// Set a callback to be notified on updated subscriptions. The function
+ /// will be called with the corresponsing subscription id.
pub fn set_sub_callback<F>(mut self, closure: F) -> Self
where
F: FnMut(u64) + 'static,
@@ -67,7 +62,6 @@ impl Config {
self.config.sub_cb = Some(sub_callback_trampoline);
self.config.sub_cb_ctx = ctx_ptr;
- self.is_rust_closure = true;
self
}
diff --git a/src/future.rs b/src/future.rs
@@ -0,0 +1,87 @@
+use crate::{Ndb, NoteKey, Subscription};
+
+use std::{
+ pin::Pin,
+ task::{Context, Poll},
+};
+
+use futures::Stream;
+
+/// Used to track query futures
+#[derive(Debug, Clone)]
+pub(crate) struct SubscriptionState {
+ pub ready: bool,
+ pub done: bool,
+ pub waker: Option<std::task::Waker>,
+}
+
+/// A subscription that you can .await on. This can enables very clean
+/// integration into Rust's async state machinery.
+pub struct SubscriptionStream {
+ // some handle or state
+ // e.g., a reference to a non-blocking API or a shared atomic state
+ ndb: Ndb,
+ sub_id: Subscription,
+ max_notes: u32,
+}
+
+impl SubscriptionStream {
+ pub fn new(ndb: Ndb, sub_id: Subscription) -> Self {
+ // Most of the time we only want to fetch a few things. If expecting
+ // lots of data, use `set_max_notes_per_await`
+ let max_notes = 32;
+ SubscriptionStream {
+ ndb,
+ sub_id,
+ max_notes,
+ }
+ }
+
+ pub fn notes_per_await(mut self, max_notes: u32) -> Self {
+ self.max_notes = max_notes;
+ self
+ }
+
+ pub fn sub_id(&self) -> Subscription {
+ self.sub_id
+ }
+}
+
+impl Drop for SubscriptionStream {
+ fn drop(&mut self) {
+ // Perform cleanup here, like removing the subscription from the global map
+ let mut map = self.ndb.subs.lock().unwrap();
+ map.remove(&self.sub_id);
+ }
+}
+
+impl Stream for SubscriptionStream {
+ type Item = Vec<NoteKey>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ let pinned = std::pin::pin!(self);
+ let me = pinned.as_ref().get_ref();
+ let mut map = me.ndb.subs.lock().unwrap();
+ let sub_state = map.entry(me.sub_id).or_insert(SubscriptionState {
+ ready: false,
+ done: false,
+ waker: None,
+ });
+
+ // we've unsubscribed
+ if sub_state.done {
+ return Poll::Ready(None);
+ }
+
+ if sub_state.ready {
+ // Reset ready, fetch notes
+ sub_state.ready = false;
+ let notes = me.ndb.poll_for_notes(me.sub_id, me.max_notes);
+ return Poll::Ready(Some(notes));
+ }
+
+ // Not ready yet, store waker
+ sub_state.waker = Some(cx.waker().clone());
+ std::task::Poll::Pending
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
@@ -12,6 +12,9 @@ mod bindings;
mod ndb_profile;
mod block;
+
+mod future;
+
mod config;
mod error;
mod filter;
@@ -30,6 +33,8 @@ pub use block::{Block, BlockType, Blocks, Mention};
pub use config::Config;
pub use error::{Error, FilterError};
pub use filter::{Filter, FilterBuilder};
+pub(crate) use future::SubscriptionState;
+pub use future::SubscriptionStream;
pub use ndb::Ndb;
pub use ndb_profile::{NdbProfile, NdbProfileRecord};
pub use ndb_str::{NdbStr, NdbStrVariant};
diff --git a/src/ndb.rs b/src/ndb.rs
@@ -3,22 +3,20 @@ use std::ptr;
use crate::{
bindings, Blocks, Config, Error, Filter, Note, NoteKey, ProfileKey, ProfileRecord, QueryResult,
- Result, Subscription, Transaction,
+ Result, Subscription, SubscriptionState, SubscriptionStream, Transaction,
};
+use futures::StreamExt;
+use std::collections::hash_map::Entry;
+use std::collections::HashMap;
use std::fs;
use std::os::raw::c_int;
use std::path::Path;
-use std::sync::Arc;
-use tokio::task; // Make sure to import the task module
+use std::sync::{Arc, Mutex};
use tracing::debug;
#[derive(Debug)]
struct NdbRef {
ndb: *mut bindings::ndb,
-
- /// Have we configured a rust closure for our callback? If so we need
- /// to clean that up when this is dropped
- has_rust_closure: bool,
rust_cb_ctx: *mut ::std::os::raw::c_void,
}
@@ -34,7 +32,7 @@ impl Drop for NdbRef {
unsafe {
bindings::ndb_destroy(self.ndb);
- if self.has_rust_closure && !self.rust_cb_ctx.is_null() {
+ if !self.rust_cb_ctx.is_null() {
// Rebuild the Box from the raw pointer and drop it.
let _ = Box::from_raw(self.rust_cb_ctx as *mut Box<dyn FnMut()>);
}
@@ -42,10 +40,15 @@ impl Drop for NdbRef {
}
}
+type SubMap = HashMap<Subscription, SubscriptionState>;
+
/// A nostrdb context. Construct one of these with [Ndb::new].
#[derive(Debug, Clone)]
pub struct Ndb {
refs: Arc<NdbRef>,
+
+ /// Track query future states
+ pub(crate) subs: Arc<Mutex<SubMap>>,
}
impl Ndb {
@@ -65,7 +68,30 @@ impl Ndb {
let min_mapsize = 1024 * 1024 * 512;
let mut mapsize = config.config.mapsize;
- let mut config = *config;
+ let config = *config;
+
+ let prev_callback = config.config.sub_cb;
+ let prev_callback_ctx = config.config.sub_cb_ctx;
+ let subs = Arc::new(Mutex::new(SubMap::default()));
+ let subs_clone = subs.clone();
+
+ // We need to register our own callback so that we can wake
+ // query futures
+ let mut config = config.set_sub_callback(move |sub_id: u64| {
+ let mut map = subs_clone.lock().unwrap();
+ if let Some(s) = map.get_mut(&Subscription::new(sub_id)) {
+ s.ready = true;
+ if let Some(w) = s.waker.take() {
+ w.wake();
+ }
+ }
+
+ if let Some(pcb) = prev_callback {
+ unsafe {
+ pcb(prev_callback_ctx, sub_id);
+ };
+ }
+ });
let result = loop {
let result =
@@ -90,15 +116,10 @@ impl Ndb {
return Err(Error::DbOpenFailed);
}
- let has_rust_closure = !config.config.sub_cb_ctx.is_null();
let rust_cb_ctx = config.config.sub_cb_ctx;
- let refs = Arc::new(NdbRef {
- ndb,
- has_rust_closure,
- rust_cb_ctx,
- });
+ let refs = Arc::new(NdbRef { ndb, rust_cb_ctx });
- Ok(Ndb { refs })
+ Ok(Ndb { refs, subs })
}
/// Ingest a relay-sent event in the form `["EVENT","subid", {"id:"...}]`
@@ -155,9 +176,17 @@ impl Ndb {
unsafe { bindings::ndb_num_subscriptions(self.as_ptr()) as u32 }
}
- pub fn unsubscribe(&self, sub: Subscription) -> Result<()> {
+ pub fn unsubscribe(&mut self, sub: Subscription) -> Result<()> {
let r = unsafe { bindings::ndb_unsubscribe(self.as_ptr(), sub.id()) };
+ // mark the subscription as done if it exists in our stream map
+ {
+ let mut map = self.subs.lock().unwrap();
+ if let Entry::Occupied(mut entry) = map.entry(sub) {
+ entry.get_mut().done = true;
+ }
+ }
+
if r == 0 {
Err(Error::SubscriptionError)
} else {
@@ -204,32 +233,11 @@ impl Ndb {
sub_id: Subscription,
max_notes: u32,
) -> Result<Vec<NoteKey>> {
- let ndb = self.clone();
- let handle = task::spawn_blocking(move || {
- let mut vec: Vec<u64> = vec![];
- vec.reserve_exact(max_notes as usize);
- let res = unsafe {
- bindings::ndb_wait_for_notes(
- ndb.as_ptr(),
- sub_id.id(),
- vec.as_mut_ptr(),
- max_notes as c_int,
- )
- };
- if res == 0 {
- Err(Error::SubscriptionError)
- } else {
- unsafe {
- vec.set_len(res as usize);
- };
- Ok(vec)
- }
- });
+ let mut stream = SubscriptionStream::new(self.clone(), sub_id).notes_per_await(max_notes);
- match handle.await {
- Ok(Ok(res)) => Ok(res.into_iter().map(NoteKey::new).collect()),
- Ok(Err(err)) => Err(err),
- Err(_) => Err(Error::SubscriptionError),
+ match stream.next().await {
+ Some(res) => Ok(res),
+ None => Err(Error::SubscriptionError),
}
}
@@ -527,4 +535,40 @@ mod tests {
// we should definitely clean this up... especially on windows
test_util::cleanup_db(&db);
}
+
+ #[tokio::test]
+ async fn test_stream() {
+ let db = "target/testdbs/test_callback";
+ test_util::cleanup_db(&db);
+
+ {
+ let mut ndb = Ndb::new(db, &Config::new()).expect("ndb");
+ let sub_id = {
+ let filter = Filter::new().kinds(vec![1]).build();
+ let filters = vec![filter];
+
+ let sub_id = ndb.subscribe(&filters).expect("sub_id");
+ let mut sub = sub_id.stream(&ndb).notes_per_await(1);
+
+ let res = sub.next();
+
+ ndb.process_event(r#"["EVENT","b",{"id": "702555e52e82cc24ad517ba78c21879f6e47a7c0692b9b20df147916ae8731a3","pubkey": "32bf915904bfde2d136ba45dde32c88f4aca863783999faea2e847a8fafd2f15","created_at": 1702675561,"kind": 1,"tags": [],"content": "hello, world","sig": "2275c5f5417abfd644b7bc74f0388d70feb5d08b6f90fa18655dda5c95d013bfbc5258ea77c05b7e40e0ee51d8a2efa931dc7a0ec1db4c0a94519762c6625675"}]"#).expect("process ok");
+
+ let res = res.await.expect("await ok");
+ assert_eq!(res, vec![NoteKey::new(1)]);
+
+ // ensure that unsubscribing kills the stream
+ assert!(ndb.unsubscribe(sub_id).is_ok());
+ assert!(sub.next().await.is_none());
+
+ assert!(ndb.subs.lock().unwrap().contains_key(&sub_id));
+ sub_id
+ };
+
+ // ensure subscription state is removed after stream is dropped
+ assert!(!ndb.subs.lock().unwrap().contains_key(&sub_id));
+ }
+
+ test_util::cleanup_db(&db);
+ }
}
diff --git a/src/subscription.rs b/src/subscription.rs
@@ -1,4 +1,6 @@
-#[derive(Debug, Clone, Copy, Eq, PartialEq)]
+use crate::{Ndb, SubscriptionStream};
+
+#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
pub struct Subscription(u64);
impl Subscription {
@@ -8,4 +10,8 @@ impl Subscription {
pub fn id(self) -> u64 {
self.0
}
+
+ pub fn stream(&self, ndb: &Ndb) -> SubscriptionStream {
+ SubscriptionStream::new(ndb.clone(), *self)
+ }
}