notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

tools.rs (18090B)


      1 use async_openai::types::*;
      2 use chrono::DateTime;
      3 use enostr::{NoteId, Pubkey};
      4 use nostrdb::{Ndb, Note, NoteKey, Transaction};
      5 use serde::{Deserialize, Serialize};
      6 use serde_json::{json, Value};
      7 use std::{collections::HashMap, fmt};
      8 
      9 /// A tool
     10 #[derive(Debug, Clone, Serialize, Deserialize)]
     11 pub struct ToolCall {
     12     id: String,
     13     typ: ToolCalls,
     14 }
     15 
     16 impl ToolCall {
     17     pub fn id(&self) -> &str {
     18         &self.id
     19     }
     20 
     21     pub fn invalid(
     22         id: String,
     23         name: Option<String>,
     24         arguments: Option<String>,
     25         error: String,
     26     ) -> Self {
     27         Self {
     28             id,
     29             typ: ToolCalls::Invalid(InvalidToolCall {
     30                 name,
     31                 arguments,
     32                 error,
     33             }),
     34         }
     35     }
     36 
     37     pub fn calls(&self) -> &ToolCalls {
     38         &self.typ
     39     }
     40 
     41     pub fn to_api(&self) -> ChatCompletionMessageToolCall {
     42         ChatCompletionMessageToolCall {
     43             id: self.id.clone(),
     44             r#type: ChatCompletionToolType::Function,
     45             function: self.typ.to_api(),
     46         }
     47     }
     48 }
     49 
     50 /// On streaming APIs, tool calls are incremental. We use this
     51 /// to represent tool calls that are in the process of returning.
     52 /// These eventually just become [`ToolCall`]'s
     53 #[derive(Default, Debug, Clone, Serialize, Deserialize)]
     54 pub struct PartialToolCall {
     55     pub id: Option<String>,
     56     pub name: Option<String>,
     57     pub arguments: Option<String>,
     58 }
     59 
     60 impl PartialToolCall {
     61     pub fn id(&self) -> Option<&str> {
     62         self.id.as_deref()
     63     }
     64 
     65     pub fn id_mut(&mut self) -> &mut Option<String> {
     66         &mut self.id
     67     }
     68 
     69     pub fn name(&self) -> Option<&str> {
     70         self.name.as_deref()
     71     }
     72 
     73     pub fn name_mut(&mut self) -> &mut Option<String> {
     74         &mut self.name
     75     }
     76 
     77     pub fn arguments(&self) -> Option<&str> {
     78         self.arguments.as_deref()
     79     }
     80 
     81     pub fn arguments_mut(&mut self) -> &mut Option<String> {
     82         &mut self.arguments
     83     }
     84 }
     85 
     86 /// The query response from nostrdb for a given context
     87 #[derive(Debug, Clone, Serialize, Deserialize)]
     88 pub struct QueryResponse {
     89     notes: Vec<u64>,
     90 }
     91 
     92 #[derive(Debug, Clone, Serialize, Deserialize)]
     93 pub enum ToolResponses {
     94     Error(String),
     95     Query(QueryResponse),
     96     PresentNotes(i32),
     97 }
     98 
     99 #[derive(Debug, Clone)]
    100 pub struct UnknownToolCall {
    101     id: String,
    102     name: String,
    103     arguments: String,
    104 }
    105 
    106 impl UnknownToolCall {
    107     pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> {
    108         let Some(tool) = tools.get(&self.name) else {
    109             return Err(ToolCallError::NotFound(self.name.to_owned()));
    110         };
    111 
    112         let parsed_args = (tool.parse_call)(&self.arguments)?;
    113         Ok(ToolCall {
    114             id: self.id.clone(),
    115             typ: parsed_args,
    116         })
    117     }
    118 }
    119 
    120 impl PartialToolCall {
    121     pub fn complete(&self) -> Option<UnknownToolCall> {
    122         Some(UnknownToolCall {
    123             id: self.id.clone()?,
    124             name: self.name.clone()?,
    125             arguments: self.arguments.clone()?,
    126         })
    127     }
    128 }
    129 
    130 #[derive(Debug, Clone, Serialize, Deserialize)]
    131 pub struct InvalidToolCall {
    132     pub error: String,
    133     pub name: Option<String>,
    134     pub arguments: Option<String>,
    135 }
    136 
    137 /// An enumeration of the possible tool calls that
    138 /// can be parsed from Dave responses. When adding
    139 /// new tools, this needs to be updated so that we can
    140 /// handle tool call responses.
    141 #[derive(Debug, Clone, Serialize, Deserialize)]
    142 pub enum ToolCalls {
    143     Query(QueryCall),
    144     PresentNotes(PresentNotesCall),
    145     Invalid(InvalidToolCall),
    146 }
    147 
    148 impl ToolCalls {
    149     pub fn to_api(&self) -> FunctionCall {
    150         FunctionCall {
    151             name: self.name().to_owned(),
    152             arguments: self.arguments(),
    153         }
    154     }
    155 
    156     fn name(&self) -> &'static str {
    157         match self {
    158             Self::Query(_) => "search",
    159             Self::Invalid(_) => "error",
    160             Self::PresentNotes(_) => "present",
    161         }
    162     }
    163 
    164     fn arguments(&self) -> String {
    165         match self {
    166             Self::Query(search) => serde_json::to_string(search).unwrap(),
    167             Self::Invalid(partial) => serde_json::to_string(partial).unwrap(),
    168             Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(),
    169         }
    170     }
    171 }
    172 
    173 #[derive(Debug)]
    174 pub enum ToolCallError {
    175     EmptyName,
    176     EmptyArgs,
    177     NotFound(String),
    178     ArgParseFailure(String),
    179 }
    180 
    181 impl fmt::Display for ToolCallError {
    182     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
    183         match self {
    184             ToolCallError::EmptyName => write!(f, "the tool name was empty"),
    185             ToolCallError::EmptyArgs => write!(f, "no arguments were provided"),
    186             ToolCallError::NotFound(ref name) => write!(f, "tool '{name}' not found"),
    187             ToolCallError::ArgParseFailure(ref msg) => {
    188                 write!(f, "failed to parse arguments: {msg}")
    189             }
    190         }
    191     }
    192 }
    193 
    194 #[derive(Debug, Clone)]
    195 enum ArgType {
    196     String,
    197     Number,
    198 
    199     #[allow(dead_code)]
    200     Enum(Vec<&'static str>),
    201 }
    202 
    203 impl ArgType {
    204     pub fn type_string(&self) -> &'static str {
    205         match self {
    206             Self::String => "string",
    207             Self::Number => "number",
    208             Self::Enum(_) => "string",
    209         }
    210     }
    211 }
    212 
    213 #[derive(Debug, Clone)]
    214 struct ToolArg {
    215     typ: ArgType,
    216     name: &'static str,
    217     required: bool,
    218     description: &'static str,
    219     default: Option<Value>,
    220 }
    221 
    222 #[derive(Debug, Clone)]
    223 pub struct Tool {
    224     parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>,
    225     name: &'static str,
    226     description: &'static str,
    227     arguments: Vec<ToolArg>,
    228 }
    229 
    230 impl Tool {
    231     pub fn name(&self) -> &'static str {
    232         self.name
    233     }
    234 
    235     pub fn to_function_object(&self) -> FunctionObject {
    236         let required_args = self
    237             .arguments
    238             .iter()
    239             .filter_map(|arg| {
    240                 if arg.required {
    241                     Some(Value::String(arg.name.to_owned()))
    242                 } else {
    243                     None
    244                 }
    245             })
    246             .collect();
    247 
    248         let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new();
    249         parameters.insert("type".to_string(), Value::String("object".to_string()));
    250         parameters.insert("required".to_string(), Value::Array(required_args));
    251         parameters.insert("additionalProperties".to_string(), Value::Bool(false));
    252 
    253         let mut properties: serde_json::Map<String, Value> = serde_json::Map::new();
    254 
    255         for arg in &self.arguments {
    256             let mut props: serde_json::Map<String, Value> = serde_json::Map::new();
    257             props.insert(
    258                 "type".to_string(),
    259                 Value::String(arg.typ.type_string().to_string()),
    260             );
    261 
    262             let description = if let Some(default) = &arg.default {
    263                 format!("{} (Default: {default}))", arg.description)
    264             } else {
    265                 arg.description.to_owned()
    266             };
    267 
    268             props.insert("description".to_string(), Value::String(description));
    269             if let ArgType::Enum(enums) = &arg.typ {
    270                 props.insert(
    271                     "enum".to_string(),
    272                     Value::Array(
    273                         enums
    274                             .iter()
    275                             .map(|s| Value::String((*s).to_owned()))
    276                             .collect(),
    277                     ),
    278                 );
    279             }
    280 
    281             properties.insert(arg.name.to_owned(), Value::Object(props));
    282         }
    283 
    284         parameters.insert("properties".to_string(), Value::Object(properties));
    285 
    286         FunctionObject {
    287             name: self.name.to_owned(),
    288             description: Some(self.description.to_owned()),
    289             strict: Some(false),
    290             parameters: Some(Value::Object(parameters)),
    291         }
    292     }
    293 
    294     pub fn to_api(&self) -> ChatCompletionTool {
    295         ChatCompletionTool {
    296             r#type: ChatCompletionToolType::Function,
    297             function: self.to_function_object(),
    298         }
    299     }
    300 }
    301 
    302 impl ToolResponses {
    303     pub fn format_for_dave(&self, txn: &Transaction, ndb: &Ndb) -> String {
    304         format_tool_response_for_ai(txn, ndb, self)
    305     }
    306 }
    307 
    308 #[derive(Debug, Clone, Serialize, Deserialize)]
    309 pub struct ToolResponse {
    310     id: String,
    311     typ: ToolResponses,
    312 }
    313 
    314 impl ToolResponse {
    315     pub fn new(id: String, responses: ToolResponses) -> Self {
    316         Self { id, typ: responses }
    317     }
    318 
    319     pub fn error(id: String, msg: String) -> Self {
    320         Self {
    321             id,
    322             typ: ToolResponses::Error(msg),
    323         }
    324     }
    325 
    326     pub fn responses(&self) -> &ToolResponses {
    327         &self.typ
    328     }
    329 
    330     pub fn id(&self) -> &str {
    331         &self.id
    332     }
    333 }
    334 
    335 /// Called by dave when he wants to display notes on the screen
    336 #[derive(Debug, Deserialize, Serialize, Clone)]
    337 pub struct PresentNotesCall {
    338     pub note_ids: Vec<NoteId>,
    339 }
    340 
    341 impl PresentNotesCall {
    342     fn to_simple(&self) -> PresentNotesCallSimple {
    343         let note_ids = self
    344             .note_ids
    345             .iter()
    346             .map(|nid| hex::encode(nid.bytes()))
    347             .collect::<Vec<_>>()
    348             .join(",");
    349 
    350         PresentNotesCallSimple { note_ids }
    351     }
    352 }
    353 
    354 /// Called by dave when he wants to display notes on the screen
    355 #[derive(Debug, Deserialize, Serialize, Clone)]
    356 pub struct PresentNotesCallSimple {
    357     note_ids: String,
    358 }
    359 
    360 impl PresentNotesCall {
    361     fn parse(args: &str) -> Result<ToolCalls, ToolCallError> {
    362         match serde_json::from_str::<PresentNotesCallSimple>(args) {
    363             Ok(call) => {
    364                 let note_ids = call
    365                     .note_ids
    366                     .split(",")
    367                     .filter_map(|n| NoteId::from_hex(n).ok())
    368                     .collect();
    369 
    370                 Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids }))
    371             }
    372             Err(e) => Err(ToolCallError::ArgParseFailure(format!(
    373                 "{args}, error: {e}"
    374             ))),
    375         }
    376     }
    377 }
    378 
    379 /// The parsed nostrdb query that dave wants to use to satisfy a request
    380 #[derive(Debug, Deserialize, Serialize, Clone)]
    381 pub struct QueryCall {
    382     pub author: Option<Pubkey>,
    383     pub limit: Option<u64>,
    384     pub since: Option<u64>,
    385     pub kind: Option<u64>,
    386     pub until: Option<u64>,
    387     pub search: Option<String>,
    388 }
    389 
    390 fn is_reply(note: Note) -> bool {
    391     for tag in note.tags() {
    392         if tag.count() < 4 {
    393             continue;
    394         }
    395 
    396         let Some("e") = tag.get_str(0) else {
    397             continue;
    398         };
    399 
    400         let Some(s) = tag.get_str(3) else {
    401             continue;
    402         };
    403 
    404         if s == "root" || s == "reply" {
    405             return true;
    406         }
    407     }
    408 
    409     false
    410 }
    411 
    412 impl QueryCall {
    413     pub fn to_filter(&self) -> nostrdb::Filter {
    414         let mut filter = nostrdb::Filter::new()
    415             .limit(self.limit())
    416             .custom(|n| !is_reply(n))
    417             .kinds([self.kind.unwrap_or(1)]);
    418 
    419         if let Some(author) = &self.author {
    420             filter = filter.authors([author.bytes()]);
    421         }
    422 
    423         if let Some(search) = &self.search {
    424             filter = filter.search(search);
    425         }
    426 
    427         if let Some(until) = self.until {
    428             filter = filter.until(until);
    429         }
    430 
    431         if let Some(since) = self.since {
    432             filter = filter.since(since);
    433         }
    434 
    435         filter.build()
    436     }
    437 
    438     fn limit(&self) -> u64 {
    439         self.limit.unwrap_or(10)
    440     }
    441 
    442     pub fn author(&self) -> Option<&Pubkey> {
    443         self.author.as_ref()
    444     }
    445 
    446     pub fn since(&self) -> Option<u64> {
    447         self.since
    448     }
    449 
    450     pub fn until(&self) -> Option<u64> {
    451         self.until
    452     }
    453 
    454     pub fn search(&self) -> Option<&str> {
    455         self.search.as_deref()
    456     }
    457 
    458     pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> QueryResponse {
    459         let notes = {
    460             if let Ok(results) = ndb.query(txn, &[self.to_filter()], self.limit() as i32) {
    461                 results.into_iter().map(|r| r.note_key.as_u64()).collect()
    462             } else {
    463                 vec![]
    464             }
    465         };
    466         QueryResponse { notes }
    467     }
    468 
    469     pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> {
    470         match serde_json::from_str::<QueryCall>(args) {
    471             Ok(call) => Ok(ToolCalls::Query(call)),
    472             Err(e) => Err(ToolCallError::ArgParseFailure(format!(
    473                 "{args}, error: {e}"
    474             ))),
    475         }
    476     }
    477 }
    478 
    479 /// A simple note format for use when formatting
    480 /// tool responses
    481 #[derive(Debug, Serialize)]
    482 struct SimpleNote {
    483     note_id: String,
    484     pubkey: String,
    485     name: String,
    486     content: String,
    487     created_at: String,
    488     note_kind: u64, // todo: add replying to
    489 }
    490 
    491 /// Take the result of a tool response and present it to the ai so that
    492 /// it can interepret it and take further action
    493 fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String {
    494     match resp {
    495         ToolResponses::PresentNotes(n) => format!("{n} notes presented to the user"),
    496         ToolResponses::Error(s) => format!("error: {}", &s),
    497 
    498         ToolResponses::Query(search_r) => {
    499             let simple_notes: Vec<SimpleNote> = search_r
    500                 .notes
    501                 .iter()
    502                 .filter_map(|nkey| {
    503                     let Ok(note) = ndb.get_note_by_key(txn, NoteKey::new(*nkey)) else {
    504                         return None;
    505                     };
    506 
    507                     let name = ndb
    508                         .get_profile_by_pubkey(txn, note.pubkey())
    509                         .ok()
    510                         .and_then(|p| p.record().profile())
    511                         .and_then(|p| p.name().or_else(|| p.display_name()))
    512                         .unwrap_or("Anonymous")
    513                         .to_string();
    514 
    515                     let content = note.content().to_owned();
    516                     let pubkey = hex::encode(note.pubkey());
    517                     let note_kind = note.kind() as u64;
    518                     let note_id = hex::encode(note.id());
    519 
    520                     let created_at = {
    521                         let datetime =
    522                             DateTime::from_timestamp(note.created_at() as i64, 0).unwrap();
    523                         datetime.format("%Y-%m-%d %H:%M:%S").to_string()
    524                     };
    525 
    526                     Some(SimpleNote {
    527                         note_id,
    528                         pubkey,
    529                         name,
    530                         content,
    531                         created_at,
    532                         note_kind,
    533                     })
    534                 })
    535                 .collect();
    536 
    537             serde_json::to_string(&json!({"search_results": simple_notes})).unwrap()
    538         }
    539     }
    540 }
    541 
    542 fn _note_kind_desc(kind: u64) -> String {
    543     match kind {
    544         1 => "microblog".to_string(),
    545         0 => "profile".to_string(),
    546         _ => kind.to_string(),
    547     }
    548 }
    549 
    550 fn present_tool() -> Tool {
    551     Tool {
    552         name: "present_notes",
    553         parse_call: PresentNotesCall::parse,
    554         description: "A tool for presenting notes to the user for display. Should be called at the end of a response so that the UI can present the notes referred to in the previous message.",
    555         arguments: vec![
    556             ToolArg {
    557                 name: "note_ids",
    558                 description: "A comma-separated list of hex note ids",
    559                 typ: ArgType::String,
    560                 required: true,
    561                 default: None
    562             }
    563         ]
    564     }
    565 }
    566 
    567 fn query_tool() -> Tool {
    568     Tool {
    569         name: "query",
    570         parse_call: QueryCall::parse,
    571         description: "Note query functionality. Used for finding notes using full-text search terms, scoped by different contexts. You can use a combination of limit, since, and until to pull notes from any time range.",
    572         arguments: vec![
    573             ToolArg {
    574                 name: "search",
    575                 typ: ArgType::String,
    576                 required: false,
    577                 default: None,
    578                 description: "A fulltext search query. Queries with multiple words will only return results with notes that have all of those words. Don't include filler words/symbols like 'and', punctuation, etc",
    579             },
    580 
    581             ToolArg {
    582                 name: "limit",
    583                 typ: ArgType::Number,
    584                 required: true,
    585                 default: Some(Value::Number(serde_json::Number::from_i128(50).unwrap())),
    586                 description: "The number of results to return.",
    587             },
    588 
    589             ToolArg {
    590                 name: "since",
    591                 typ: ArgType::Number,
    592                 required: false,
    593                 default: None,
    594                 description: "Only pull notes after this unix timestamp",
    595             },
    596 
    597             ToolArg {
    598                 name: "until",
    599                 typ: ArgType::Number,
    600                 required: false,
    601                 default: None,
    602                 description: "Only pull notes up until this unix timestamp. Always include this when searching notes within some date range (yesterday, last week, etc).",
    603             },
    604 
    605             ToolArg {
    606                 name: "author",
    607                 typ: ArgType::String,
    608                 required: false,
    609                 default: None,
    610                 description: "An author *pubkey* to constrain the query on. Can be used to search for notes from individual users. If unsure what pubkey to u
    611 se, you can query for kind 0 profiles with the search argument.",
    612             },
    613 
    614             ToolArg {
    615                 name: "kind",
    616                 typ: ArgType::Number,
    617                 required: false,
    618                 default: Some(Value::Number(serde_json::Number::from_i128(1).unwrap())),
    619                 description: r#"The kind of note. Kind list:
    620                 - 0: profiles
    621                 - 1: microblogs/\"tweets\"/posts
    622                 - 6: reposts of kind 1 notes
    623                 - 7: emoji reactions/likes
    624                 - 9735: zaps (bitcoin micropayment receipts)
    625                 - 30023: longform articles, blog posts, etc
    626 
    627                 "#,
    628             },
    629 
    630         ]
    631     }
    632 }
    633 
    634 pub fn dave_tools() -> Vec<Tool> {
    635     vec![query_tool(), present_tool()]
    636 }