notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

tools.rs (18179B)


      1 use async_openai::types::*;
      2 use chrono::DateTime;
      3 use enostr::{NoteId, Pubkey};
      4 use nostrdb::{Ndb, Note, NoteKey, Transaction};
      5 use serde::{Deserialize, Serialize};
      6 use serde_json::{json, Value};
      7 use std::{collections::HashMap, fmt};
      8 
      9 /// A tool
     10 #[derive(Debug, Clone, Serialize, Deserialize)]
     11 pub struct ToolCall {
     12     id: String,
     13     typ: ToolCalls,
     14 }
     15 
     16 impl ToolCall {
     17     pub fn id(&self) -> &str {
     18         &self.id
     19     }
     20 
     21     pub fn invalid(
     22         id: String,
     23         name: Option<String>,
     24         arguments: Option<String>,
     25         error: String,
     26     ) -> Self {
     27         Self {
     28             id,
     29             typ: ToolCalls::Invalid(InvalidToolCall {
     30                 name,
     31                 arguments,
     32                 error,
     33             }),
     34         }
     35     }
     36 
     37     pub fn calls(&self) -> &ToolCalls {
     38         &self.typ
     39     }
     40 
     41     pub fn to_api(&self) -> ChatCompletionMessageToolCall {
     42         ChatCompletionMessageToolCall {
     43             id: self.id.clone(),
     44             r#type: ChatCompletionToolType::Function,
     45             function: self.typ.to_api(),
     46         }
     47     }
     48 }
     49 
     50 /// On streaming APIs, tool calls are incremental. We use this
     51 /// to represent tool calls that are in the process of returning.
     52 /// These eventually just become [`ToolCall`]'s
     53 #[derive(Default, Debug, Clone, Serialize, Deserialize)]
     54 pub struct PartialToolCall {
     55     pub id: Option<String>,
     56     pub name: Option<String>,
     57     pub arguments: Option<String>,
     58 }
     59 
     60 impl PartialToolCall {
     61     pub fn id(&self) -> Option<&str> {
     62         self.id.as_deref()
     63     }
     64 
     65     pub fn id_mut(&mut self) -> &mut Option<String> {
     66         &mut self.id
     67     }
     68 
     69     pub fn name(&self) -> Option<&str> {
     70         self.name.as_deref()
     71     }
     72 
     73     pub fn name_mut(&mut self) -> &mut Option<String> {
     74         &mut self.name
     75     }
     76 
     77     pub fn arguments(&self) -> Option<&str> {
     78         self.arguments.as_deref()
     79     }
     80 
     81     pub fn arguments_mut(&mut self) -> &mut Option<String> {
     82         &mut self.arguments
     83     }
     84 }
     85 
     86 /// The query response from nostrdb for a given context
     87 #[derive(Debug, Clone, Serialize, Deserialize)]
     88 pub struct QueryResponse {
     89     notes: Vec<u64>,
     90 }
     91 
     92 #[derive(Debug, Clone, Serialize, Deserialize)]
     93 pub enum ToolResponses {
     94     Error(String),
     95     Query(QueryResponse),
     96     PresentNotes(i32),
     97 }
     98 
     99 #[derive(Debug, Clone)]
    100 pub struct UnknownToolCall {
    101     id: String,
    102     name: String,
    103     arguments: String,
    104 }
    105 
    106 impl UnknownToolCall {
    107     pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> {
    108         let Some(tool) = tools.get(&self.name) else {
    109             return Err(ToolCallError::NotFound(self.name.to_owned()));
    110         };
    111 
    112         let parsed_args = (tool.parse_call)(&self.arguments)?;
    113         Ok(ToolCall {
    114             id: self.id.clone(),
    115             typ: parsed_args,
    116         })
    117     }
    118 }
    119 
    120 impl PartialToolCall {
    121     pub fn complete(&self) -> Option<UnknownToolCall> {
    122         Some(UnknownToolCall {
    123             id: self.id.clone()?,
    124             name: self.name.clone()?,
    125             arguments: self.arguments.clone()?,
    126         })
    127     }
    128 }
    129 
    130 #[derive(Debug, Clone, Serialize, Deserialize)]
    131 pub struct InvalidToolCall {
    132     pub error: String,
    133     pub name: Option<String>,
    134     pub arguments: Option<String>,
    135 }
    136 
    137 /// An enumeration of the possible tool calls that
    138 /// can be parsed from Dave responses. When adding
    139 /// new tools, this needs to be updated so that we can
    140 /// handle tool call responses.
    141 #[derive(Debug, Clone, Serialize, Deserialize)]
    142 pub enum ToolCalls {
    143     Query(QueryCall),
    144     PresentNotes(PresentNotesCall),
    145     Invalid(InvalidToolCall),
    146 }
    147 
    148 #[derive(Debug, Clone, Serialize, Deserialize)]
    149 struct ErrorCall {
    150     error: String,
    151 }
    152 
    153 impl ToolCalls {
    154     pub fn to_api(&self) -> FunctionCall {
    155         FunctionCall {
    156             name: self.name().to_owned(),
    157             arguments: self.arguments(),
    158         }
    159     }
    160 
    161     fn name(&self) -> &'static str {
    162         match self {
    163             Self::Query(_) => "search",
    164             Self::Invalid(_) => "error",
    165             Self::PresentNotes(_) => "present",
    166         }
    167     }
    168 
    169     fn arguments(&self) -> String {
    170         match self {
    171             Self::Query(search) => serde_json::to_string(search).unwrap(),
    172             Self::Invalid(partial) => serde_json::to_string(partial).unwrap(),
    173             Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(),
    174         }
    175     }
    176 }
    177 
    178 #[derive(Debug)]
    179 pub enum ToolCallError {
    180     EmptyName,
    181     EmptyArgs,
    182     NotFound(String),
    183     ArgParseFailure(String),
    184 }
    185 
    186 impl fmt::Display for ToolCallError {
    187     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
    188         match self {
    189             ToolCallError::EmptyName => write!(f, "the tool name was empty"),
    190             ToolCallError::EmptyArgs => write!(f, "no arguments were provided"),
    191             ToolCallError::NotFound(ref name) => write!(f, "tool '{name}' not found"),
    192             ToolCallError::ArgParseFailure(ref msg) => {
    193                 write!(f, "failed to parse arguments: {msg}")
    194             }
    195         }
    196     }
    197 }
    198 
    199 #[derive(Debug, Clone)]
    200 enum ArgType {
    201     String,
    202     Number,
    203 
    204     #[allow(dead_code)]
    205     Enum(Vec<&'static str>),
    206 }
    207 
    208 impl ArgType {
    209     pub fn type_string(&self) -> &'static str {
    210         match self {
    211             Self::String => "string",
    212             Self::Number => "number",
    213             Self::Enum(_) => "string",
    214         }
    215     }
    216 }
    217 
    218 #[derive(Debug, Clone)]
    219 struct ToolArg {
    220     typ: ArgType,
    221     name: &'static str,
    222     required: bool,
    223     description: &'static str,
    224     default: Option<Value>,
    225 }
    226 
    227 #[derive(Debug, Clone)]
    228 pub struct Tool {
    229     parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>,
    230     name: &'static str,
    231     description: &'static str,
    232     arguments: Vec<ToolArg>,
    233 }
    234 
    235 impl Tool {
    236     pub fn name(&self) -> &'static str {
    237         self.name
    238     }
    239 
    240     pub fn to_function_object(&self) -> FunctionObject {
    241         let required_args = self
    242             .arguments
    243             .iter()
    244             .filter_map(|arg| {
    245                 if arg.required {
    246                     Some(Value::String(arg.name.to_owned()))
    247                 } else {
    248                     None
    249                 }
    250             })
    251             .collect();
    252 
    253         let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new();
    254         parameters.insert("type".to_string(), Value::String("object".to_string()));
    255         parameters.insert("required".to_string(), Value::Array(required_args));
    256         parameters.insert("additionalProperties".to_string(), Value::Bool(false));
    257 
    258         let mut properties: serde_json::Map<String, Value> = serde_json::Map::new();
    259 
    260         for arg in &self.arguments {
    261             let mut props: serde_json::Map<String, Value> = serde_json::Map::new();
    262             props.insert(
    263                 "type".to_string(),
    264                 Value::String(arg.typ.type_string().to_string()),
    265             );
    266 
    267             let description = if let Some(default) = &arg.default {
    268                 format!("{} (Default: {default}))", arg.description)
    269             } else {
    270                 arg.description.to_owned()
    271             };
    272 
    273             props.insert("description".to_string(), Value::String(description));
    274             if let ArgType::Enum(enums) = &arg.typ {
    275                 props.insert(
    276                     "enum".to_string(),
    277                     Value::Array(
    278                         enums
    279                             .iter()
    280                             .map(|s| Value::String((*s).to_owned()))
    281                             .collect(),
    282                     ),
    283                 );
    284             }
    285 
    286             properties.insert(arg.name.to_owned(), Value::Object(props));
    287         }
    288 
    289         parameters.insert("properties".to_string(), Value::Object(properties));
    290 
    291         FunctionObject {
    292             name: self.name.to_owned(),
    293             description: Some(self.description.to_owned()),
    294             strict: Some(false),
    295             parameters: Some(Value::Object(parameters)),
    296         }
    297     }
    298 
    299     pub fn to_api(&self) -> ChatCompletionTool {
    300         ChatCompletionTool {
    301             r#type: ChatCompletionToolType::Function,
    302             function: self.to_function_object(),
    303         }
    304     }
    305 }
    306 
    307 impl ToolResponses {
    308     pub fn format_for_dave(&self, txn: &Transaction, ndb: &Ndb) -> String {
    309         format_tool_response_for_ai(txn, ndb, self)
    310     }
    311 }
    312 
    313 #[derive(Debug, Clone, Serialize, Deserialize)]
    314 pub struct ToolResponse {
    315     id: String,
    316     typ: ToolResponses,
    317 }
    318 
    319 impl ToolResponse {
    320     pub fn new(id: String, responses: ToolResponses) -> Self {
    321         Self { id, typ: responses }
    322     }
    323 
    324     pub fn error(id: String, msg: String) -> Self {
    325         Self {
    326             id,
    327             typ: ToolResponses::Error(msg),
    328         }
    329     }
    330 
    331     pub fn responses(&self) -> &ToolResponses {
    332         &self.typ
    333     }
    334 
    335     pub fn id(&self) -> &str {
    336         &self.id
    337     }
    338 }
    339 
    340 /// Called by dave when he wants to display notes on the screen
    341 #[derive(Debug, Deserialize, Serialize, Clone)]
    342 pub struct PresentNotesCall {
    343     pub note_ids: Vec<NoteId>,
    344 }
    345 
    346 impl PresentNotesCall {
    347     fn to_simple(&self) -> PresentNotesCallSimple {
    348         let note_ids = self
    349             .note_ids
    350             .iter()
    351             .map(|nid| hex::encode(nid.bytes()))
    352             .collect::<Vec<_>>()
    353             .join(",");
    354 
    355         PresentNotesCallSimple { note_ids }
    356     }
    357 }
    358 
    359 /// Called by dave when he wants to display notes on the screen
    360 #[derive(Debug, Deserialize, Serialize, Clone)]
    361 pub struct PresentNotesCallSimple {
    362     note_ids: String,
    363 }
    364 
    365 impl PresentNotesCall {
    366     fn parse(args: &str) -> Result<ToolCalls, ToolCallError> {
    367         match serde_json::from_str::<PresentNotesCallSimple>(args) {
    368             Ok(call) => {
    369                 let note_ids = call
    370                     .note_ids
    371                     .split(",")
    372                     .filter_map(|n| NoteId::from_hex(n).ok())
    373                     .collect();
    374 
    375                 Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids }))
    376             }
    377             Err(e) => Err(ToolCallError::ArgParseFailure(format!(
    378                 "{args}, error: {e}"
    379             ))),
    380         }
    381     }
    382 }
    383 
    384 /// The parsed nostrdb query that dave wants to use to satisfy a request
    385 #[derive(Debug, Deserialize, Serialize, Clone)]
    386 pub struct QueryCall {
    387     pub author: Option<Pubkey>,
    388     pub limit: Option<u64>,
    389     pub since: Option<u64>,
    390     pub kind: Option<u64>,
    391     pub until: Option<u64>,
    392     pub search: Option<String>,
    393 }
    394 
    395 fn is_reply(note: Note) -> bool {
    396     for tag in note.tags() {
    397         if tag.count() < 4 {
    398             continue;
    399         }
    400 
    401         let Some("e") = tag.get_str(0) else {
    402             continue;
    403         };
    404 
    405         let Some(s) = tag.get_str(3) else {
    406             continue;
    407         };
    408 
    409         if s == "root" || s == "reply" {
    410             return true;
    411         }
    412     }
    413 
    414     false
    415 }
    416 
    417 impl QueryCall {
    418     pub fn to_filter(&self) -> nostrdb::Filter {
    419         let mut filter = nostrdb::Filter::new()
    420             .limit(self.limit())
    421             .custom(|n| !is_reply(n))
    422             .kinds([self.kind.unwrap_or(1)]);
    423 
    424         if let Some(author) = &self.author {
    425             filter = filter.authors([author.bytes()]);
    426         }
    427 
    428         if let Some(search) = &self.search {
    429             filter = filter.search(search);
    430         }
    431 
    432         if let Some(until) = self.until {
    433             filter = filter.until(until);
    434         }
    435 
    436         if let Some(since) = self.since {
    437             filter = filter.since(since);
    438         }
    439 
    440         filter.build()
    441     }
    442 
    443     fn limit(&self) -> u64 {
    444         self.limit.unwrap_or(10)
    445     }
    446 
    447     pub fn author(&self) -> Option<&Pubkey> {
    448         self.author.as_ref()
    449     }
    450 
    451     pub fn since(&self) -> Option<u64> {
    452         self.since
    453     }
    454 
    455     pub fn until(&self) -> Option<u64> {
    456         self.until
    457     }
    458 
    459     pub fn search(&self) -> Option<&str> {
    460         self.search.as_deref()
    461     }
    462 
    463     pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> QueryResponse {
    464         let notes = {
    465             if let Ok(results) = ndb.query(txn, &[self.to_filter()], self.limit() as i32) {
    466                 results.into_iter().map(|r| r.note_key.as_u64()).collect()
    467             } else {
    468                 vec![]
    469             }
    470         };
    471         QueryResponse { notes }
    472     }
    473 
    474     pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> {
    475         match serde_json::from_str::<QueryCall>(args) {
    476             Ok(call) => Ok(ToolCalls::Query(call)),
    477             Err(e) => Err(ToolCallError::ArgParseFailure(format!(
    478                 "{args}, error: {e}"
    479             ))),
    480         }
    481     }
    482 }
    483 
    484 /// A simple note format for use when formatting
    485 /// tool responses
    486 #[derive(Debug, Serialize)]
    487 struct SimpleNote {
    488     note_id: String,
    489     pubkey: String,
    490     name: String,
    491     content: String,
    492     created_at: String,
    493     note_kind: u64, // todo: add replying to
    494 }
    495 
    496 /// Take the result of a tool response and present it to the ai so that
    497 /// it can interepret it and take further action
    498 fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String {
    499     match resp {
    500         ToolResponses::PresentNotes(n) => format!("{n} notes presented to the user"),
    501         ToolResponses::Error(s) => format!("error: {}", &s),
    502 
    503         ToolResponses::Query(search_r) => {
    504             let simple_notes: Vec<SimpleNote> = search_r
    505                 .notes
    506                 .iter()
    507                 .filter_map(|nkey| {
    508                     let Ok(note) = ndb.get_note_by_key(txn, NoteKey::new(*nkey)) else {
    509                         return None;
    510                     };
    511 
    512                     let name = ndb
    513                         .get_profile_by_pubkey(txn, note.pubkey())
    514                         .ok()
    515                         .and_then(|p| p.record().profile())
    516                         .and_then(|p| p.name().or_else(|| p.display_name()))
    517                         .unwrap_or("Anonymous")
    518                         .to_string();
    519 
    520                     let content = note.content().to_owned();
    521                     let pubkey = hex::encode(note.pubkey());
    522                     let note_kind = note.kind() as u64;
    523                     let note_id = hex::encode(note.id());
    524 
    525                     let created_at = {
    526                         let datetime =
    527                             DateTime::from_timestamp(note.created_at() as i64, 0).unwrap();
    528                         datetime.format("%Y-%m-%d %H:%M:%S").to_string()
    529                     };
    530 
    531                     Some(SimpleNote {
    532                         note_id,
    533                         pubkey,
    534                         name,
    535                         content,
    536                         created_at,
    537                         note_kind,
    538                     })
    539                 })
    540                 .collect();
    541 
    542             serde_json::to_string(&json!({"search_results": simple_notes})).unwrap()
    543         }
    544     }
    545 }
    546 
    547 fn _note_kind_desc(kind: u64) -> String {
    548     match kind {
    549         1 => "microblog".to_string(),
    550         0 => "profile".to_string(),
    551         _ => kind.to_string(),
    552     }
    553 }
    554 
    555 fn present_tool() -> Tool {
    556     Tool {
    557         name: "present_notes",
    558         parse_call: PresentNotesCall::parse,
    559         description: "A tool for presenting notes to the user for display. Should be called at the end of a response so that the UI can present the notes referred to in the previous message.",
    560         arguments: vec![
    561             ToolArg {
    562                 name: "note_ids",
    563                 description: "A comma-separated list of hex note ids",
    564                 typ: ArgType::String,
    565                 required: true,
    566                 default: None
    567             }
    568         ]
    569     }
    570 }
    571 
    572 fn query_tool() -> Tool {
    573     Tool {
    574         name: "query",
    575         parse_call: QueryCall::parse,
    576         description: "Note query functionality. Used for finding notes using full-text search terms, scoped by different contexts. You can use a combination of limit, since, and until to pull notes from any time range.",
    577         arguments: vec![
    578             ToolArg {
    579                 name: "search",
    580                 typ: ArgType::String,
    581                 required: false,
    582                 default: None,
    583                 description: "A fulltext search query. Queries with multiple words will only return results with notes that have all of those words. Don't include filler words/symbols like 'and', punctuation, etc",
    584             },
    585 
    586             ToolArg {
    587                 name: "limit",
    588                 typ: ArgType::Number,
    589                 required: true,
    590                 default: Some(Value::Number(serde_json::Number::from_i128(50).unwrap())),
    591                 description: "The number of results to return.",
    592             },
    593 
    594             ToolArg {
    595                 name: "since",
    596                 typ: ArgType::Number,
    597                 required: false,
    598                 default: None,
    599                 description: "Only pull notes after this unix timestamp",
    600             },
    601 
    602             ToolArg {
    603                 name: "until",
    604                 typ: ArgType::Number,
    605                 required: false,
    606                 default: None,
    607                 description: "Only pull notes up until this unix timestamp. Always include this when searching notes within some date range (yesterday, last week, etc).",
    608             },
    609 
    610             ToolArg {
    611                 name: "author",
    612                 typ: ArgType::String,
    613                 required: false,
    614                 default: None,
    615                 description: "An author *pubkey* to constrain the query on. Can be used to search for notes from individual users. If unsure what pubkey to u
    616 se, you can query for kind 0 profiles with the search argument.",
    617             },
    618 
    619             ToolArg {
    620                 name: "kind",
    621                 typ: ArgType::Number,
    622                 required: false,
    623                 default: Some(Value::Number(serde_json::Number::from_i128(1).unwrap())),
    624                 description: r#"The kind of note. Kind list:
    625                 - 0: profiles
    626                 - 1: microblogs/\"tweets\"/posts
    627                 - 6: reposts of kind 1 notes
    628                 - 7: emoji reactions/likes
    629                 - 9735: zaps (bitcoin micropayment receipts)
    630                 - 30023: longform articles, blog posts, etc
    631 
    632                 "#,
    633             },
    634 
    635         ]
    636     }
    637 }
    638 
    639 pub fn dave_tools() -> Vec<Tool> {
    640     vec![query_tool(), present_tool()]
    641 }