damus

nostr ios client
git clone git://jb55.com/damus
Log | Files | Refs | README | LICENSE

commit d541153e4cb398a06279388b0c298c438e3908b1
parent 53fc1b694551333619a48b40d3d4dc439282d07c
Author: William Casarin <jb55@jb55.com>
Date:   Mon, 27 Nov 2023 16:08:42 -0800

nostrdb/Add fulltext search index

Signed-off-by: William Casarin <jb55@jb55.com>

Diffstat:
Mnostrdb/nostrdb.c | 567++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---
Mnostrdb/nostrdb.h | 7+++++++
2 files changed, 555 insertions(+), 19 deletions(-)

diff --git a/nostrdb/nostrdb.c b/nostrdb/nostrdb.c @@ -44,6 +44,8 @@ static const int DEFAULT_QUEUE_SIZE = 1000000; #define NDB_PARSED_ALL (NDB_PARSED_ID|NDB_PARSED_PUBKEY|NDB_PARSED_SIG|NDB_PARSED_CREATED_AT|NDB_PARSED_KIND|NDB_PARSED_CONTENT|NDB_PARSED_TAGS) typedef int (*ndb_migrate_fn)(struct ndb *); +typedef int (*ndb_word_parser_fn)(void *, const char *word, int word_len, + int word_index); struct ndb_migration { ndb_migrate_fn fn; @@ -133,6 +135,156 @@ struct ndb_u64_tsid { uint64_t timestamp; }; +// uncompressed form of the actual lmdb key +struct ndb_text_search_key +{ + int str_len; + const char *str; + int word_index; + uint64_t timestamp; +}; + +// ndb_text_search_key +// +// This is compressed when in lmdb: +// +// strlen: varint +// str: cstr +// timestamp: varint +// word_index: varint +static int ndb_make_text_search_key(unsigned char *buf, int bufsize, + int word_index, int word_len, const char *str, + uint64_t timestamp, int *keysize) +{ + struct cursor cur; + int size, pad; + make_cursor(buf, buf + bufsize, &cur); + + // string length + if (!push_varint(&cur, word_len)) + return 0; + + // non-null terminated string + if (!cursor_push(&cur, (unsigned char*)str, word_len)) + return 0; + + // the index of the word in the content so that we can do more accurate + // phrase searches + if (!push_varint(&cur, word_index)) + return 0; + + // TODO: need update this to uint64_t + if (!push_varint(&cur, (int)timestamp)) + return 0; + + size = cur.p - cur.start; + + // pad to 8-byte alignment + pad = ((size + 7) & ~7) - size; + if (pad > 0) { + if (!cursor_memset(&cur, 0, pad)) { + return 0; + } + } + + *keysize = cur.p - cur.start; + assert((*keysize % 8) == 0); + + return 1; +} + +static int ndb_make_text_search_key_low(unsigned char *buf, int bufsize, + int wordlen, const char *word, + int *keysize) +{ + return ndb_make_text_search_key(buf, bufsize, 0, wordlen, word, 0, keysize); +} + +/** From LMDB: Compare two items lexically */ +static int mdb_cmp_memn(const MDB_val *a, const MDB_val *b) { + int diff; + ssize_t len_diff; + unsigned int len; + + len = a->mv_size; + len_diff = (ssize_t) a->mv_size - (ssize_t) b->mv_size; + if (len_diff > 0) { + len = b->mv_size; + len_diff = 1; + } + + diff = memcmp(a->mv_data, b->mv_data, len); + return diff ? diff : len_diff<0 ? -1 : len_diff; +} + +static int ndb_text_search_key_compare(const MDB_val *a, const MDB_val *b) +{ + struct cursor ca, cb; + int sa, sb; + MDB_val a2, b2; + + make_cursor(a->mv_data, a->mv_data + a->mv_size, &ca); + make_cursor(b->mv_data, b->mv_data + b->mv_size, &cb); + + // string size + if (unlikely(!pull_varint(&ca, &sa) || !pull_varint(&cb, &sb))) + return 0; + + a2.mv_data = ca.p; + a2.mv_size = sa; + + b2.mv_data = cb.p; + b2.mv_size = sb; + + int cmp = mdb_cmp_memn(&a2, &b2); + if (cmp) return cmp; + + // skip over string + ca.p += sa; + cb.p += sb; + + // timestamp + if (unlikely(!pull_varint(&ca, &sa) || !pull_varint(&cb, &sb))) + return 0; + + if (sa < sb) return -1; + else if (sa > sb) return 1; + + // word index + if (unlikely(!pull_varint(&ca, &sa) || !pull_varint(&cb, &sb))) + return 0; + + if (sa < sb) return -1; + else if (sa > sb) return 1; + + return 0; +} + +/* +static int ndb_decompress_text_search_key(unsigned char *p, int len, + struct ndb_text_search_key *key) +{ + struct cursor c; + + make_cursor(p, p + len, &c); + + if (!pull_varint(&c, &key->str_len)) + return 0; + + key->str = cur->p; + + if (!cursor_skip(&c, key->str_len)) + return 0; + + if (!pull_varint(&c, &key->word_index)) + return 0; + + if (!pull_varint(&c, &key->timestamp)) + return 0; + +} +*/ + // Copies only lowercase characters to the destination string and fills the rest with null bytes. // `dst` and `src` are pointers to the destination and source strings, respectively. // `n` is the maximum number of characters to copy. @@ -742,23 +894,6 @@ int ndb_db_version(struct ndb *ndb) return version; } -/** From LMDB: Compare two items lexically */ -static int mdb_cmp_memn(const MDB_val *a, const MDB_val *b) { - int diff; - ssize_t len_diff; - unsigned int len; - - len = a->mv_size; - len_diff = (ssize_t) a->mv_size - (ssize_t) b->mv_size; - if (len_diff > 0) { - len = b->mv_size; - len_diff = 1; - } - - diff = memcmp(a->mv_data, b->mv_data, len); - return diff ? diff : len_diff<0 ? -1 : len_diff; -} - // custom kind+timestamp comparison function. This is used by lmdb to perform // b+ tree searches over the kind+timestamp index static int ndb_u64_tsid_compare(const MDB_val *a, const MDB_val *b) @@ -814,10 +949,10 @@ static inline void ndb_tsid_init(struct ndb_tsid *key, unsigned char *id, key->timestamp = timestamp; } -static inline void ndb_u64_tsid_init(struct ndb_tsid *key, uint64_t integer, +static inline void ndb_u64_tsid_init(struct ndb_u64_tsid *key, uint64_t integer, uint64_t timestamp) { - key->integer = integer; + key->u64 = integer; key->timestamp = timestamp; } @@ -1877,6 +2012,388 @@ static int ndb_write_note_kind_index(struct ndb_txn *txn, struct ndb_note *note, return 1; } +/** + * Checks if a given Unicode code point is a punctuation character + * + * @param codepoint The Unicode code point to check. @return true if the + * code point is a punctuation character, false otherwise. + */ +static inline int is_punctuation(unsigned int codepoint) { + // Check for underscore (underscore is not treated as punctuation) + if (codepoint == '_') + return 0; + + // Check for ASCII punctuation + if (ispunct(codepoint)) + return 1; + + // Check for Unicode punctuation exceptions (punctuation allowed in hashtags) + if (codepoint == 0x301C || codepoint == 0xFF5E) // Japanese Wave Dash / Tilde + return 0; + + // Check for Unicode punctuation + // NOTE: We may need to adjust the codepoint ranges in the future, + // to include/exclude certain types of Unicode characters in hashtags. + // Unicode Blocks Reference: https://www.compart.com/en/unicode/block + return ( + // Latin-1 Supplement No-Break Space (NBSP): U+00A0 + (codepoint == 0x00A0) || + + // Latin-1 Supplement Punctuation: U+00A1 to U+00BF + (codepoint >= 0x00A1 && codepoint <= 0x00BF) || + + // General Punctuation: U+2000 to U+206F + (codepoint >= 0x2000 && codepoint <= 0x206F) || + + // Currency Symbols: U+20A0 to U+20CF + (codepoint >= 0x20A0 && codepoint <= 0x20CF) || + + // Supplemental Punctuation: U+2E00 to U+2E7F + (codepoint >= 0x2E00 && codepoint <= 0x2E7F) || + + // CJK Symbols and Punctuation: U+3000 to U+303F + (codepoint >= 0x3000 && codepoint <= 0x303F) || + + // Ideographic Description Characters: U+2FF0 to U+2FFF + (codepoint >= 0x2FF0 && codepoint <= 0x2FFF) + ); +} + +static inline int is_whitespace(char c) { + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; +} + +static inline int is_right_boundary(int c) { + return is_whitespace(c) || is_punctuation(c); +} + +static inline int parse_byte(struct cursor *cursor, unsigned char *c) +{ + if (unlikely(cursor->p >= cursor->end)) + return 0; + + *c = *cursor->p; + + return 1; +} + +static inline int peek_char(struct cursor *cur, int ind) { + if ((cur->p + ind < cur->start) || (cur->p + ind >= cur->end)) + return -1; + + return *(cur->p + ind); +} + +static int parse_utf8_char(struct cursor *cursor, unsigned int *code_point, + unsigned int *utf8_length) +{ + unsigned char first_byte; + if (!parse_byte(cursor, &first_byte)) + return 0; // Not enough data + + // Determine the number of bytes in this UTF-8 character + int remaining_bytes = 0; + if (first_byte < 0x80) { + *code_point = first_byte; + return 1; + } else if ((first_byte & 0xE0) == 0xC0) { + remaining_bytes = 1; + *utf8_length = remaining_bytes + 1; + *code_point = first_byte & 0x1F; + } else if ((first_byte & 0xF0) == 0xE0) { + remaining_bytes = 2; + *utf8_length = remaining_bytes + 1; + *code_point = first_byte & 0x0F; + } else if ((first_byte & 0xF8) == 0xF0) { + remaining_bytes = 3; + *utf8_length = remaining_bytes + 1; + *code_point = first_byte & 0x07; + } else { + remaining_bytes = 0; + *utf8_length = 1; // Assume 1 byte length for unrecognized UTF-8 characters + // TODO: We need to gracefully handle unrecognized UTF-8 characters + //printf("Invalid UTF-8 byte: %x\n", *code_point); + *code_point = ((first_byte & 0xF0) << 6); // Prevent testing as punctuation + return 0; // Invalid first byte + } + + // Peek at remaining bytes + for (int i = 0; i < remaining_bytes; ++i) { + signed char next_byte; + if ((next_byte = peek_char(cursor, i+1)) == -1) { + *utf8_length = 1; + return 0; // Not enough data + } + + if ((next_byte & 0xC0) != 0x80) { + *utf8_length = 1; + return 0; // Invalid byte in sequence + } + + *code_point = (*code_point << 6) | (next_byte & 0x3F); + } + + return 1; +} + + +static inline int is_utf8_byte(unsigned char c) { + return c & 0x80; +} + +static inline int consume_until_boundary(struct cursor *cur) { + unsigned int c; + unsigned int char_length = 1; + unsigned int *utf8_char_length = &char_length; + + while (cur->p < cur->end) { + c = *cur->p; + *utf8_char_length = 1; + + if (is_whitespace(c)) + return 1; + + // Need to check for UTF-8 characters, which can be multiple + // bytes long + if (is_utf8_byte(c)) { + if (!parse_utf8_char(cur, &c, utf8_char_length)) { + if (!is_right_boundary(c)){ + // TODO: We should work towards + // handling all UTF-8 characters. + //printf("Invalid UTF-8 code point: %x\n", c); + } + } + } + + if (is_right_boundary(c)) + return 1; + + // Need to use a variable character byte length for UTF-8 (2-4 bytes) + if (cur->p + *utf8_char_length <= cur->end) + cur->p += *utf8_char_length; + else + cur->p++; + } + + return 1; +} + +static void consume_whitespace_or_punctuation(struct cursor *cur) +{ + while (cur->p < cur->end) { + if (!is_right_boundary(*cur->p)) + return; + cur->p++; + } +} + +static int ndb_write_word_to_index(struct ndb_txn *txn, const char *word, + int word_len, int word_index, + uint64_t timestamp, uint64_t note_id) +{ + // cap to some reasonable key size + unsigned char buffer[1024]; + int keysize, rc; + MDB_val k, v; + MDB_dbi text_db; + + // build our compressed text index key + if (!ndb_make_text_search_key(buffer, sizeof(buffer), word_index, + word_len, word, timestamp, &keysize)) { + // probably too big + + return 0; + } + + k.mv_data = buffer; + k.mv_size = keysize; + + v.mv_data = &note_id; + v.mv_size = sizeof(note_id); + + text_db = txn->lmdb->dbs[NDB_DB_NOTE_TEXT]; + + if ((rc = mdb_put(txn->mdb_txn, text_db, &k, &v, 0))) { + ndb_debug("write note text index to db failed: %s\n", + mdb_strerror(rc)); + return 0; + } + + return 1; +} + + + +static int ndb_parse_words(struct cursor *cur, void *ctx, ndb_word_parser_fn fn) +{ + int word_len, words; + const char *word; + + words = 0; + + while (cur->p < cur->end) { + consume_whitespace_or_punctuation(cur); + if (cur->p >= cur->end) + break; + word = (const char *)cur->p; + + if (!consume_until_boundary(cur)) + break; + + // start of word or end + word_len = cur->p - (unsigned char *)word; + if (word_len == 0 && cur->p >= cur->end) + break; + + if (!fn(ctx, word, word_len, words)) + continue; + + words++; + } + + return 1; +} + +struct ndb_word_writer_ctx +{ + struct ndb_txn *txn; + struct ndb_note *note; + uint64_t note_id; +}; + +static int ndb_fulltext_word_writer(void *ctx, + const char *word, int word_len, int words) +{ + struct ndb_word_writer_ctx *wctx = ctx; + + if (!ndb_write_word_to_index(wctx->txn, word, word_len, words, + wctx->note->created_at, wctx->note_id)) { + // too big to write this one, just skip it + ndb_debug(stderr, "failed to write word '%.*s' to index\n", word_len, word); + + return 0; + } + + //fprintf(stderr, "wrote '%.*s' to note text index\n", word_len, word); + return 1; +} + +static int ndb_write_note_fulltext_index(struct ndb_txn *txn, + struct ndb_note *note, + uint64_t note_id) +{ + struct cursor cur; + unsigned char *content; + struct ndb_str str; + struct ndb_word_writer_ctx ctx; + + str = ndb_note_str(note, &note->content); + // I don't think this should happen? + if (unlikely(str.flag == NDB_PACKED_ID)) + return 0; + + content = (unsigned char *)str.str; + + make_cursor(content, content + note->content_length, &cur); + + ctx.txn = txn; + ctx.note = note; + ctx.note_id = note_id; + + ndb_parse_words(&cur, &ctx, ndb_fulltext_word_writer); + + return 1; +} + +struct ndb_word +{ + const char *word; + int word_len; +}; + +#define MAX_SEARCH_WORDS 16 + +struct ndb_search_words +{ + struct ndb_word words[MAX_SEARCH_WORDS]; + int num_words; +}; + +static int ndb_parse_search_words(void *ctx, const char *word_str, int word_len, int word_index) +{ + struct ndb_search_words *words = ctx; + struct ndb_word *word; + + if (words->num_words + 1 > MAX_SEARCH_WORDS) + return 0; + + word = &words->words[words->num_words++]; + word->word = word_str; + word->word_len = word_len; + + return 1; +} + +int ndb_text_search(struct ndb_txn *txn, const char *query) +{ + unsigned char buffer[1024]; + struct ndb_search_words words; + struct ndb_word *word; + struct cursor cur; + MDB_dbi text_db; + MDB_cursor *cursor; + MDB_val k, v; + int i, rc, keysize; + size_t len; + //uint64_t note_ids[32], note_id; + uint64_t note_id; + struct ndb_note *note; + //int num_note_ids; + + //num_note_ids = 0; + text_db = txn->lmdb->dbs[NDB_DB_NOTE_TEXT]; + make_cursor((unsigned char *)query, (unsigned char *)query + strlen(query), &cur); + words.num_words = 0; + + ndb_parse_words(&cur, &words, ndb_parse_search_words); + + if ((rc = mdb_cursor_open(txn->mdb_txn, text_db, &cursor))) { + fprintf(stderr, "nd_text_search: mdb_cursor_open failed, error %d\n", rc); + return 0; + } + + for (i = 0; i < words.num_words; i++) { + word = &words.words[i]; + fprintf(stderr, "search word %.*s\n", word->word_len, word->word); + + if (!ndb_make_text_search_key_low(buffer, sizeof(buffer), + word->word_len, word->word, + &keysize)) { + // word is too big to fit in 1024-sized key + continue; + } + + k.mv_data = buffer; + k.mv_size = keysize; + + // Position cursor at the next key greater than or equal to the specified key + if (mdb_cursor_get(cursor, &k, &v, MDB_SET_RANGE)) { + continue; + } else { + //note_ids[num_note_ids++] = *((uint64_t*)v.mv_data); + note_id = *((uint64_t*)v.mv_data); + if ((note = ndb_get_note_by_key(txn, note_id, &len))) { + fprintf(stderr, "found note: '%s' for query word '%.*s'\n", + ndb_note_str(note, &note->content).str, + word->word_len, word->word); + } + return 1; + } + } + + return 1; +} + static uint64_t ndb_write_note(struct ndb_txn *txn, struct ndb_writer_note *note) { @@ -1910,6 +2427,12 @@ static uint64_t ndb_write_note(struct ndb_txn *txn, if (!ndb_write_note_kind_index(txn, note->note, note_key)) return 0; + // only do fulltext index on kind1 notes + if (note->note->kind == 1) { + if (!ndb_write_note_fulltext_index(txn, note->note, note_key)) + return 0; + } + if (note->note->kind == 7) { ndb_write_reaction_stats(txn, note->note); } @@ -2282,6 +2805,12 @@ static int ndb_init_lmdb(const char *filename, struct ndb_lmdb *lmdb, size_t map } mdb_set_compare(txn, lmdb->dbs[NDB_DB_NOTE_KIND], ndb_u64_tsid_compare); + if ((rc = mdb_dbi_open(txn, "note_text", tsid_flags, &lmdb->dbs[NDB_DB_NOTE_TEXT]))) { + fprintf(stderr, "mdb_dbi_open id failed: %s\n", mdb_strerror(rc)); + return 0; + } + mdb_set_compare(txn, lmdb->dbs[NDB_DB_NOTE_TEXT], ndb_text_search_key_compare); + // Commit the transaction if ((rc = mdb_txn_commit(txn))) { fprintf(stderr, "mdb_txn_commit failed, error %d\n", rc); diff --git a/nostrdb/nostrdb.h b/nostrdb/nostrdb.h @@ -42,6 +42,7 @@ enum ndb_dbs { NDB_DB_PROFILE_SEARCH, NDB_DB_PROFILE_LAST_FETCH, NDB_DB_NOTE_KIND, // note kind index + NDB_DB_NOTE_TEXT, // note fulltext index NDB_DBS, }; @@ -327,6 +328,10 @@ void ndb_filter_reset(struct ndb_filter *); void ndb_filter_end_field(struct ndb_filter *); void ndb_filter_free(struct ndb_filter *filter); + +// FULLTEXT SEARCH +int ndb_text_search(struct ndb_txn *, const char *query); + // stats int ndb_stat(struct ndb *ndb, struct ndb_stat *stat); void ndb_stat_counts_init(struct ndb_stat_counts *counts); @@ -528,6 +533,8 @@ ndb_db_name(enum ndb_dbs db) return "profile_last_fetch"; case NDB_DB_NOTE_KIND: return "note_kind_index"; + case NDB_DB_NOTE_TEXT: + return "note_fulltext"; case NDB_DBS: return "count"; }