tools.rs (18685B)
1 use async_openai::types::*; 2 use chrono::DateTime; 3 use enostr::{NoteId, Pubkey}; 4 use nostrdb::{Ndb, Note, NoteKey, Transaction}; 5 use serde::{Deserialize, Serialize}; 6 use serde_json::{json, Value}; 7 use std::{collections::HashMap, fmt}; 8 9 /// A tool 10 #[derive(Debug, Clone, Serialize, Deserialize)] 11 pub struct ToolCall { 12 id: String, 13 typ: ToolCalls, 14 } 15 16 impl ToolCall { 17 pub fn new(id: String, typ: ToolCalls) -> Self { 18 Self { id, typ } 19 } 20 21 pub fn id(&self) -> &str { 22 &self.id 23 } 24 25 pub fn invalid( 26 id: String, 27 name: Option<String>, 28 arguments: Option<String>, 29 error: String, 30 ) -> Self { 31 Self { 32 id, 33 typ: ToolCalls::Invalid(InvalidToolCall { 34 name, 35 arguments, 36 error, 37 }), 38 } 39 } 40 41 pub fn calls(&self) -> &ToolCalls { 42 &self.typ 43 } 44 45 pub fn to_api(&self) -> ChatCompletionMessageToolCall { 46 ChatCompletionMessageToolCall { 47 id: self.id.clone(), 48 r#type: ChatCompletionToolType::Function, 49 function: self.typ.to_api(), 50 } 51 } 52 } 53 54 /// On streaming APIs, tool calls are incremental. We use this 55 /// to represent tool calls that are in the process of returning. 56 /// These eventually just become [`ToolCall`]'s 57 #[derive(Default, Debug, Clone, Serialize, Deserialize)] 58 pub struct PartialToolCall { 59 pub id: Option<String>, 60 pub name: Option<String>, 61 pub arguments: Option<String>, 62 } 63 64 impl PartialToolCall { 65 pub fn id(&self) -> Option<&str> { 66 self.id.as_deref() 67 } 68 69 pub fn id_mut(&mut self) -> &mut Option<String> { 70 &mut self.id 71 } 72 73 pub fn name(&self) -> Option<&str> { 74 self.name.as_deref() 75 } 76 77 pub fn name_mut(&mut self) -> &mut Option<String> { 78 &mut self.name 79 } 80 81 pub fn arguments(&self) -> Option<&str> { 82 self.arguments.as_deref() 83 } 84 85 pub fn arguments_mut(&mut self) -> &mut Option<String> { 86 &mut self.arguments 87 } 88 } 89 90 /// The query response from nostrdb for a given context 91 #[derive(Debug, Clone, Serialize, Deserialize)] 92 pub struct QueryResponse { 93 pub notes: Vec<u64>, 94 } 95 96 #[derive(Debug, Clone, Serialize, Deserialize)] 97 pub enum ToolResponses { 98 Error(String), 99 Query(QueryResponse), 100 PresentNotes(i32), 101 } 102 103 #[derive(Debug, Clone)] 104 pub struct UnknownToolCall { 105 id: String, 106 name: String, 107 arguments: String, 108 } 109 110 impl UnknownToolCall { 111 pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> { 112 let Some(tool) = tools.get(&self.name) else { 113 return Err(ToolCallError::NotFound(self.name.to_owned())); 114 }; 115 116 let parsed_args = (tool.parse_call)(&self.arguments)?; 117 Ok(ToolCall { 118 id: self.id.clone(), 119 typ: parsed_args, 120 }) 121 } 122 } 123 124 impl PartialToolCall { 125 pub fn complete(&self) -> Option<UnknownToolCall> { 126 Some(UnknownToolCall { 127 id: self.id.clone()?, 128 name: self.name.clone()?, 129 arguments: self.arguments.clone()?, 130 }) 131 } 132 } 133 134 #[derive(Debug, Clone, Serialize, Deserialize)] 135 pub struct InvalidToolCall { 136 pub error: String, 137 pub name: Option<String>, 138 pub arguments: Option<String>, 139 } 140 141 /// An enumeration of the possible tool calls that 142 /// can be parsed from Dave responses. When adding 143 /// new tools, this needs to be updated so that we can 144 /// handle tool call responses. 145 #[derive(Debug, Clone, Serialize, Deserialize)] 146 pub enum ToolCalls { 147 Query(QueryCall), 148 PresentNotes(PresentNotesCall), 149 Invalid(InvalidToolCall), 150 } 151 152 impl ToolCalls { 153 pub fn to_api(&self) -> FunctionCall { 154 FunctionCall { 155 name: self.name().to_owned(), 156 arguments: self.arguments(), 157 } 158 } 159 160 fn name(&self) -> &'static str { 161 match self { 162 Self::Query(_) => "search", 163 Self::Invalid(_) => "error", 164 Self::PresentNotes(_) => "present", 165 } 166 } 167 168 /// Returns the tool name as defined in the tool registry (for prompt-based tool calls) 169 pub fn tool_name(&self) -> &'static str { 170 match self { 171 Self::Query(_) => "query", 172 Self::Invalid(_) => "invalid", 173 Self::PresentNotes(_) => "present_notes", 174 } 175 } 176 177 pub fn arguments(&self) -> String { 178 match self { 179 Self::Query(search) => serde_json::to_string(search).unwrap(), 180 Self::Invalid(partial) => serde_json::to_string(partial).unwrap(), 181 Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(), 182 } 183 } 184 } 185 186 #[derive(Debug)] 187 pub enum ToolCallError { 188 EmptyName, 189 EmptyArgs, 190 NotFound(String), 191 ArgParseFailure(String), 192 } 193 194 impl fmt::Display for ToolCallError { 195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 196 match self { 197 ToolCallError::EmptyName => write!(f, "the tool name was empty"), 198 ToolCallError::EmptyArgs => write!(f, "no arguments were provided"), 199 ToolCallError::NotFound(ref name) => write!(f, "tool '{name}' not found"), 200 ToolCallError::ArgParseFailure(ref msg) => { 201 write!(f, "failed to parse arguments: {msg}") 202 } 203 } 204 } 205 } 206 207 #[derive(Debug, Clone)] 208 enum ArgType { 209 String, 210 Number, 211 212 #[allow(dead_code)] 213 Enum(Vec<&'static str>), 214 } 215 216 impl ArgType { 217 pub fn type_string(&self) -> &'static str { 218 match self { 219 Self::String => "string", 220 Self::Number => "number", 221 Self::Enum(_) => "string", 222 } 223 } 224 } 225 226 #[derive(Debug, Clone)] 227 struct ToolArg { 228 typ: ArgType, 229 name: &'static str, 230 required: bool, 231 description: &'static str, 232 default: Option<Value>, 233 } 234 235 #[derive(Debug, Clone)] 236 pub struct Tool { 237 parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>, 238 name: &'static str, 239 description: &'static str, 240 arguments: Vec<ToolArg>, 241 } 242 243 impl Tool { 244 pub fn name(&self) -> &'static str { 245 self.name 246 } 247 248 pub fn description(&self) -> &'static str { 249 self.description 250 } 251 252 pub fn parse_call(&self) -> fn(&str) -> Result<ToolCalls, ToolCallError> { 253 self.parse_call 254 } 255 256 pub fn to_function_object(&self) -> FunctionObject { 257 let required_args = self 258 .arguments 259 .iter() 260 .filter_map(|arg| { 261 if arg.required { 262 Some(Value::String(arg.name.to_owned())) 263 } else { 264 None 265 } 266 }) 267 .collect(); 268 269 let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new(); 270 parameters.insert("type".to_string(), Value::String("object".to_string())); 271 parameters.insert("required".to_string(), Value::Array(required_args)); 272 parameters.insert("additionalProperties".to_string(), Value::Bool(false)); 273 274 let mut properties: serde_json::Map<String, Value> = serde_json::Map::new(); 275 276 for arg in &self.arguments { 277 let mut props: serde_json::Map<String, Value> = serde_json::Map::new(); 278 props.insert( 279 "type".to_string(), 280 Value::String(arg.typ.type_string().to_string()), 281 ); 282 283 let description = if let Some(default) = &arg.default { 284 format!("{} (Default: {default}))", arg.description) 285 } else { 286 arg.description.to_owned() 287 }; 288 289 props.insert("description".to_string(), Value::String(description)); 290 if let ArgType::Enum(enums) = &arg.typ { 291 props.insert( 292 "enum".to_string(), 293 Value::Array( 294 enums 295 .iter() 296 .map(|s| Value::String((*s).to_owned())) 297 .collect(), 298 ), 299 ); 300 } 301 302 properties.insert(arg.name.to_owned(), Value::Object(props)); 303 } 304 305 parameters.insert("properties".to_string(), Value::Object(properties)); 306 307 FunctionObject { 308 name: self.name.to_owned(), 309 description: Some(self.description.to_owned()), 310 strict: Some(false), 311 parameters: Some(Value::Object(parameters)), 312 } 313 } 314 315 pub fn to_api(&self) -> ChatCompletionTool { 316 ChatCompletionTool { 317 r#type: ChatCompletionToolType::Function, 318 function: self.to_function_object(), 319 } 320 } 321 } 322 323 impl ToolResponses { 324 pub fn format_for_dave(&self, txn: &Transaction, ndb: &Ndb) -> String { 325 format_tool_response_for_ai(txn, ndb, self) 326 } 327 } 328 329 #[derive(Debug, Clone, Serialize, Deserialize)] 330 pub struct ToolResponse { 331 id: String, 332 typ: ToolResponses, 333 } 334 335 impl ToolResponse { 336 pub fn new(id: String, responses: ToolResponses) -> Self { 337 Self { id, typ: responses } 338 } 339 340 pub fn error(id: String, msg: String) -> Self { 341 Self { 342 id, 343 typ: ToolResponses::Error(msg), 344 } 345 } 346 347 pub fn responses(&self) -> &ToolResponses { 348 &self.typ 349 } 350 351 pub fn id(&self) -> &str { 352 &self.id 353 } 354 } 355 356 /// Called by dave when he wants to display notes on the screen 357 #[derive(Debug, Deserialize, Serialize, Clone)] 358 pub struct PresentNotesCall { 359 pub note_ids: Vec<NoteId>, 360 } 361 362 impl PresentNotesCall { 363 fn to_simple(&self) -> PresentNotesCallSimple { 364 let note_ids = self 365 .note_ids 366 .iter() 367 .map(|nid| hex::encode(nid.bytes())) 368 .collect::<Vec<_>>() 369 .join(","); 370 371 PresentNotesCallSimple { note_ids } 372 } 373 } 374 375 /// Called by dave when he wants to display notes on the screen 376 #[derive(Debug, Deserialize, Serialize, Clone)] 377 pub struct PresentNotesCallSimple { 378 note_ids: String, 379 } 380 381 impl PresentNotesCall { 382 fn parse(args: &str) -> Result<ToolCalls, ToolCallError> { 383 match serde_json::from_str::<PresentNotesCallSimple>(args) { 384 Ok(call) => { 385 let note_ids = call 386 .note_ids 387 .split(",") 388 .filter_map(|n| NoteId::from_hex(n).ok()) 389 .collect(); 390 391 Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids })) 392 } 393 Err(e) => Err(ToolCallError::ArgParseFailure(format!( 394 "{args}, error: {e}" 395 ))), 396 } 397 } 398 } 399 400 /// The parsed nostrdb query that dave wants to use to satisfy a request 401 #[derive(Debug, Deserialize, Serialize, Clone)] 402 pub struct QueryCall { 403 pub author: Option<Pubkey>, 404 pub limit: Option<u64>, 405 pub since: Option<u64>, 406 pub kind: Option<u64>, 407 pub until: Option<u64>, 408 pub search: Option<String>, 409 } 410 411 fn is_reply(note: Note) -> bool { 412 for tag in note.tags() { 413 if tag.count() < 4 { 414 continue; 415 } 416 417 let Some("e") = tag.get_str(0) else { 418 continue; 419 }; 420 421 let Some(s) = tag.get_str(3) else { 422 continue; 423 }; 424 425 if s == "root" || s == "reply" { 426 return true; 427 } 428 } 429 430 false 431 } 432 433 impl QueryCall { 434 pub fn to_filter(&self) -> nostrdb::Filter { 435 let mut filter = nostrdb::Filter::new() 436 .limit(self.limit()) 437 .custom(|n| !is_reply(n)) 438 .kinds([self.kind.unwrap_or(1)]); 439 440 if let Some(author) = &self.author { 441 filter = filter.authors([author.bytes()]); 442 } 443 444 if let Some(search) = &self.search { 445 filter = filter.search(search); 446 } 447 448 if let Some(until) = self.until { 449 filter = filter.until(until); 450 } 451 452 if let Some(since) = self.since { 453 filter = filter.since(since); 454 } 455 456 filter.build() 457 } 458 459 fn limit(&self) -> u64 { 460 self.limit.unwrap_or(10) 461 } 462 463 pub fn author(&self) -> Option<&Pubkey> { 464 self.author.as_ref() 465 } 466 467 pub fn since(&self) -> Option<u64> { 468 self.since 469 } 470 471 pub fn until(&self) -> Option<u64> { 472 self.until 473 } 474 475 pub fn search(&self) -> Option<&str> { 476 self.search.as_deref() 477 } 478 479 pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> QueryResponse { 480 let notes = { 481 if let Ok(results) = ndb.query(txn, &[self.to_filter()], self.limit() as i32) { 482 results.into_iter().map(|r| r.note_key.as_u64()).collect() 483 } else { 484 vec![] 485 } 486 }; 487 QueryResponse { notes } 488 } 489 490 pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> { 491 match serde_json::from_str::<QueryCall>(args) { 492 Ok(call) => Ok(ToolCalls::Query(call)), 493 Err(e) => Err(ToolCallError::ArgParseFailure(format!( 494 "{args}, error: {e}" 495 ))), 496 } 497 } 498 } 499 500 /// A simple note format for use when formatting 501 /// tool responses 502 #[derive(Debug, Serialize)] 503 struct SimpleNote { 504 note_id: String, 505 pubkey: String, 506 name: String, 507 content: String, 508 created_at: String, 509 note_kind: u64, // todo: add replying to 510 } 511 512 /// Take the result of a tool response and present it to the ai so that 513 /// it can interepret it and take further action 514 fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String { 515 match resp { 516 ToolResponses::PresentNotes(n) => format!("{n} notes presented to the user"), 517 ToolResponses::Error(s) => format!("error: {}", &s), 518 519 ToolResponses::Query(search_r) => { 520 let simple_notes: Vec<SimpleNote> = search_r 521 .notes 522 .iter() 523 .filter_map(|nkey| { 524 let Ok(note) = ndb.get_note_by_key(txn, NoteKey::new(*nkey)) else { 525 return None; 526 }; 527 528 let name = ndb 529 .get_profile_by_pubkey(txn, note.pubkey()) 530 .ok() 531 .and_then(|p| p.record().profile()) 532 .and_then(|p| p.name().or_else(|| p.display_name())) 533 .unwrap_or("Anonymous") 534 .to_string(); 535 536 let content = note.content().to_owned(); 537 let pubkey = hex::encode(note.pubkey()); 538 let note_kind = note.kind() as u64; 539 let note_id = hex::encode(note.id()); 540 541 let created_at = { 542 let datetime = 543 DateTime::from_timestamp(note.created_at() as i64, 0).unwrap(); 544 datetime.format("%Y-%m-%d %H:%M:%S").to_string() 545 }; 546 547 Some(SimpleNote { 548 note_id, 549 pubkey, 550 name, 551 content, 552 created_at, 553 note_kind, 554 }) 555 }) 556 .collect(); 557 558 serde_json::to_string(&json!({"search_results": simple_notes})).unwrap() 559 } 560 } 561 } 562 563 fn _note_kind_desc(kind: u64) -> String { 564 match kind { 565 1 => "microblog".to_string(), 566 0 => "profile".to_string(), 567 _ => kind.to_string(), 568 } 569 } 570 571 fn present_tool() -> Tool { 572 Tool { 573 name: "present_notes", 574 parse_call: PresentNotesCall::parse, 575 description: "A tool for presenting notes to the user for display. Should be called at the end of a response so that the UI can present the notes referred to in the previous message.", 576 arguments: vec![ 577 ToolArg { 578 name: "note_ids", 579 description: "A comma-separated list of hex note ids", 580 typ: ArgType::String, 581 required: true, 582 default: None 583 } 584 ] 585 } 586 } 587 588 fn query_tool() -> Tool { 589 Tool { 590 name: "query", 591 parse_call: QueryCall::parse, 592 description: "Note query functionality. Used for finding notes using full-text search terms, scoped by different contexts. You can use a combination of limit, since, and until to pull notes from any time range.", 593 arguments: vec![ 594 ToolArg { 595 name: "search", 596 typ: ArgType::String, 597 required: false, 598 default: None, 599 description: "A fulltext search query. Queries with multiple words will only return results with notes that have all of those words. Don't include filler words/symbols like 'and', punctuation, etc", 600 }, 601 602 ToolArg { 603 name: "limit", 604 typ: ArgType::Number, 605 required: true, 606 default: Some(Value::Number(serde_json::Number::from_i128(50).unwrap())), 607 description: "The number of results to return.", 608 }, 609 610 ToolArg { 611 name: "since", 612 typ: ArgType::Number, 613 required: false, 614 default: None, 615 description: "Only pull notes after this unix timestamp", 616 }, 617 618 ToolArg { 619 name: "until", 620 typ: ArgType::Number, 621 required: false, 622 default: None, 623 description: "Only pull notes up until this unix timestamp. Always include this when searching notes within some date range (yesterday, last week, etc).", 624 }, 625 626 ToolArg { 627 name: "author", 628 typ: ArgType::String, 629 required: false, 630 default: None, 631 description: "An author *pubkey* to constrain the query on. Can be used to search for notes from individual users. If unsure what pubkey to u 632 se, you can query for kind 0 profiles with the search argument.", 633 }, 634 635 ToolArg { 636 name: "kind", 637 typ: ArgType::Number, 638 required: false, 639 default: Some(Value::Number(serde_json::Number::from_i128(1).unwrap())), 640 description: r#"The kind of note. Kind list: 641 - 0: profiles 642 - 1: microblogs/\"tweets\"/posts 643 - 6: reposts of kind 1 notes 644 - 7: emoji reactions/likes 645 - 9735: zaps (bitcoin micropayment receipts) 646 - 30023: longform articles, blog posts, etc 647 648 "#, 649 }, 650 651 ] 652 } 653 } 654 655 pub fn dave_tools() -> Vec<Tool> { 656 vec![query_tool(), present_tool()] 657 }