notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

tools.rs (19039B)


      1 use crate::messages::ExecutedTool;
      2 use async_openai::types::*;
      3 use chrono::DateTime;
      4 use enostr::{NoteId, Pubkey};
      5 use nostrdb::{Ndb, Note, NoteKey, Transaction};
      6 use serde::{Deserialize, Serialize};
      7 use serde_json::{json, Value};
      8 use std::{collections::HashMap, fmt};
      9 
     10 /// A tool
     11 #[derive(Debug, Clone, Serialize, Deserialize)]
     12 pub struct ToolCall {
     13     id: String,
     14     typ: ToolCalls,
     15 }
     16 
     17 impl ToolCall {
     18     pub fn new(id: String, typ: ToolCalls) -> Self {
     19         Self { id, typ }
     20     }
     21 
     22     pub fn id(&self) -> &str {
     23         &self.id
     24     }
     25 
     26     pub fn invalid(
     27         id: String,
     28         name: Option<String>,
     29         arguments: Option<String>,
     30         error: String,
     31     ) -> Self {
     32         Self {
     33             id,
     34             typ: ToolCalls::Invalid(InvalidToolCall {
     35                 name,
     36                 arguments,
     37                 error,
     38             }),
     39         }
     40     }
     41 
     42     pub fn calls(&self) -> &ToolCalls {
     43         &self.typ
     44     }
     45 
     46     pub fn to_api(&self) -> ChatCompletionMessageToolCall {
     47         ChatCompletionMessageToolCall {
     48             id: self.id.clone(),
     49             r#type: ChatCompletionToolType::Function,
     50             function: self.typ.to_api(),
     51         }
     52     }
     53 }
     54 
     55 /// On streaming APIs, tool calls are incremental. We use this
     56 /// to represent tool calls that are in the process of returning.
     57 /// These eventually just become [`ToolCall`]'s
     58 #[derive(Default, Debug, Clone, Serialize, Deserialize)]
     59 pub struct PartialToolCall {
     60     pub id: Option<String>,
     61     pub name: Option<String>,
     62     pub arguments: Option<String>,
     63 }
     64 
     65 impl PartialToolCall {
     66     pub fn id(&self) -> Option<&str> {
     67         self.id.as_deref()
     68     }
     69 
     70     pub fn id_mut(&mut self) -> &mut Option<String> {
     71         &mut self.id
     72     }
     73 
     74     pub fn name(&self) -> Option<&str> {
     75         self.name.as_deref()
     76     }
     77 
     78     pub fn name_mut(&mut self) -> &mut Option<String> {
     79         &mut self.name
     80     }
     81 
     82     pub fn arguments(&self) -> Option<&str> {
     83         self.arguments.as_deref()
     84     }
     85 
     86     pub fn arguments_mut(&mut self) -> &mut Option<String> {
     87         &mut self.arguments
     88     }
     89 }
     90 
     91 /// The query response from nostrdb for a given context
     92 #[derive(Debug, Clone, Serialize, Deserialize)]
     93 pub struct QueryResponse {
     94     pub notes: Vec<u64>,
     95 }
     96 
     97 #[derive(Debug, Clone, Serialize, Deserialize)]
     98 pub enum ToolResponses {
     99     Error(String),
    100     Query(QueryResponse),
    101     PresentNotes(i32),
    102     ExecutedTool(ExecutedTool),
    103 }
    104 
    105 #[derive(Debug, Clone)]
    106 pub struct UnknownToolCall {
    107     id: String,
    108     name: String,
    109     arguments: String,
    110 }
    111 
    112 impl UnknownToolCall {
    113     pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> {
    114         let Some(tool) = tools.get(&self.name) else {
    115             return Err(ToolCallError::NotFound(self.name.to_owned()));
    116         };
    117 
    118         let parsed_args = (tool.parse_call)(&self.arguments)?;
    119         Ok(ToolCall {
    120             id: self.id.clone(),
    121             typ: parsed_args,
    122         })
    123     }
    124 }
    125 
    126 impl PartialToolCall {
    127     pub fn complete(&self) -> Option<UnknownToolCall> {
    128         Some(UnknownToolCall {
    129             id: self.id.clone()?,
    130             name: self.name.clone()?,
    131             arguments: self.arguments.clone()?,
    132         })
    133     }
    134 }
    135 
    136 #[derive(Debug, Clone, Serialize, Deserialize)]
    137 pub struct InvalidToolCall {
    138     pub error: String,
    139     pub name: Option<String>,
    140     pub arguments: Option<String>,
    141 }
    142 
    143 /// An enumeration of the possible tool calls that
    144 /// can be parsed from Dave responses. When adding
    145 /// new tools, this needs to be updated so that we can
    146 /// handle tool call responses.
    147 #[derive(Debug, Clone, Serialize, Deserialize)]
    148 pub enum ToolCalls {
    149     Query(QueryCall),
    150     PresentNotes(PresentNotesCall),
    151     Invalid(InvalidToolCall),
    152 }
    153 
    154 impl ToolCalls {
    155     pub fn to_api(&self) -> FunctionCall {
    156         FunctionCall {
    157             name: self.name().to_owned(),
    158             arguments: self.arguments(),
    159         }
    160     }
    161 
    162     fn name(&self) -> &'static str {
    163         match self {
    164             Self::Query(_) => "search",
    165             Self::Invalid(_) => "error",
    166             Self::PresentNotes(_) => "present",
    167         }
    168     }
    169 
    170     /// Returns the tool name as defined in the tool registry (for prompt-based tool calls)
    171     pub fn tool_name(&self) -> &'static str {
    172         match self {
    173             Self::Query(_) => "query",
    174             Self::Invalid(_) => "invalid",
    175             Self::PresentNotes(_) => "present_notes",
    176         }
    177     }
    178 
    179     pub fn arguments(&self) -> String {
    180         match self {
    181             Self::Query(search) => serde_json::to_string(search).unwrap(),
    182             Self::Invalid(partial) => serde_json::to_string(partial).unwrap(),
    183             Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(),
    184         }
    185     }
    186 }
    187 
    188 #[derive(Debug)]
    189 pub enum ToolCallError {
    190     EmptyName,
    191     EmptyArgs,
    192     NotFound(String),
    193     ArgParseFailure(String),
    194 }
    195 
    196 impl fmt::Display for ToolCallError {
    197     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
    198         match self {
    199             ToolCallError::EmptyName => write!(f, "the tool name was empty"),
    200             ToolCallError::EmptyArgs => write!(f, "no arguments were provided"),
    201             ToolCallError::NotFound(ref name) => write!(f, "tool '{name}' not found"),
    202             ToolCallError::ArgParseFailure(ref msg) => {
    203                 write!(f, "failed to parse arguments: {msg}")
    204             }
    205         }
    206     }
    207 }
    208 
    209 #[derive(Debug, Clone)]
    210 enum ArgType {
    211     String,
    212     Number,
    213 
    214     #[allow(dead_code)]
    215     Enum(Vec<&'static str>),
    216 }
    217 
    218 impl ArgType {
    219     pub fn type_string(&self) -> &'static str {
    220         match self {
    221             Self::String => "string",
    222             Self::Number => "number",
    223             Self::Enum(_) => "string",
    224         }
    225     }
    226 }
    227 
    228 #[derive(Debug, Clone)]
    229 struct ToolArg {
    230     typ: ArgType,
    231     name: &'static str,
    232     required: bool,
    233     description: &'static str,
    234     default: Option<Value>,
    235 }
    236 
    237 #[derive(Debug, Clone)]
    238 pub struct Tool {
    239     parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>,
    240     name: &'static str,
    241     description: &'static str,
    242     arguments: Vec<ToolArg>,
    243 }
    244 
    245 impl Tool {
    246     pub fn name(&self) -> &'static str {
    247         self.name
    248     }
    249 
    250     pub fn description(&self) -> &'static str {
    251         self.description
    252     }
    253 
    254     pub fn parse_call(&self) -> fn(&str) -> Result<ToolCalls, ToolCallError> {
    255         self.parse_call
    256     }
    257 
    258     pub fn to_function_object(&self) -> FunctionObject {
    259         let required_args = self
    260             .arguments
    261             .iter()
    262             .filter_map(|arg| {
    263                 if arg.required {
    264                     Some(Value::String(arg.name.to_owned()))
    265                 } else {
    266                     None
    267                 }
    268             })
    269             .collect();
    270 
    271         let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new();
    272         parameters.insert("type".to_string(), Value::String("object".to_string()));
    273         parameters.insert("required".to_string(), Value::Array(required_args));
    274         parameters.insert("additionalProperties".to_string(), Value::Bool(false));
    275 
    276         let mut properties: serde_json::Map<String, Value> = serde_json::Map::new();
    277 
    278         for arg in &self.arguments {
    279             let mut props: serde_json::Map<String, Value> = serde_json::Map::new();
    280             props.insert(
    281                 "type".to_string(),
    282                 Value::String(arg.typ.type_string().to_string()),
    283             );
    284 
    285             let description = if let Some(default) = &arg.default {
    286                 format!("{} (Default: {default}))", arg.description)
    287             } else {
    288                 arg.description.to_owned()
    289             };
    290 
    291             props.insert("description".to_string(), Value::String(description));
    292             if let ArgType::Enum(enums) = &arg.typ {
    293                 props.insert(
    294                     "enum".to_string(),
    295                     Value::Array(
    296                         enums
    297                             .iter()
    298                             .map(|s| Value::String((*s).to_owned()))
    299                             .collect(),
    300                     ),
    301                 );
    302             }
    303 
    304             properties.insert(arg.name.to_owned(), Value::Object(props));
    305         }
    306 
    307         parameters.insert("properties".to_string(), Value::Object(properties));
    308 
    309         FunctionObject {
    310             name: self.name.to_owned(),
    311             description: Some(self.description.to_owned()),
    312             strict: Some(false),
    313             parameters: Some(Value::Object(parameters)),
    314         }
    315     }
    316 
    317     pub fn to_api(&self) -> ChatCompletionTool {
    318         ChatCompletionTool {
    319             r#type: ChatCompletionToolType::Function,
    320             function: self.to_function_object(),
    321         }
    322     }
    323 }
    324 
    325 impl ToolResponses {
    326     pub fn format_for_dave(&self, txn: &Transaction, ndb: &Ndb) -> String {
    327         format_tool_response_for_ai(txn, ndb, self)
    328     }
    329 }
    330 
    331 #[derive(Debug, Clone, Serialize, Deserialize)]
    332 pub struct ToolResponse {
    333     id: String,
    334     typ: ToolResponses,
    335 }
    336 
    337 impl ToolResponse {
    338     pub fn new(id: String, responses: ToolResponses) -> Self {
    339         Self { id, typ: responses }
    340     }
    341 
    342     pub fn error(id: String, msg: String) -> Self {
    343         Self {
    344             id,
    345             typ: ToolResponses::Error(msg),
    346         }
    347     }
    348 
    349     pub fn executed_tool(result: ExecutedTool) -> Self {
    350         Self {
    351             id: String::new(),
    352             typ: ToolResponses::ExecutedTool(result),
    353         }
    354     }
    355 
    356     pub fn responses(&self) -> &ToolResponses {
    357         &self.typ
    358     }
    359 
    360     pub fn id(&self) -> &str {
    361         &self.id
    362     }
    363 }
    364 
    365 /// Called by dave when he wants to display notes on the screen
    366 #[derive(Debug, Deserialize, Serialize, Clone)]
    367 pub struct PresentNotesCall {
    368     pub note_ids: Vec<NoteId>,
    369 }
    370 
    371 impl PresentNotesCall {
    372     fn to_simple(&self) -> PresentNotesCallSimple {
    373         let note_ids = self
    374             .note_ids
    375             .iter()
    376             .map(|nid| hex::encode(nid.bytes()))
    377             .collect::<Vec<_>>()
    378             .join(",");
    379 
    380         PresentNotesCallSimple { note_ids }
    381     }
    382 }
    383 
    384 /// Called by dave when he wants to display notes on the screen
    385 #[derive(Debug, Deserialize, Serialize, Clone)]
    386 pub struct PresentNotesCallSimple {
    387     note_ids: String,
    388 }
    389 
    390 impl PresentNotesCall {
    391     fn parse(args: &str) -> Result<ToolCalls, ToolCallError> {
    392         match serde_json::from_str::<PresentNotesCallSimple>(args) {
    393             Ok(call) => {
    394                 let note_ids = call
    395                     .note_ids
    396                     .split(",")
    397                     .filter_map(|n| NoteId::from_hex(n).ok())
    398                     .collect();
    399 
    400                 Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids }))
    401             }
    402             Err(e) => Err(ToolCallError::ArgParseFailure(format!(
    403                 "{args}, error: {e}"
    404             ))),
    405         }
    406     }
    407 }
    408 
    409 /// The parsed nostrdb query that dave wants to use to satisfy a request
    410 #[derive(Debug, Deserialize, Serialize, Clone)]
    411 pub struct QueryCall {
    412     pub author: Option<Pubkey>,
    413     pub limit: Option<u64>,
    414     pub since: Option<u64>,
    415     pub kind: Option<u64>,
    416     pub until: Option<u64>,
    417     pub search: Option<String>,
    418 }
    419 
    420 fn is_reply(note: Note) -> bool {
    421     for tag in note.tags() {
    422         if tag.count() < 4 {
    423             continue;
    424         }
    425 
    426         let Some("e") = tag.get_str(0) else {
    427             continue;
    428         };
    429 
    430         let Some(s) = tag.get_str(3) else {
    431             continue;
    432         };
    433 
    434         if s == "root" || s == "reply" {
    435             return true;
    436         }
    437     }
    438 
    439     false
    440 }
    441 
    442 impl QueryCall {
    443     pub fn to_filter(&self) -> nostrdb::Filter {
    444         let mut filter = nostrdb::Filter::new()
    445             .limit(self.limit())
    446             .custom(|n| !is_reply(n))
    447             .kinds([self.kind.unwrap_or(1)]);
    448 
    449         if let Some(author) = &self.author {
    450             filter = filter.authors([author.bytes()]);
    451         }
    452 
    453         if let Some(search) = &self.search {
    454             filter = filter.search(search);
    455         }
    456 
    457         if let Some(until) = self.until {
    458             filter = filter.until(until);
    459         }
    460 
    461         if let Some(since) = self.since {
    462             filter = filter.since(since);
    463         }
    464 
    465         filter.build()
    466     }
    467 
    468     fn limit(&self) -> u64 {
    469         self.limit.unwrap_or(10)
    470     }
    471 
    472     pub fn author(&self) -> Option<&Pubkey> {
    473         self.author.as_ref()
    474     }
    475 
    476     pub fn since(&self) -> Option<u64> {
    477         self.since
    478     }
    479 
    480     pub fn until(&self) -> Option<u64> {
    481         self.until
    482     }
    483 
    484     pub fn search(&self) -> Option<&str> {
    485         self.search.as_deref()
    486     }
    487 
    488     pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> QueryResponse {
    489         let notes = {
    490             if let Ok(results) = ndb.query(txn, &[self.to_filter()], self.limit() as i32) {
    491                 results.into_iter().map(|r| r.note_key.as_u64()).collect()
    492             } else {
    493                 vec![]
    494             }
    495         };
    496         QueryResponse { notes }
    497     }
    498 
    499     pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> {
    500         match serde_json::from_str::<QueryCall>(args) {
    501             Ok(call) => Ok(ToolCalls::Query(call)),
    502             Err(e) => Err(ToolCallError::ArgParseFailure(format!(
    503                 "{args}, error: {e}"
    504             ))),
    505         }
    506     }
    507 }
    508 
    509 /// A simple note format for use when formatting
    510 /// tool responses and thread summaries
    511 #[derive(Debug, Serialize)]
    512 pub struct SimpleNote {
    513     pub note_id: String,
    514     pub pubkey: String,
    515     pub name: String,
    516     pub content: String,
    517     pub created_at: String,
    518     pub note_kind: u64,
    519     #[serde(skip_serializing_if = "Option::is_none")]
    520     pub reply_to: Option<String>,
    521 }
    522 
    523 /// Convert a note to a SimpleNote for AI consumption.
    524 pub fn note_to_simple(txn: &Transaction, ndb: &Ndb, note: &Note<'_>) -> SimpleNote {
    525     let name = ndb
    526         .get_profile_by_pubkey(txn, note.pubkey())
    527         .ok()
    528         .and_then(|p| p.record().profile())
    529         .and_then(|p| p.name().or_else(|| p.display_name()))
    530         .unwrap_or("Anonymous")
    531         .to_string();
    532 
    533     let created_at = DateTime::from_timestamp(note.created_at() as i64, 0)
    534         .unwrap()
    535         .format("%Y-%m-%d %H:%M:%S")
    536         .to_string();
    537 
    538     let reply_to = nostrdb::NoteReply::new(note.tags())
    539         .reply()
    540         .map(|r| hex::encode(r.id));
    541 
    542     SimpleNote {
    543         note_id: hex::encode(note.id()),
    544         pubkey: hex::encode(note.pubkey()),
    545         name,
    546         content: note.content().to_owned(),
    547         created_at,
    548         note_kind: note.kind() as u64,
    549         reply_to,
    550     }
    551 }
    552 
    553 /// Format a list of SimpleNotes as JSON for AI consumption.
    554 pub fn format_simple_notes_json(notes: &[SimpleNote]) -> String {
    555     serde_json::to_string(&json!({"thread": notes})).unwrap()
    556 }
    557 
    558 /// Take the result of a tool response and present it to the ai so that
    559 /// it can interepret it and take further action
    560 fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String {
    561     match resp {
    562         ToolResponses::PresentNotes(n) => format!("{n} notes presented to the user"),
    563         ToolResponses::Error(s) => format!("error: {}", &s),
    564 
    565         ToolResponses::Query(search_r) => {
    566             let simple_notes: Vec<SimpleNote> = search_r
    567                 .notes
    568                 .iter()
    569                 .filter_map(|nkey| {
    570                     let note = ndb.get_note_by_key(txn, NoteKey::new(*nkey)).ok()?;
    571                     Some(note_to_simple(txn, ndb, &note))
    572                 })
    573                 .collect();
    574 
    575             serde_json::to_string(&json!({"search_results": simple_notes})).unwrap()
    576         }
    577 
    578         ToolResponses::ExecutedTool(r) => format!("{}: {}", r.tool_name, r.summary),
    579     }
    580 }
    581 
    582 fn _note_kind_desc(kind: u64) -> String {
    583     match kind {
    584         1 => "microblog".to_string(),
    585         0 => "profile".to_string(),
    586         _ => kind.to_string(),
    587     }
    588 }
    589 
    590 fn present_tool() -> Tool {
    591     Tool {
    592         name: "present_notes",
    593         parse_call: PresentNotesCall::parse,
    594         description: "A tool for presenting notes to the user for display. Should be called at the end of a response so that the UI can present the notes referred to in the previous message.",
    595         arguments: vec![ToolArg {
    596             name: "note_ids",
    597             description: "A comma-separated list of hex note ids",
    598             typ: ArgType::String,
    599             required: true,
    600             default: None,
    601         }],
    602     }
    603 }
    604 
    605 fn query_tool() -> Tool {
    606     Tool {
    607         name: "query",
    608         parse_call: QueryCall::parse,
    609         description: "Note query functionality. Used for finding notes using full-text search terms, scoped by different contexts. You can use a combination of limit, since, and until to pull notes from any time range.",
    610         arguments: vec![
    611             ToolArg {
    612                 name: "search",
    613                 typ: ArgType::String,
    614                 required: false,
    615                 default: None,
    616                 description: "A fulltext search query. Queries with multiple words will only return results with notes that have all of those words. Don't include filler words/symbols like 'and', punctuation, etc",
    617             },
    618 
    619             ToolArg {
    620                 name: "limit",
    621                 typ: ArgType::Number,
    622                 required: true,
    623                 default: Some(Value::Number(serde_json::Number::from_i128(50).unwrap())),
    624                 description: "The number of results to return.",
    625             },
    626 
    627             ToolArg {
    628                 name: "since",
    629                 typ: ArgType::Number,
    630                 required: false,
    631                 default: None,
    632                 description: "Only pull notes after this unix timestamp",
    633             },
    634 
    635             ToolArg {
    636                 name: "until",
    637                 typ: ArgType::Number,
    638                 required: false,
    639                 default: None,
    640                 description: "Only pull notes up until this unix timestamp. Always include this when searching notes within some date range (yesterday, last week, etc).",
    641             },
    642 
    643             ToolArg {
    644                 name: "author",
    645                 typ: ArgType::String,
    646                 required: false,
    647                 default: None,
    648                 description: "An author *pubkey* to constrain the query on. Can be used to search for notes from individual users. If unsure what pubkey to u
    649 se, you can query for kind 0 profiles with the search argument.",
    650             },
    651 
    652             ToolArg {
    653                 name: "kind",
    654                 typ: ArgType::Number,
    655                 required: false,
    656                 default: Some(Value::Number(serde_json::Number::from_i128(1).unwrap())),
    657                 description: r#"The kind of note. Kind list:
    658                 - 0: profiles
    659                 - 1: microblogs/\"tweets\"/posts
    660                 - 6: reposts of kind 1 notes
    661                 - 7: emoji reactions/likes
    662                 - 9735: zaps (bitcoin micropayment receipts)
    663                 - 30023: longform articles, blog posts, etc
    664 
    665                 "#,
    666             },
    667 
    668         ]
    669     }
    670 }
    671 
    672 pub fn dave_tools() -> Vec<Tool> {
    673     vec![query_tool(), present_tool()]
    674 }