tools.rs (19039B)
1 use crate::messages::ExecutedTool; 2 use async_openai::types::*; 3 use chrono::DateTime; 4 use enostr::{NoteId, Pubkey}; 5 use nostrdb::{Ndb, Note, NoteKey, Transaction}; 6 use serde::{Deserialize, Serialize}; 7 use serde_json::{json, Value}; 8 use std::{collections::HashMap, fmt}; 9 10 /// A tool 11 #[derive(Debug, Clone, Serialize, Deserialize)] 12 pub struct ToolCall { 13 id: String, 14 typ: ToolCalls, 15 } 16 17 impl ToolCall { 18 pub fn new(id: String, typ: ToolCalls) -> Self { 19 Self { id, typ } 20 } 21 22 pub fn id(&self) -> &str { 23 &self.id 24 } 25 26 pub fn invalid( 27 id: String, 28 name: Option<String>, 29 arguments: Option<String>, 30 error: String, 31 ) -> Self { 32 Self { 33 id, 34 typ: ToolCalls::Invalid(InvalidToolCall { 35 name, 36 arguments, 37 error, 38 }), 39 } 40 } 41 42 pub fn calls(&self) -> &ToolCalls { 43 &self.typ 44 } 45 46 pub fn to_api(&self) -> ChatCompletionMessageToolCall { 47 ChatCompletionMessageToolCall { 48 id: self.id.clone(), 49 r#type: ChatCompletionToolType::Function, 50 function: self.typ.to_api(), 51 } 52 } 53 } 54 55 /// On streaming APIs, tool calls are incremental. We use this 56 /// to represent tool calls that are in the process of returning. 57 /// These eventually just become [`ToolCall`]'s 58 #[derive(Default, Debug, Clone, Serialize, Deserialize)] 59 pub struct PartialToolCall { 60 pub id: Option<String>, 61 pub name: Option<String>, 62 pub arguments: Option<String>, 63 } 64 65 impl PartialToolCall { 66 pub fn id(&self) -> Option<&str> { 67 self.id.as_deref() 68 } 69 70 pub fn id_mut(&mut self) -> &mut Option<String> { 71 &mut self.id 72 } 73 74 pub fn name(&self) -> Option<&str> { 75 self.name.as_deref() 76 } 77 78 pub fn name_mut(&mut self) -> &mut Option<String> { 79 &mut self.name 80 } 81 82 pub fn arguments(&self) -> Option<&str> { 83 self.arguments.as_deref() 84 } 85 86 pub fn arguments_mut(&mut self) -> &mut Option<String> { 87 &mut self.arguments 88 } 89 } 90 91 /// The query response from nostrdb for a given context 92 #[derive(Debug, Clone, Serialize, Deserialize)] 93 pub struct QueryResponse { 94 pub notes: Vec<u64>, 95 } 96 97 #[derive(Debug, Clone, Serialize, Deserialize)] 98 pub enum ToolResponses { 99 Error(String), 100 Query(QueryResponse), 101 PresentNotes(i32), 102 ExecutedTool(ExecutedTool), 103 } 104 105 #[derive(Debug, Clone)] 106 pub struct UnknownToolCall { 107 id: String, 108 name: String, 109 arguments: String, 110 } 111 112 impl UnknownToolCall { 113 pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> { 114 let Some(tool) = tools.get(&self.name) else { 115 return Err(ToolCallError::NotFound(self.name.to_owned())); 116 }; 117 118 let parsed_args = (tool.parse_call)(&self.arguments)?; 119 Ok(ToolCall { 120 id: self.id.clone(), 121 typ: parsed_args, 122 }) 123 } 124 } 125 126 impl PartialToolCall { 127 pub fn complete(&self) -> Option<UnknownToolCall> { 128 Some(UnknownToolCall { 129 id: self.id.clone()?, 130 name: self.name.clone()?, 131 arguments: self.arguments.clone()?, 132 }) 133 } 134 } 135 136 #[derive(Debug, Clone, Serialize, Deserialize)] 137 pub struct InvalidToolCall { 138 pub error: String, 139 pub name: Option<String>, 140 pub arguments: Option<String>, 141 } 142 143 /// An enumeration of the possible tool calls that 144 /// can be parsed from Dave responses. When adding 145 /// new tools, this needs to be updated so that we can 146 /// handle tool call responses. 147 #[derive(Debug, Clone, Serialize, Deserialize)] 148 pub enum ToolCalls { 149 Query(QueryCall), 150 PresentNotes(PresentNotesCall), 151 Invalid(InvalidToolCall), 152 } 153 154 impl ToolCalls { 155 pub fn to_api(&self) -> FunctionCall { 156 FunctionCall { 157 name: self.name().to_owned(), 158 arguments: self.arguments(), 159 } 160 } 161 162 fn name(&self) -> &'static str { 163 match self { 164 Self::Query(_) => "search", 165 Self::Invalid(_) => "error", 166 Self::PresentNotes(_) => "present", 167 } 168 } 169 170 /// Returns the tool name as defined in the tool registry (for prompt-based tool calls) 171 pub fn tool_name(&self) -> &'static str { 172 match self { 173 Self::Query(_) => "query", 174 Self::Invalid(_) => "invalid", 175 Self::PresentNotes(_) => "present_notes", 176 } 177 } 178 179 pub fn arguments(&self) -> String { 180 match self { 181 Self::Query(search) => serde_json::to_string(search).unwrap(), 182 Self::Invalid(partial) => serde_json::to_string(partial).unwrap(), 183 Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(), 184 } 185 } 186 } 187 188 #[derive(Debug)] 189 pub enum ToolCallError { 190 EmptyName, 191 EmptyArgs, 192 NotFound(String), 193 ArgParseFailure(String), 194 } 195 196 impl fmt::Display for ToolCallError { 197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 198 match self { 199 ToolCallError::EmptyName => write!(f, "the tool name was empty"), 200 ToolCallError::EmptyArgs => write!(f, "no arguments were provided"), 201 ToolCallError::NotFound(ref name) => write!(f, "tool '{name}' not found"), 202 ToolCallError::ArgParseFailure(ref msg) => { 203 write!(f, "failed to parse arguments: {msg}") 204 } 205 } 206 } 207 } 208 209 #[derive(Debug, Clone)] 210 enum ArgType { 211 String, 212 Number, 213 214 #[allow(dead_code)] 215 Enum(Vec<&'static str>), 216 } 217 218 impl ArgType { 219 pub fn type_string(&self) -> &'static str { 220 match self { 221 Self::String => "string", 222 Self::Number => "number", 223 Self::Enum(_) => "string", 224 } 225 } 226 } 227 228 #[derive(Debug, Clone)] 229 struct ToolArg { 230 typ: ArgType, 231 name: &'static str, 232 required: bool, 233 description: &'static str, 234 default: Option<Value>, 235 } 236 237 #[derive(Debug, Clone)] 238 pub struct Tool { 239 parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>, 240 name: &'static str, 241 description: &'static str, 242 arguments: Vec<ToolArg>, 243 } 244 245 impl Tool { 246 pub fn name(&self) -> &'static str { 247 self.name 248 } 249 250 pub fn description(&self) -> &'static str { 251 self.description 252 } 253 254 pub fn parse_call(&self) -> fn(&str) -> Result<ToolCalls, ToolCallError> { 255 self.parse_call 256 } 257 258 pub fn to_function_object(&self) -> FunctionObject { 259 let required_args = self 260 .arguments 261 .iter() 262 .filter_map(|arg| { 263 if arg.required { 264 Some(Value::String(arg.name.to_owned())) 265 } else { 266 None 267 } 268 }) 269 .collect(); 270 271 let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new(); 272 parameters.insert("type".to_string(), Value::String("object".to_string())); 273 parameters.insert("required".to_string(), Value::Array(required_args)); 274 parameters.insert("additionalProperties".to_string(), Value::Bool(false)); 275 276 let mut properties: serde_json::Map<String, Value> = serde_json::Map::new(); 277 278 for arg in &self.arguments { 279 let mut props: serde_json::Map<String, Value> = serde_json::Map::new(); 280 props.insert( 281 "type".to_string(), 282 Value::String(arg.typ.type_string().to_string()), 283 ); 284 285 let description = if let Some(default) = &arg.default { 286 format!("{} (Default: {default}))", arg.description) 287 } else { 288 arg.description.to_owned() 289 }; 290 291 props.insert("description".to_string(), Value::String(description)); 292 if let ArgType::Enum(enums) = &arg.typ { 293 props.insert( 294 "enum".to_string(), 295 Value::Array( 296 enums 297 .iter() 298 .map(|s| Value::String((*s).to_owned())) 299 .collect(), 300 ), 301 ); 302 } 303 304 properties.insert(arg.name.to_owned(), Value::Object(props)); 305 } 306 307 parameters.insert("properties".to_string(), Value::Object(properties)); 308 309 FunctionObject { 310 name: self.name.to_owned(), 311 description: Some(self.description.to_owned()), 312 strict: Some(false), 313 parameters: Some(Value::Object(parameters)), 314 } 315 } 316 317 pub fn to_api(&self) -> ChatCompletionTool { 318 ChatCompletionTool { 319 r#type: ChatCompletionToolType::Function, 320 function: self.to_function_object(), 321 } 322 } 323 } 324 325 impl ToolResponses { 326 pub fn format_for_dave(&self, txn: &Transaction, ndb: &Ndb) -> String { 327 format_tool_response_for_ai(txn, ndb, self) 328 } 329 } 330 331 #[derive(Debug, Clone, Serialize, Deserialize)] 332 pub struct ToolResponse { 333 id: String, 334 typ: ToolResponses, 335 } 336 337 impl ToolResponse { 338 pub fn new(id: String, responses: ToolResponses) -> Self { 339 Self { id, typ: responses } 340 } 341 342 pub fn error(id: String, msg: String) -> Self { 343 Self { 344 id, 345 typ: ToolResponses::Error(msg), 346 } 347 } 348 349 pub fn executed_tool(result: ExecutedTool) -> Self { 350 Self { 351 id: String::new(), 352 typ: ToolResponses::ExecutedTool(result), 353 } 354 } 355 356 pub fn responses(&self) -> &ToolResponses { 357 &self.typ 358 } 359 360 pub fn id(&self) -> &str { 361 &self.id 362 } 363 } 364 365 /// Called by dave when he wants to display notes on the screen 366 #[derive(Debug, Deserialize, Serialize, Clone)] 367 pub struct PresentNotesCall { 368 pub note_ids: Vec<NoteId>, 369 } 370 371 impl PresentNotesCall { 372 fn to_simple(&self) -> PresentNotesCallSimple { 373 let note_ids = self 374 .note_ids 375 .iter() 376 .map(|nid| hex::encode(nid.bytes())) 377 .collect::<Vec<_>>() 378 .join(","); 379 380 PresentNotesCallSimple { note_ids } 381 } 382 } 383 384 /// Called by dave when he wants to display notes on the screen 385 #[derive(Debug, Deserialize, Serialize, Clone)] 386 pub struct PresentNotesCallSimple { 387 note_ids: String, 388 } 389 390 impl PresentNotesCall { 391 fn parse(args: &str) -> Result<ToolCalls, ToolCallError> { 392 match serde_json::from_str::<PresentNotesCallSimple>(args) { 393 Ok(call) => { 394 let note_ids = call 395 .note_ids 396 .split(",") 397 .filter_map(|n| NoteId::from_hex(n).ok()) 398 .collect(); 399 400 Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids })) 401 } 402 Err(e) => Err(ToolCallError::ArgParseFailure(format!( 403 "{args}, error: {e}" 404 ))), 405 } 406 } 407 } 408 409 /// The parsed nostrdb query that dave wants to use to satisfy a request 410 #[derive(Debug, Deserialize, Serialize, Clone)] 411 pub struct QueryCall { 412 pub author: Option<Pubkey>, 413 pub limit: Option<u64>, 414 pub since: Option<u64>, 415 pub kind: Option<u64>, 416 pub until: Option<u64>, 417 pub search: Option<String>, 418 } 419 420 fn is_reply(note: Note) -> bool { 421 for tag in note.tags() { 422 if tag.count() < 4 { 423 continue; 424 } 425 426 let Some("e") = tag.get_str(0) else { 427 continue; 428 }; 429 430 let Some(s) = tag.get_str(3) else { 431 continue; 432 }; 433 434 if s == "root" || s == "reply" { 435 return true; 436 } 437 } 438 439 false 440 } 441 442 impl QueryCall { 443 pub fn to_filter(&self) -> nostrdb::Filter { 444 let mut filter = nostrdb::Filter::new() 445 .limit(self.limit()) 446 .custom(|n| !is_reply(n)) 447 .kinds([self.kind.unwrap_or(1)]); 448 449 if let Some(author) = &self.author { 450 filter = filter.authors([author.bytes()]); 451 } 452 453 if let Some(search) = &self.search { 454 filter = filter.search(search); 455 } 456 457 if let Some(until) = self.until { 458 filter = filter.until(until); 459 } 460 461 if let Some(since) = self.since { 462 filter = filter.since(since); 463 } 464 465 filter.build() 466 } 467 468 fn limit(&self) -> u64 { 469 self.limit.unwrap_or(10) 470 } 471 472 pub fn author(&self) -> Option<&Pubkey> { 473 self.author.as_ref() 474 } 475 476 pub fn since(&self) -> Option<u64> { 477 self.since 478 } 479 480 pub fn until(&self) -> Option<u64> { 481 self.until 482 } 483 484 pub fn search(&self) -> Option<&str> { 485 self.search.as_deref() 486 } 487 488 pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> QueryResponse { 489 let notes = { 490 if let Ok(results) = ndb.query(txn, &[self.to_filter()], self.limit() as i32) { 491 results.into_iter().map(|r| r.note_key.as_u64()).collect() 492 } else { 493 vec![] 494 } 495 }; 496 QueryResponse { notes } 497 } 498 499 pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> { 500 match serde_json::from_str::<QueryCall>(args) { 501 Ok(call) => Ok(ToolCalls::Query(call)), 502 Err(e) => Err(ToolCallError::ArgParseFailure(format!( 503 "{args}, error: {e}" 504 ))), 505 } 506 } 507 } 508 509 /// A simple note format for use when formatting 510 /// tool responses and thread summaries 511 #[derive(Debug, Serialize)] 512 pub struct SimpleNote { 513 pub note_id: String, 514 pub pubkey: String, 515 pub name: String, 516 pub content: String, 517 pub created_at: String, 518 pub note_kind: u64, 519 #[serde(skip_serializing_if = "Option::is_none")] 520 pub reply_to: Option<String>, 521 } 522 523 /// Convert a note to a SimpleNote for AI consumption. 524 pub fn note_to_simple(txn: &Transaction, ndb: &Ndb, note: &Note<'_>) -> SimpleNote { 525 let name = ndb 526 .get_profile_by_pubkey(txn, note.pubkey()) 527 .ok() 528 .and_then(|p| p.record().profile()) 529 .and_then(|p| p.name().or_else(|| p.display_name())) 530 .unwrap_or("Anonymous") 531 .to_string(); 532 533 let created_at = DateTime::from_timestamp(note.created_at() as i64, 0) 534 .unwrap() 535 .format("%Y-%m-%d %H:%M:%S") 536 .to_string(); 537 538 let reply_to = nostrdb::NoteReply::new(note.tags()) 539 .reply() 540 .map(|r| hex::encode(r.id)); 541 542 SimpleNote { 543 note_id: hex::encode(note.id()), 544 pubkey: hex::encode(note.pubkey()), 545 name, 546 content: note.content().to_owned(), 547 created_at, 548 note_kind: note.kind() as u64, 549 reply_to, 550 } 551 } 552 553 /// Format a list of SimpleNotes as JSON for AI consumption. 554 pub fn format_simple_notes_json(notes: &[SimpleNote]) -> String { 555 serde_json::to_string(&json!({"thread": notes})).unwrap() 556 } 557 558 /// Take the result of a tool response and present it to the ai so that 559 /// it can interepret it and take further action 560 fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String { 561 match resp { 562 ToolResponses::PresentNotes(n) => format!("{n} notes presented to the user"), 563 ToolResponses::Error(s) => format!("error: {}", &s), 564 565 ToolResponses::Query(search_r) => { 566 let simple_notes: Vec<SimpleNote> = search_r 567 .notes 568 .iter() 569 .filter_map(|nkey| { 570 let note = ndb.get_note_by_key(txn, NoteKey::new(*nkey)).ok()?; 571 Some(note_to_simple(txn, ndb, ¬e)) 572 }) 573 .collect(); 574 575 serde_json::to_string(&json!({"search_results": simple_notes})).unwrap() 576 } 577 578 ToolResponses::ExecutedTool(r) => format!("{}: {}", r.tool_name, r.summary), 579 } 580 } 581 582 fn _note_kind_desc(kind: u64) -> String { 583 match kind { 584 1 => "microblog".to_string(), 585 0 => "profile".to_string(), 586 _ => kind.to_string(), 587 } 588 } 589 590 fn present_tool() -> Tool { 591 Tool { 592 name: "present_notes", 593 parse_call: PresentNotesCall::parse, 594 description: "A tool for presenting notes to the user for display. Should be called at the end of a response so that the UI can present the notes referred to in the previous message.", 595 arguments: vec![ToolArg { 596 name: "note_ids", 597 description: "A comma-separated list of hex note ids", 598 typ: ArgType::String, 599 required: true, 600 default: None, 601 }], 602 } 603 } 604 605 fn query_tool() -> Tool { 606 Tool { 607 name: "query", 608 parse_call: QueryCall::parse, 609 description: "Note query functionality. Used for finding notes using full-text search terms, scoped by different contexts. You can use a combination of limit, since, and until to pull notes from any time range.", 610 arguments: vec![ 611 ToolArg { 612 name: "search", 613 typ: ArgType::String, 614 required: false, 615 default: None, 616 description: "A fulltext search query. Queries with multiple words will only return results with notes that have all of those words. Don't include filler words/symbols like 'and', punctuation, etc", 617 }, 618 619 ToolArg { 620 name: "limit", 621 typ: ArgType::Number, 622 required: true, 623 default: Some(Value::Number(serde_json::Number::from_i128(50).unwrap())), 624 description: "The number of results to return.", 625 }, 626 627 ToolArg { 628 name: "since", 629 typ: ArgType::Number, 630 required: false, 631 default: None, 632 description: "Only pull notes after this unix timestamp", 633 }, 634 635 ToolArg { 636 name: "until", 637 typ: ArgType::Number, 638 required: false, 639 default: None, 640 description: "Only pull notes up until this unix timestamp. Always include this when searching notes within some date range (yesterday, last week, etc).", 641 }, 642 643 ToolArg { 644 name: "author", 645 typ: ArgType::String, 646 required: false, 647 default: None, 648 description: "An author *pubkey* to constrain the query on. Can be used to search for notes from individual users. If unsure what pubkey to u 649 se, you can query for kind 0 profiles with the search argument.", 650 }, 651 652 ToolArg { 653 name: "kind", 654 typ: ArgType::Number, 655 required: false, 656 default: Some(Value::Number(serde_json::Number::from_i128(1).unwrap())), 657 description: r#"The kind of note. Kind list: 658 - 0: profiles 659 - 1: microblogs/\"tweets\"/posts 660 - 6: reposts of kind 1 notes 661 - 7: emoji reactions/likes 662 - 9735: zaps (bitcoin micropayment receipts) 663 - 30023: longform articles, blog posts, etc 664 665 "#, 666 }, 667 668 ] 669 } 670 } 671 672 pub fn dave_tools() -> Vec<Tool> { 673 vec![query_tool(), present_tool()] 674 }