notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

tools.rs (18685B)


      1 use async_openai::types::*;
      2 use chrono::DateTime;
      3 use enostr::{NoteId, Pubkey};
      4 use nostrdb::{Ndb, Note, NoteKey, Transaction};
      5 use serde::{Deserialize, Serialize};
      6 use serde_json::{json, Value};
      7 use std::{collections::HashMap, fmt};
      8 
      9 /// A tool
     10 #[derive(Debug, Clone, Serialize, Deserialize)]
     11 pub struct ToolCall {
     12     id: String,
     13     typ: ToolCalls,
     14 }
     15 
     16 impl ToolCall {
     17     pub fn new(id: String, typ: ToolCalls) -> Self {
     18         Self { id, typ }
     19     }
     20 
     21     pub fn id(&self) -> &str {
     22         &self.id
     23     }
     24 
     25     pub fn invalid(
     26         id: String,
     27         name: Option<String>,
     28         arguments: Option<String>,
     29         error: String,
     30     ) -> Self {
     31         Self {
     32             id,
     33             typ: ToolCalls::Invalid(InvalidToolCall {
     34                 name,
     35                 arguments,
     36                 error,
     37             }),
     38         }
     39     }
     40 
     41     pub fn calls(&self) -> &ToolCalls {
     42         &self.typ
     43     }
     44 
     45     pub fn to_api(&self) -> ChatCompletionMessageToolCall {
     46         ChatCompletionMessageToolCall {
     47             id: self.id.clone(),
     48             r#type: ChatCompletionToolType::Function,
     49             function: self.typ.to_api(),
     50         }
     51     }
     52 }
     53 
     54 /// On streaming APIs, tool calls are incremental. We use this
     55 /// to represent tool calls that are in the process of returning.
     56 /// These eventually just become [`ToolCall`]'s
     57 #[derive(Default, Debug, Clone, Serialize, Deserialize)]
     58 pub struct PartialToolCall {
     59     pub id: Option<String>,
     60     pub name: Option<String>,
     61     pub arguments: Option<String>,
     62 }
     63 
     64 impl PartialToolCall {
     65     pub fn id(&self) -> Option<&str> {
     66         self.id.as_deref()
     67     }
     68 
     69     pub fn id_mut(&mut self) -> &mut Option<String> {
     70         &mut self.id
     71     }
     72 
     73     pub fn name(&self) -> Option<&str> {
     74         self.name.as_deref()
     75     }
     76 
     77     pub fn name_mut(&mut self) -> &mut Option<String> {
     78         &mut self.name
     79     }
     80 
     81     pub fn arguments(&self) -> Option<&str> {
     82         self.arguments.as_deref()
     83     }
     84 
     85     pub fn arguments_mut(&mut self) -> &mut Option<String> {
     86         &mut self.arguments
     87     }
     88 }
     89 
     90 /// The query response from nostrdb for a given context
     91 #[derive(Debug, Clone, Serialize, Deserialize)]
     92 pub struct QueryResponse {
     93     pub notes: Vec<u64>,
     94 }
     95 
     96 #[derive(Debug, Clone, Serialize, Deserialize)]
     97 pub enum ToolResponses {
     98     Error(String),
     99     Query(QueryResponse),
    100     PresentNotes(i32),
    101 }
    102 
    103 #[derive(Debug, Clone)]
    104 pub struct UnknownToolCall {
    105     id: String,
    106     name: String,
    107     arguments: String,
    108 }
    109 
    110 impl UnknownToolCall {
    111     pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> {
    112         let Some(tool) = tools.get(&self.name) else {
    113             return Err(ToolCallError::NotFound(self.name.to_owned()));
    114         };
    115 
    116         let parsed_args = (tool.parse_call)(&self.arguments)?;
    117         Ok(ToolCall {
    118             id: self.id.clone(),
    119             typ: parsed_args,
    120         })
    121     }
    122 }
    123 
    124 impl PartialToolCall {
    125     pub fn complete(&self) -> Option<UnknownToolCall> {
    126         Some(UnknownToolCall {
    127             id: self.id.clone()?,
    128             name: self.name.clone()?,
    129             arguments: self.arguments.clone()?,
    130         })
    131     }
    132 }
    133 
    134 #[derive(Debug, Clone, Serialize, Deserialize)]
    135 pub struct InvalidToolCall {
    136     pub error: String,
    137     pub name: Option<String>,
    138     pub arguments: Option<String>,
    139 }
    140 
    141 /// An enumeration of the possible tool calls that
    142 /// can be parsed from Dave responses. When adding
    143 /// new tools, this needs to be updated so that we can
    144 /// handle tool call responses.
    145 #[derive(Debug, Clone, Serialize, Deserialize)]
    146 pub enum ToolCalls {
    147     Query(QueryCall),
    148     PresentNotes(PresentNotesCall),
    149     Invalid(InvalidToolCall),
    150 }
    151 
    152 impl ToolCalls {
    153     pub fn to_api(&self) -> FunctionCall {
    154         FunctionCall {
    155             name: self.name().to_owned(),
    156             arguments: self.arguments(),
    157         }
    158     }
    159 
    160     fn name(&self) -> &'static str {
    161         match self {
    162             Self::Query(_) => "search",
    163             Self::Invalid(_) => "error",
    164             Self::PresentNotes(_) => "present",
    165         }
    166     }
    167 
    168     /// Returns the tool name as defined in the tool registry (for prompt-based tool calls)
    169     pub fn tool_name(&self) -> &'static str {
    170         match self {
    171             Self::Query(_) => "query",
    172             Self::Invalid(_) => "invalid",
    173             Self::PresentNotes(_) => "present_notes",
    174         }
    175     }
    176 
    177     pub fn arguments(&self) -> String {
    178         match self {
    179             Self::Query(search) => serde_json::to_string(search).unwrap(),
    180             Self::Invalid(partial) => serde_json::to_string(partial).unwrap(),
    181             Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(),
    182         }
    183     }
    184 }
    185 
    186 #[derive(Debug)]
    187 pub enum ToolCallError {
    188     EmptyName,
    189     EmptyArgs,
    190     NotFound(String),
    191     ArgParseFailure(String),
    192 }
    193 
    194 impl fmt::Display for ToolCallError {
    195     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
    196         match self {
    197             ToolCallError::EmptyName => write!(f, "the tool name was empty"),
    198             ToolCallError::EmptyArgs => write!(f, "no arguments were provided"),
    199             ToolCallError::NotFound(ref name) => write!(f, "tool '{name}' not found"),
    200             ToolCallError::ArgParseFailure(ref msg) => {
    201                 write!(f, "failed to parse arguments: {msg}")
    202             }
    203         }
    204     }
    205 }
    206 
    207 #[derive(Debug, Clone)]
    208 enum ArgType {
    209     String,
    210     Number,
    211 
    212     #[allow(dead_code)]
    213     Enum(Vec<&'static str>),
    214 }
    215 
    216 impl ArgType {
    217     pub fn type_string(&self) -> &'static str {
    218         match self {
    219             Self::String => "string",
    220             Self::Number => "number",
    221             Self::Enum(_) => "string",
    222         }
    223     }
    224 }
    225 
    226 #[derive(Debug, Clone)]
    227 struct ToolArg {
    228     typ: ArgType,
    229     name: &'static str,
    230     required: bool,
    231     description: &'static str,
    232     default: Option<Value>,
    233 }
    234 
    235 #[derive(Debug, Clone)]
    236 pub struct Tool {
    237     parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>,
    238     name: &'static str,
    239     description: &'static str,
    240     arguments: Vec<ToolArg>,
    241 }
    242 
    243 impl Tool {
    244     pub fn name(&self) -> &'static str {
    245         self.name
    246     }
    247 
    248     pub fn description(&self) -> &'static str {
    249         self.description
    250     }
    251 
    252     pub fn parse_call(&self) -> fn(&str) -> Result<ToolCalls, ToolCallError> {
    253         self.parse_call
    254     }
    255 
    256     pub fn to_function_object(&self) -> FunctionObject {
    257         let required_args = self
    258             .arguments
    259             .iter()
    260             .filter_map(|arg| {
    261                 if arg.required {
    262                     Some(Value::String(arg.name.to_owned()))
    263                 } else {
    264                     None
    265                 }
    266             })
    267             .collect();
    268 
    269         let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new();
    270         parameters.insert("type".to_string(), Value::String("object".to_string()));
    271         parameters.insert("required".to_string(), Value::Array(required_args));
    272         parameters.insert("additionalProperties".to_string(), Value::Bool(false));
    273 
    274         let mut properties: serde_json::Map<String, Value> = serde_json::Map::new();
    275 
    276         for arg in &self.arguments {
    277             let mut props: serde_json::Map<String, Value> = serde_json::Map::new();
    278             props.insert(
    279                 "type".to_string(),
    280                 Value::String(arg.typ.type_string().to_string()),
    281             );
    282 
    283             let description = if let Some(default) = &arg.default {
    284                 format!("{} (Default: {default}))", arg.description)
    285             } else {
    286                 arg.description.to_owned()
    287             };
    288 
    289             props.insert("description".to_string(), Value::String(description));
    290             if let ArgType::Enum(enums) = &arg.typ {
    291                 props.insert(
    292                     "enum".to_string(),
    293                     Value::Array(
    294                         enums
    295                             .iter()
    296                             .map(|s| Value::String((*s).to_owned()))
    297                             .collect(),
    298                     ),
    299                 );
    300             }
    301 
    302             properties.insert(arg.name.to_owned(), Value::Object(props));
    303         }
    304 
    305         parameters.insert("properties".to_string(), Value::Object(properties));
    306 
    307         FunctionObject {
    308             name: self.name.to_owned(),
    309             description: Some(self.description.to_owned()),
    310             strict: Some(false),
    311             parameters: Some(Value::Object(parameters)),
    312         }
    313     }
    314 
    315     pub fn to_api(&self) -> ChatCompletionTool {
    316         ChatCompletionTool {
    317             r#type: ChatCompletionToolType::Function,
    318             function: self.to_function_object(),
    319         }
    320     }
    321 }
    322 
    323 impl ToolResponses {
    324     pub fn format_for_dave(&self, txn: &Transaction, ndb: &Ndb) -> String {
    325         format_tool_response_for_ai(txn, ndb, self)
    326     }
    327 }
    328 
    329 #[derive(Debug, Clone, Serialize, Deserialize)]
    330 pub struct ToolResponse {
    331     id: String,
    332     typ: ToolResponses,
    333 }
    334 
    335 impl ToolResponse {
    336     pub fn new(id: String, responses: ToolResponses) -> Self {
    337         Self { id, typ: responses }
    338     }
    339 
    340     pub fn error(id: String, msg: String) -> Self {
    341         Self {
    342             id,
    343             typ: ToolResponses::Error(msg),
    344         }
    345     }
    346 
    347     pub fn responses(&self) -> &ToolResponses {
    348         &self.typ
    349     }
    350 
    351     pub fn id(&self) -> &str {
    352         &self.id
    353     }
    354 }
    355 
    356 /// Called by dave when he wants to display notes on the screen
    357 #[derive(Debug, Deserialize, Serialize, Clone)]
    358 pub struct PresentNotesCall {
    359     pub note_ids: Vec<NoteId>,
    360 }
    361 
    362 impl PresentNotesCall {
    363     fn to_simple(&self) -> PresentNotesCallSimple {
    364         let note_ids = self
    365             .note_ids
    366             .iter()
    367             .map(|nid| hex::encode(nid.bytes()))
    368             .collect::<Vec<_>>()
    369             .join(",");
    370 
    371         PresentNotesCallSimple { note_ids }
    372     }
    373 }
    374 
    375 /// Called by dave when he wants to display notes on the screen
    376 #[derive(Debug, Deserialize, Serialize, Clone)]
    377 pub struct PresentNotesCallSimple {
    378     note_ids: String,
    379 }
    380 
    381 impl PresentNotesCall {
    382     fn parse(args: &str) -> Result<ToolCalls, ToolCallError> {
    383         match serde_json::from_str::<PresentNotesCallSimple>(args) {
    384             Ok(call) => {
    385                 let note_ids = call
    386                     .note_ids
    387                     .split(",")
    388                     .filter_map(|n| NoteId::from_hex(n).ok())
    389                     .collect();
    390 
    391                 Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids }))
    392             }
    393             Err(e) => Err(ToolCallError::ArgParseFailure(format!(
    394                 "{args}, error: {e}"
    395             ))),
    396         }
    397     }
    398 }
    399 
    400 /// The parsed nostrdb query that dave wants to use to satisfy a request
    401 #[derive(Debug, Deserialize, Serialize, Clone)]
    402 pub struct QueryCall {
    403     pub author: Option<Pubkey>,
    404     pub limit: Option<u64>,
    405     pub since: Option<u64>,
    406     pub kind: Option<u64>,
    407     pub until: Option<u64>,
    408     pub search: Option<String>,
    409 }
    410 
    411 fn is_reply(note: Note) -> bool {
    412     for tag in note.tags() {
    413         if tag.count() < 4 {
    414             continue;
    415         }
    416 
    417         let Some("e") = tag.get_str(0) else {
    418             continue;
    419         };
    420 
    421         let Some(s) = tag.get_str(3) else {
    422             continue;
    423         };
    424 
    425         if s == "root" || s == "reply" {
    426             return true;
    427         }
    428     }
    429 
    430     false
    431 }
    432 
    433 impl QueryCall {
    434     pub fn to_filter(&self) -> nostrdb::Filter {
    435         let mut filter = nostrdb::Filter::new()
    436             .limit(self.limit())
    437             .custom(|n| !is_reply(n))
    438             .kinds([self.kind.unwrap_or(1)]);
    439 
    440         if let Some(author) = &self.author {
    441             filter = filter.authors([author.bytes()]);
    442         }
    443 
    444         if let Some(search) = &self.search {
    445             filter = filter.search(search);
    446         }
    447 
    448         if let Some(until) = self.until {
    449             filter = filter.until(until);
    450         }
    451 
    452         if let Some(since) = self.since {
    453             filter = filter.since(since);
    454         }
    455 
    456         filter.build()
    457     }
    458 
    459     fn limit(&self) -> u64 {
    460         self.limit.unwrap_or(10)
    461     }
    462 
    463     pub fn author(&self) -> Option<&Pubkey> {
    464         self.author.as_ref()
    465     }
    466 
    467     pub fn since(&self) -> Option<u64> {
    468         self.since
    469     }
    470 
    471     pub fn until(&self) -> Option<u64> {
    472         self.until
    473     }
    474 
    475     pub fn search(&self) -> Option<&str> {
    476         self.search.as_deref()
    477     }
    478 
    479     pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> QueryResponse {
    480         let notes = {
    481             if let Ok(results) = ndb.query(txn, &[self.to_filter()], self.limit() as i32) {
    482                 results.into_iter().map(|r| r.note_key.as_u64()).collect()
    483             } else {
    484                 vec![]
    485             }
    486         };
    487         QueryResponse { notes }
    488     }
    489 
    490     pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> {
    491         match serde_json::from_str::<QueryCall>(args) {
    492             Ok(call) => Ok(ToolCalls::Query(call)),
    493             Err(e) => Err(ToolCallError::ArgParseFailure(format!(
    494                 "{args}, error: {e}"
    495             ))),
    496         }
    497     }
    498 }
    499 
    500 /// A simple note format for use when formatting
    501 /// tool responses
    502 #[derive(Debug, Serialize)]
    503 struct SimpleNote {
    504     note_id: String,
    505     pubkey: String,
    506     name: String,
    507     content: String,
    508     created_at: String,
    509     note_kind: u64, // todo: add replying to
    510 }
    511 
    512 /// Take the result of a tool response and present it to the ai so that
    513 /// it can interepret it and take further action
    514 fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String {
    515     match resp {
    516         ToolResponses::PresentNotes(n) => format!("{n} notes presented to the user"),
    517         ToolResponses::Error(s) => format!("error: {}", &s),
    518 
    519         ToolResponses::Query(search_r) => {
    520             let simple_notes: Vec<SimpleNote> = search_r
    521                 .notes
    522                 .iter()
    523                 .filter_map(|nkey| {
    524                     let Ok(note) = ndb.get_note_by_key(txn, NoteKey::new(*nkey)) else {
    525                         return None;
    526                     };
    527 
    528                     let name = ndb
    529                         .get_profile_by_pubkey(txn, note.pubkey())
    530                         .ok()
    531                         .and_then(|p| p.record().profile())
    532                         .and_then(|p| p.name().or_else(|| p.display_name()))
    533                         .unwrap_or("Anonymous")
    534                         .to_string();
    535 
    536                     let content = note.content().to_owned();
    537                     let pubkey = hex::encode(note.pubkey());
    538                     let note_kind = note.kind() as u64;
    539                     let note_id = hex::encode(note.id());
    540 
    541                     let created_at = {
    542                         let datetime =
    543                             DateTime::from_timestamp(note.created_at() as i64, 0).unwrap();
    544                         datetime.format("%Y-%m-%d %H:%M:%S").to_string()
    545                     };
    546 
    547                     Some(SimpleNote {
    548                         note_id,
    549                         pubkey,
    550                         name,
    551                         content,
    552                         created_at,
    553                         note_kind,
    554                     })
    555                 })
    556                 .collect();
    557 
    558             serde_json::to_string(&json!({"search_results": simple_notes})).unwrap()
    559         }
    560     }
    561 }
    562 
    563 fn _note_kind_desc(kind: u64) -> String {
    564     match kind {
    565         1 => "microblog".to_string(),
    566         0 => "profile".to_string(),
    567         _ => kind.to_string(),
    568     }
    569 }
    570 
    571 fn present_tool() -> Tool {
    572     Tool {
    573         name: "present_notes",
    574         parse_call: PresentNotesCall::parse,
    575         description: "A tool for presenting notes to the user for display. Should be called at the end of a response so that the UI can present the notes referred to in the previous message.",
    576         arguments: vec![
    577             ToolArg {
    578                 name: "note_ids",
    579                 description: "A comma-separated list of hex note ids",
    580                 typ: ArgType::String,
    581                 required: true,
    582                 default: None
    583             }
    584         ]
    585     }
    586 }
    587 
    588 fn query_tool() -> Tool {
    589     Tool {
    590         name: "query",
    591         parse_call: QueryCall::parse,
    592         description: "Note query functionality. Used for finding notes using full-text search terms, scoped by different contexts. You can use a combination of limit, since, and until to pull notes from any time range.",
    593         arguments: vec![
    594             ToolArg {
    595                 name: "search",
    596                 typ: ArgType::String,
    597                 required: false,
    598                 default: None,
    599                 description: "A fulltext search query. Queries with multiple words will only return results with notes that have all of those words. Don't include filler words/symbols like 'and', punctuation, etc",
    600             },
    601 
    602             ToolArg {
    603                 name: "limit",
    604                 typ: ArgType::Number,
    605                 required: true,
    606                 default: Some(Value::Number(serde_json::Number::from_i128(50).unwrap())),
    607                 description: "The number of results to return.",
    608             },
    609 
    610             ToolArg {
    611                 name: "since",
    612                 typ: ArgType::Number,
    613                 required: false,
    614                 default: None,
    615                 description: "Only pull notes after this unix timestamp",
    616             },
    617 
    618             ToolArg {
    619                 name: "until",
    620                 typ: ArgType::Number,
    621                 required: false,
    622                 default: None,
    623                 description: "Only pull notes up until this unix timestamp. Always include this when searching notes within some date range (yesterday, last week, etc).",
    624             },
    625 
    626             ToolArg {
    627                 name: "author",
    628                 typ: ArgType::String,
    629                 required: false,
    630                 default: None,
    631                 description: "An author *pubkey* to constrain the query on. Can be used to search for notes from individual users. If unsure what pubkey to u
    632 se, you can query for kind 0 profiles with the search argument.",
    633             },
    634 
    635             ToolArg {
    636                 name: "kind",
    637                 typ: ArgType::Number,
    638                 required: false,
    639                 default: Some(Value::Number(serde_json::Number::from_i128(1).unwrap())),
    640                 description: r#"The kind of note. Kind list:
    641                 - 0: profiles
    642                 - 1: microblogs/\"tweets\"/posts
    643                 - 6: reposts of kind 1 notes
    644                 - 7: emoji reactions/likes
    645                 - 9735: zaps (bitcoin micropayment receipts)
    646                 - 30023: longform articles, blog posts, etc
    647 
    648                 "#,
    649             },
    650 
    651         ]
    652     }
    653 }
    654 
    655 pub fn dave_tools() -> Vec<Tool> {
    656     vec![query_tool(), present_tool()]
    657 }