nostrdb-rs

nostrdb in rust!
git clone git://jb55.com/nostrdb-rs
Log | Files | Refs | Submodules | README | LICENSE

commit 0a3edf49f48c7151b9eedba44d777171ad334a3c
parent 42866db225207cb120ef3507da719b69f99649e4
Author: William Casarin <jb55@jb55.com>
Date:   Tue, 13 Aug 2024 11:28:11 -0700

Add filter iteration

It's ergonomic, zero copy, and typesafe!

Signed-off-by: William Casarin <jb55@jb55.com>

Diffstat:
Msrc/filter.rs | 527+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 527 insertions(+), 0 deletions(-)

diff --git a/src/filter.rs b/src/filter.rs @@ -80,6 +80,30 @@ impl Filter { } } + pub fn field(&self, index: i32) -> Option<FilterField<'_>> { + let ptr = unsafe { bindings::ndb_filter_get_elements(self.as_ptr(), index) }; + + if ptr.is_null() { + return None; + } + + Some(FilterElements::new(self, ptr).field()) + } + + pub fn elements(&self, index: i32) -> Option<FilterElements<'_>> { + let ptr = unsafe { bindings::ndb_filter_get_elements(self.as_ptr(), index) }; + + if ptr.is_null() { + return None; + } + + Some(FilterElements::new(self, ptr)) + } + + pub fn num_elements(&self) -> i32 { + unsafe { &*(self.as_ptr()) }.num_elements + } + pub fn as_ptr(&self) -> *const bindings::ndb_filter { self.data.as_ptr() } @@ -307,3 +331,506 @@ impl Drop for Filter { unsafe { bindings::ndb_filter_destroy(self.as_mut_ptr()) }; } } + +#[derive(Debug, Copy, Clone)] +pub struct FilterIter<'a> { + filter: &'a Filter, + index: i32, +} + +/// Filter element: `authors`, `limit`, etc +#[derive(Copy, Clone, Debug)] +pub struct FilterElements<'a> { + filter: &'a Filter, + elements: *const bindings::ndb_filter_elements, +} + +#[derive(Copy, Clone, Debug)] +pub struct FilterIdElements<'a> { + filter: &'a Filter, + elements: *const bindings::ndb_filter_elements, +} + +#[derive(Copy, Clone, Debug)] +pub struct FilterIntElements<'a> { + _filter: &'a Filter, + elements: *const bindings::ndb_filter_elements, +} + +pub struct FilterIdElemIter<'a> { + ids: FilterIdElements<'a>, + index: i32, +} + +pub struct FilterIntElemIter<'a> { + ints: FilterIntElements<'a>, + index: i32, +} + +impl<'a> FilterIdElemIter<'a> { + pub(crate) fn new(ids: FilterIdElements<'a>) -> Self { + let index = 0; + Self { ids, index } + } + + pub fn done(&self) -> bool { + self.index >= self.ids.count() + } +} + +impl<'a> FilterIntElemIter<'a> { + pub(crate) fn new(ints: FilterIntElements<'a>) -> Self { + let index = 0; + Self { ints, index } + } + + pub fn done(&self) -> bool { + self.index >= self.ints.count() + } +} + +impl<'a> FilterIdElements<'a> { + pub(crate) fn new(filter: &'a Filter, elements: *const bindings::ndb_filter_elements) -> Self { + Self { filter, elements } + } + + pub fn count(&self) -> i32 { + unsafe { &*self.elements }.count + } + + /// Field element type. In the case of ids, it would be FieldElemType::Id, etc + fn elemtype(&self) -> FieldElemType { + FieldElemType::new(unsafe { &*self.elements }.field.elem_type) + .expect("expected valid filter element type") + } + + pub fn get(self, index: i32) -> Option<&'a [u8; 32]> { + assert!(self.elemtype() == FieldElemType::Id); + + let id = unsafe { + bindings::ndb_filter_get_id_element(self.filter.as_ptr(), self.elements, index) + as *const [u8; 32] + }; + + if id.is_null() { + return None; + } + + Some(unsafe { &*id }) + } +} + +impl<'a> FilterIntElements<'a> { + pub(crate) fn new(filter: &'a Filter, elements: *const bindings::ndb_filter_elements) -> Self { + Self { + _filter: filter, + elements, + } + } + + pub fn count(&self) -> i32 { + unsafe { &*self.elements }.count + } + + /// Field element type. In the case of ids, it would be FieldElemType::Id, etc + fn elemtype(&self) -> FieldElemType { + FieldElemType::new(unsafe { &*self.elements }.field.elem_type) + .expect("expected valid filter element type") + } + + pub fn get(self, index: i32) -> Option<u64> { + if index >= self.count() { + return None; + } + assert!(self.elemtype() == FieldElemType::Int); + Some(unsafe { bindings::ndb_filter_get_int_element(self.elements, index) }) + } +} + +pub enum FilterField<'a> { + Ids(FilterIdElements<'a>), + Authors(FilterIdElements<'a>), + Kinds(FilterIntElements<'a>), + Tags(char, FilterElements<'a>), + Since(u64), + Until(u64), + Limit(u64), +} + +impl<'a> FilterField<'a> { + pub fn new(elements: FilterElements<'a>) -> Self { + match elements.fieldtype() { + FilterFieldType::Ids => { + FilterField::Ids(FilterIdElements::new(elements.filter(), elements.as_ptr())) + } + + FilterFieldType::Authors => { + FilterField::Authors(FilterIdElements::new(elements.filter(), elements.as_ptr())) + } + + FilterFieldType::Kinds => { + FilterField::Kinds(FilterIntElements::new(elements.filter(), elements.as_ptr())) + } + + FilterFieldType::Tags => FilterField::Tags(elements.tag(), elements), + + FilterFieldType::Since => FilterField::Since( + FilterIntElements::new(elements.filter(), elements.as_ptr()) + .into_iter() + .next() + .expect("expected since in filter"), + ), + + FilterFieldType::Until => FilterField::Until( + FilterIntElements::new(elements.filter(), elements.as_ptr()) + .into_iter() + .next() + .expect("expected until in filter"), + ), + + FilterFieldType::Limit => FilterField::Limit( + FilterIntElements::new(elements.filter(), elements.as_ptr()) + .into_iter() + .next() + .expect("expected limit in filter"), + ), + } + } +} + +impl<'a> FilterElements<'a> { + pub(crate) fn new(filter: &'a Filter, elements: *const bindings::ndb_filter_elements) -> Self { + FilterElements { filter, elements } + } + + pub fn filter(self) -> &'a Filter { + self.filter + } + + pub fn as_ptr(self) -> *const bindings::ndb_filter_elements { + self.elements + } + + pub fn count(&self) -> i32 { + unsafe { &*self.elements }.count + } + + pub fn field(self) -> FilterField<'a> { + FilterField::new(self) + } + + pub fn get(self, index: i32) -> Option<FilterElement<'a>> { + if index >= self.count() { + return None; + } + + match self.elemtype() { + FieldElemType::Id => { + let id = unsafe { + bindings::ndb_filter_get_id_element(self.filter.as_ptr(), self.elements, index) + as *const [u8; 32] + }; + if id.is_null() { + return None; + } + Some(FilterElement::Id(unsafe { &*id })) + } + + FieldElemType::Str => { + let cstr = unsafe { + bindings::ndb_filter_get_string_element( + self.filter.as_ptr(), + self.elements, + index, + ) + }; + if cstr.is_null() { + return None; + } + let str = unsafe { + let byte_slice = + std::slice::from_raw_parts(cstr as *const u8, libc::strlen(cstr)); + std::str::from_utf8_unchecked(byte_slice) + }; + Some(FilterElement::Str(str)) + } + + FieldElemType::Int => { + let num = unsafe { bindings::ndb_filter_get_int_element(self.elements, index) }; + Some(FilterElement::Int(num)) + } + } + } + + /// Field element type. In the case of ids, it would be FieldElemType::Id, etc + pub fn elemtype(&self) -> FieldElemType { + FieldElemType::new(unsafe { &*self.elements }.field.elem_type) + .expect("expected valid filter element type") + } + + /// Field element type. In the case of ids, it would be FieldElemType::Id, etc + pub fn tag(&self) -> char { + (unsafe { &*self.elements }.field.tag as u8) as char + } + + pub fn fieldtype(self) -> FilterFieldType { + FilterFieldType::new(unsafe { &*self.elements }.field.type_) + .expect("expected valid fieldtype") + } +} + +impl<'a> FilterIter<'a> { + pub fn new(filter: &'a Filter) -> Self { + let index = 0; + FilterIter { filter, index } + } + + pub fn done(&self) -> bool { + self.index >= self.filter.num_elements() + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum FilterFieldType { + Ids, + Authors, + Kinds, + Tags, + Since, + Until, + Limit, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum FieldElemType { + Str, + Id, + Int, +} + +impl FieldElemType { + pub(crate) fn new(val: bindings::ndb_generic_element_type) -> Option<Self> { + if val == bindings::ndb_generic_element_type_NDB_ELEMENT_UNKNOWN { + None + } else if val == bindings::ndb_generic_element_type_NDB_ELEMENT_STRING { + Some(FieldElemType::Str) + } else if val == bindings::ndb_generic_element_type_NDB_ELEMENT_ID { + Some(FieldElemType::Id) + } else if val == bindings::ndb_generic_element_type_NDB_ELEMENT_INT { + Some(FieldElemType::Int) + } else { + None + } + } +} + +impl FilterFieldType { + pub(crate) fn new(val: bindings::ndb_filter_fieldtype) -> Option<Self> { + if val == bindings::ndb_filter_fieldtype_NDB_FILTER_IDS { + Some(FilterFieldType::Ids) + } else if val == bindings::ndb_filter_fieldtype_NDB_FILTER_AUTHORS { + Some(FilterFieldType::Authors) + } else if val == bindings::ndb_filter_fieldtype_NDB_FILTER_KINDS { + Some(FilterFieldType::Kinds) + } else if val == bindings::ndb_filter_fieldtype_NDB_FILTER_TAGS { + Some(FilterFieldType::Tags) + } else if val == bindings::ndb_filter_fieldtype_NDB_FILTER_SINCE { + Some(FilterFieldType::Since) + } else if val == bindings::ndb_filter_fieldtype_NDB_FILTER_UNTIL { + Some(FilterFieldType::Until) + } else if val == bindings::ndb_filter_fieldtype_NDB_FILTER_LIMIT { + Some(FilterFieldType::Limit) + } else { + None + } + } +} + +impl<'a> IntoIterator for &'a Filter { + type Item = FilterField<'a>; + type IntoIter = FilterIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + FilterIter::new(self) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum FilterElement<'a> { + Str(&'a str), + Id(&'a [u8; 32]), + Int(u64), +} + +impl<'a> Iterator for FilterIter<'a> { + type Item = FilterField<'a>; + + fn next(&mut self) -> Option<FilterField<'a>> { + if self.done() { + return None; + } + + let ind = self.index; + self.index += 1; + + self.filter.field(ind) + } +} + +impl<'a> IntoIterator for FilterIdElements<'a> { + type Item = &'a [u8; 32]; + type IntoIter = FilterIdElemIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + FilterIdElemIter::new(self) + } +} + +impl<'a> IntoIterator for FilterIntElements<'a> { + type Item = u64; + type IntoIter = FilterIntElemIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + FilterIntElemIter::new(self) + } +} + +impl<'a> Iterator for FilterIntElemIter<'a> { + type Item = u64; + + fn next(&mut self) -> Option<u64> { + if self.done() { + return None; + } + + let ind = self.index; + self.index += 1; + + self.ints.get(ind) + } +} + +impl<'a> Iterator for FilterIdElemIter<'a> { + type Item = &'a [u8; 32]; + + fn next(&mut self) -> Option<&'a [u8; 32]> { + if self.done() { + return None; + } + + let ind = self.index; + self.index += 1; + + self.ids.get(ind) + } +} + +impl<'a> IntoIterator for FilterElements<'a> { + type Item = FilterElement<'a>; + type IntoIter = FilterElemIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + FilterElemIter::new(self) + } +} + +impl<'a> Iterator for FilterElemIter<'a> { + type Item = FilterElement<'a>; + + fn next(&mut self) -> Option<FilterElement<'a>> { + let element = self.elements.get(self.index); + if element.is_some() { + self.index += 1; + element + } else { + None + } + } +} + +#[derive(Copy, Clone, Debug)] +pub struct FilterElemIter<'a> { + elements: FilterElements<'a>, + index: i32, +} + +impl<'a> FilterElemIter<'a> { + pub(crate) fn new(elements: FilterElements<'a>) -> Self { + let index = 0; + FilterElemIter { elements, index } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn filter_limit_iter_works() { + let filter = Filter::new().limit(42).build(); + let mut hit = 0; + for element in &filter { + if let FilterField::Limit(42) = element { + hit += 1; + } + } + assert!(hit == 1); + } + + #[test] + fn filter_id_iter_works() { + let id: [u8; 32] = [ + 0xfb, 0x16, 0x5b, 0xe2, 0x2c, 0x7b, 0x25, 0x18, 0xb7, 0x49, 0xaa, 0xbb, 0x71, 0x40, + 0xc7, 0x3f, 0x08, 0x87, 0xfe, 0x84, 0x47, 0x5c, 0x82, 0x78, 0x57, 0x00, 0x66, 0x3b, + 0xe8, 0x5b, 0xa8, 0x59, + ]; + + let filter = Filter::new().ids(vec![id, id, id]).build(); + let mut hit = 0; + for element in &filter { + if let FilterField::Ids(ids) = element { + for same_id in ids { + hit += 1; + assert!(same_id == &id); + } + } + } + assert!(hit == 3); + } + + #[test] + fn filter_int_iter_works() { + let filter = Filter::new().kinds(vec![1, 2, 3]).build(); + let mut hit = 0; + for element in &filter { + if let FilterField::Kinds(ks) = element { + hit += 1; + assert!(vec![1, 2, 3] == ks.into_iter().collect::<Vec<u64>>()); + } + } + assert!(hit == 1); + } + + #[test] + fn filter_multiple_field_iter_works() { + let id: [u8; 32] = [ + 0xfb, 0x16, 0x5b, 0xe2, 0x2c, 0x7b, 0x25, 0x18, 0xb7, 0x49, 0xaa, 0xbb, 0x71, 0x40, + 0xc7, 0x3f, 0x08, 0x87, 0xfe, 0x84, 0x47, 0x5c, 0x82, 0x78, 0x57, 0x00, 0x66, 0x3b, + 0xe8, 0x5b, 0xa8, 0x59, + ]; + let filter = Filter::new().event(&id).kinds(vec![1, 2, 3]).build(); + let mut hit = 0; + for element in &filter { + if let FilterField::Kinds(ks) = element { + hit += 1; + assert!(vec![1, 2, 3] == ks.into_iter().collect::<Vec<u64>>()); + } else if let FilterField::Tags('e', ids) = element { + for i in ids { + hit += 1; + assert!(i == FilterElement::Id(&id)); + } + } + } + assert!(hit == 2); + } +}