tools.rs (18179B)
1 use async_openai::types::*; 2 use chrono::DateTime; 3 use enostr::{NoteId, Pubkey}; 4 use nostrdb::{Ndb, Note, NoteKey, Transaction}; 5 use serde::{Deserialize, Serialize}; 6 use serde_json::{json, Value}; 7 use std::{collections::HashMap, fmt}; 8 9 /// A tool 10 #[derive(Debug, Clone, Serialize, Deserialize)] 11 pub struct ToolCall { 12 id: String, 13 typ: ToolCalls, 14 } 15 16 impl ToolCall { 17 pub fn id(&self) -> &str { 18 &self.id 19 } 20 21 pub fn invalid( 22 id: String, 23 name: Option<String>, 24 arguments: Option<String>, 25 error: String, 26 ) -> Self { 27 Self { 28 id, 29 typ: ToolCalls::Invalid(InvalidToolCall { 30 name, 31 arguments, 32 error, 33 }), 34 } 35 } 36 37 pub fn calls(&self) -> &ToolCalls { 38 &self.typ 39 } 40 41 pub fn to_api(&self) -> ChatCompletionMessageToolCall { 42 ChatCompletionMessageToolCall { 43 id: self.id.clone(), 44 r#type: ChatCompletionToolType::Function, 45 function: self.typ.to_api(), 46 } 47 } 48 } 49 50 /// On streaming APIs, tool calls are incremental. We use this 51 /// to represent tool calls that are in the process of returning. 52 /// These eventually just become [`ToolCall`]'s 53 #[derive(Default, Debug, Clone, Serialize, Deserialize)] 54 pub struct PartialToolCall { 55 pub id: Option<String>, 56 pub name: Option<String>, 57 pub arguments: Option<String>, 58 } 59 60 impl PartialToolCall { 61 pub fn id(&self) -> Option<&str> { 62 self.id.as_deref() 63 } 64 65 pub fn id_mut(&mut self) -> &mut Option<String> { 66 &mut self.id 67 } 68 69 pub fn name(&self) -> Option<&str> { 70 self.name.as_deref() 71 } 72 73 pub fn name_mut(&mut self) -> &mut Option<String> { 74 &mut self.name 75 } 76 77 pub fn arguments(&self) -> Option<&str> { 78 self.arguments.as_deref() 79 } 80 81 pub fn arguments_mut(&mut self) -> &mut Option<String> { 82 &mut self.arguments 83 } 84 } 85 86 /// The query response from nostrdb for a given context 87 #[derive(Debug, Clone, Serialize, Deserialize)] 88 pub struct QueryResponse { 89 notes: Vec<u64>, 90 } 91 92 #[derive(Debug, Clone, Serialize, Deserialize)] 93 pub enum ToolResponses { 94 Error(String), 95 Query(QueryResponse), 96 PresentNotes(i32), 97 } 98 99 #[derive(Debug, Clone)] 100 pub struct UnknownToolCall { 101 id: String, 102 name: String, 103 arguments: String, 104 } 105 106 impl UnknownToolCall { 107 pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> { 108 let Some(tool) = tools.get(&self.name) else { 109 return Err(ToolCallError::NotFound(self.name.to_owned())); 110 }; 111 112 let parsed_args = (tool.parse_call)(&self.arguments)?; 113 Ok(ToolCall { 114 id: self.id.clone(), 115 typ: parsed_args, 116 }) 117 } 118 } 119 120 impl PartialToolCall { 121 pub fn complete(&self) -> Option<UnknownToolCall> { 122 Some(UnknownToolCall { 123 id: self.id.clone()?, 124 name: self.name.clone()?, 125 arguments: self.arguments.clone()?, 126 }) 127 } 128 } 129 130 #[derive(Debug, Clone, Serialize, Deserialize)] 131 pub struct InvalidToolCall { 132 pub error: String, 133 pub name: Option<String>, 134 pub arguments: Option<String>, 135 } 136 137 /// An enumeration of the possible tool calls that 138 /// can be parsed from Dave responses. When adding 139 /// new tools, this needs to be updated so that we can 140 /// handle tool call responses. 141 #[derive(Debug, Clone, Serialize, Deserialize)] 142 pub enum ToolCalls { 143 Query(QueryCall), 144 PresentNotes(PresentNotesCall), 145 Invalid(InvalidToolCall), 146 } 147 148 #[derive(Debug, Clone, Serialize, Deserialize)] 149 struct ErrorCall { 150 error: String, 151 } 152 153 impl ToolCalls { 154 pub fn to_api(&self) -> FunctionCall { 155 FunctionCall { 156 name: self.name().to_owned(), 157 arguments: self.arguments(), 158 } 159 } 160 161 fn name(&self) -> &'static str { 162 match self { 163 Self::Query(_) => "search", 164 Self::Invalid(_) => "error", 165 Self::PresentNotes(_) => "present", 166 } 167 } 168 169 fn arguments(&self) -> String { 170 match self { 171 Self::Query(search) => serde_json::to_string(search).unwrap(), 172 Self::Invalid(partial) => serde_json::to_string(partial).unwrap(), 173 Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(), 174 } 175 } 176 } 177 178 #[derive(Debug)] 179 pub enum ToolCallError { 180 EmptyName, 181 EmptyArgs, 182 NotFound(String), 183 ArgParseFailure(String), 184 } 185 186 impl fmt::Display for ToolCallError { 187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 188 match self { 189 ToolCallError::EmptyName => write!(f, "the tool name was empty"), 190 ToolCallError::EmptyArgs => write!(f, "no arguments were provided"), 191 ToolCallError::NotFound(ref name) => write!(f, "tool '{name}' not found"), 192 ToolCallError::ArgParseFailure(ref msg) => { 193 write!(f, "failed to parse arguments: {msg}") 194 } 195 } 196 } 197 } 198 199 #[derive(Debug, Clone)] 200 enum ArgType { 201 String, 202 Number, 203 204 #[allow(dead_code)] 205 Enum(Vec<&'static str>), 206 } 207 208 impl ArgType { 209 pub fn type_string(&self) -> &'static str { 210 match self { 211 Self::String => "string", 212 Self::Number => "number", 213 Self::Enum(_) => "string", 214 } 215 } 216 } 217 218 #[derive(Debug, Clone)] 219 struct ToolArg { 220 typ: ArgType, 221 name: &'static str, 222 required: bool, 223 description: &'static str, 224 default: Option<Value>, 225 } 226 227 #[derive(Debug, Clone)] 228 pub struct Tool { 229 parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>, 230 name: &'static str, 231 description: &'static str, 232 arguments: Vec<ToolArg>, 233 } 234 235 impl Tool { 236 pub fn name(&self) -> &'static str { 237 self.name 238 } 239 240 pub fn to_function_object(&self) -> FunctionObject { 241 let required_args = self 242 .arguments 243 .iter() 244 .filter_map(|arg| { 245 if arg.required { 246 Some(Value::String(arg.name.to_owned())) 247 } else { 248 None 249 } 250 }) 251 .collect(); 252 253 let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new(); 254 parameters.insert("type".to_string(), Value::String("object".to_string())); 255 parameters.insert("required".to_string(), Value::Array(required_args)); 256 parameters.insert("additionalProperties".to_string(), Value::Bool(false)); 257 258 let mut properties: serde_json::Map<String, Value> = serde_json::Map::new(); 259 260 for arg in &self.arguments { 261 let mut props: serde_json::Map<String, Value> = serde_json::Map::new(); 262 props.insert( 263 "type".to_string(), 264 Value::String(arg.typ.type_string().to_string()), 265 ); 266 267 let description = if let Some(default) = &arg.default { 268 format!("{} (Default: {default}))", arg.description) 269 } else { 270 arg.description.to_owned() 271 }; 272 273 props.insert("description".to_string(), Value::String(description)); 274 if let ArgType::Enum(enums) = &arg.typ { 275 props.insert( 276 "enum".to_string(), 277 Value::Array( 278 enums 279 .iter() 280 .map(|s| Value::String((*s).to_owned())) 281 .collect(), 282 ), 283 ); 284 } 285 286 properties.insert(arg.name.to_owned(), Value::Object(props)); 287 } 288 289 parameters.insert("properties".to_string(), Value::Object(properties)); 290 291 FunctionObject { 292 name: self.name.to_owned(), 293 description: Some(self.description.to_owned()), 294 strict: Some(false), 295 parameters: Some(Value::Object(parameters)), 296 } 297 } 298 299 pub fn to_api(&self) -> ChatCompletionTool { 300 ChatCompletionTool { 301 r#type: ChatCompletionToolType::Function, 302 function: self.to_function_object(), 303 } 304 } 305 } 306 307 impl ToolResponses { 308 pub fn format_for_dave(&self, txn: &Transaction, ndb: &Ndb) -> String { 309 format_tool_response_for_ai(txn, ndb, self) 310 } 311 } 312 313 #[derive(Debug, Clone, Serialize, Deserialize)] 314 pub struct ToolResponse { 315 id: String, 316 typ: ToolResponses, 317 } 318 319 impl ToolResponse { 320 pub fn new(id: String, responses: ToolResponses) -> Self { 321 Self { id, typ: responses } 322 } 323 324 pub fn error(id: String, msg: String) -> Self { 325 Self { 326 id, 327 typ: ToolResponses::Error(msg), 328 } 329 } 330 331 pub fn responses(&self) -> &ToolResponses { 332 &self.typ 333 } 334 335 pub fn id(&self) -> &str { 336 &self.id 337 } 338 } 339 340 /// Called by dave when he wants to display notes on the screen 341 #[derive(Debug, Deserialize, Serialize, Clone)] 342 pub struct PresentNotesCall { 343 pub note_ids: Vec<NoteId>, 344 } 345 346 impl PresentNotesCall { 347 fn to_simple(&self) -> PresentNotesCallSimple { 348 let note_ids = self 349 .note_ids 350 .iter() 351 .map(|nid| hex::encode(nid.bytes())) 352 .collect::<Vec<_>>() 353 .join(","); 354 355 PresentNotesCallSimple { note_ids } 356 } 357 } 358 359 /// Called by dave when he wants to display notes on the screen 360 #[derive(Debug, Deserialize, Serialize, Clone)] 361 pub struct PresentNotesCallSimple { 362 note_ids: String, 363 } 364 365 impl PresentNotesCall { 366 fn parse(args: &str) -> Result<ToolCalls, ToolCallError> { 367 match serde_json::from_str::<PresentNotesCallSimple>(args) { 368 Ok(call) => { 369 let note_ids = call 370 .note_ids 371 .split(",") 372 .filter_map(|n| NoteId::from_hex(n).ok()) 373 .collect(); 374 375 Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids })) 376 } 377 Err(e) => Err(ToolCallError::ArgParseFailure(format!( 378 "{args}, error: {e}" 379 ))), 380 } 381 } 382 } 383 384 /// The parsed nostrdb query that dave wants to use to satisfy a request 385 #[derive(Debug, Deserialize, Serialize, Clone)] 386 pub struct QueryCall { 387 pub author: Option<Pubkey>, 388 pub limit: Option<u64>, 389 pub since: Option<u64>, 390 pub kind: Option<u64>, 391 pub until: Option<u64>, 392 pub search: Option<String>, 393 } 394 395 fn is_reply(note: Note) -> bool { 396 for tag in note.tags() { 397 if tag.count() < 4 { 398 continue; 399 } 400 401 let Some("e") = tag.get_str(0) else { 402 continue; 403 }; 404 405 let Some(s) = tag.get_str(3) else { 406 continue; 407 }; 408 409 if s == "root" || s == "reply" { 410 return true; 411 } 412 } 413 414 false 415 } 416 417 impl QueryCall { 418 pub fn to_filter(&self) -> nostrdb::Filter { 419 let mut filter = nostrdb::Filter::new() 420 .limit(self.limit()) 421 .custom(|n| !is_reply(n)) 422 .kinds([self.kind.unwrap_or(1)]); 423 424 if let Some(author) = &self.author { 425 filter = filter.authors([author.bytes()]); 426 } 427 428 if let Some(search) = &self.search { 429 filter = filter.search(search); 430 } 431 432 if let Some(until) = self.until { 433 filter = filter.until(until); 434 } 435 436 if let Some(since) = self.since { 437 filter = filter.since(since); 438 } 439 440 filter.build() 441 } 442 443 fn limit(&self) -> u64 { 444 self.limit.unwrap_or(10) 445 } 446 447 pub fn author(&self) -> Option<&Pubkey> { 448 self.author.as_ref() 449 } 450 451 pub fn since(&self) -> Option<u64> { 452 self.since 453 } 454 455 pub fn until(&self) -> Option<u64> { 456 self.until 457 } 458 459 pub fn search(&self) -> Option<&str> { 460 self.search.as_deref() 461 } 462 463 pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> QueryResponse { 464 let notes = { 465 if let Ok(results) = ndb.query(txn, &[self.to_filter()], self.limit() as i32) { 466 results.into_iter().map(|r| r.note_key.as_u64()).collect() 467 } else { 468 vec![] 469 } 470 }; 471 QueryResponse { notes } 472 } 473 474 pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> { 475 match serde_json::from_str::<QueryCall>(args) { 476 Ok(call) => Ok(ToolCalls::Query(call)), 477 Err(e) => Err(ToolCallError::ArgParseFailure(format!( 478 "{args}, error: {e}" 479 ))), 480 } 481 } 482 } 483 484 /// A simple note format for use when formatting 485 /// tool responses 486 #[derive(Debug, Serialize)] 487 struct SimpleNote { 488 note_id: String, 489 pubkey: String, 490 name: String, 491 content: String, 492 created_at: String, 493 note_kind: u64, // todo: add replying to 494 } 495 496 /// Take the result of a tool response and present it to the ai so that 497 /// it can interepret it and take further action 498 fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String { 499 match resp { 500 ToolResponses::PresentNotes(n) => format!("{n} notes presented to the user"), 501 ToolResponses::Error(s) => format!("error: {}", &s), 502 503 ToolResponses::Query(search_r) => { 504 let simple_notes: Vec<SimpleNote> = search_r 505 .notes 506 .iter() 507 .filter_map(|nkey| { 508 let Ok(note) = ndb.get_note_by_key(txn, NoteKey::new(*nkey)) else { 509 return None; 510 }; 511 512 let name = ndb 513 .get_profile_by_pubkey(txn, note.pubkey()) 514 .ok() 515 .and_then(|p| p.record().profile()) 516 .and_then(|p| p.name().or_else(|| p.display_name())) 517 .unwrap_or("Anonymous") 518 .to_string(); 519 520 let content = note.content().to_owned(); 521 let pubkey = hex::encode(note.pubkey()); 522 let note_kind = note.kind() as u64; 523 let note_id = hex::encode(note.id()); 524 525 let created_at = { 526 let datetime = 527 DateTime::from_timestamp(note.created_at() as i64, 0).unwrap(); 528 datetime.format("%Y-%m-%d %H:%M:%S").to_string() 529 }; 530 531 Some(SimpleNote { 532 note_id, 533 pubkey, 534 name, 535 content, 536 created_at, 537 note_kind, 538 }) 539 }) 540 .collect(); 541 542 serde_json::to_string(&json!({"search_results": simple_notes})).unwrap() 543 } 544 } 545 } 546 547 fn _note_kind_desc(kind: u64) -> String { 548 match kind { 549 1 => "microblog".to_string(), 550 0 => "profile".to_string(), 551 _ => kind.to_string(), 552 } 553 } 554 555 fn present_tool() -> Tool { 556 Tool { 557 name: "present_notes", 558 parse_call: PresentNotesCall::parse, 559 description: "A tool for presenting notes to the user for display. Should be called at the end of a response so that the UI can present the notes referred to in the previous message.", 560 arguments: vec![ 561 ToolArg { 562 name: "note_ids", 563 description: "A comma-separated list of hex note ids", 564 typ: ArgType::String, 565 required: true, 566 default: None 567 } 568 ] 569 } 570 } 571 572 fn query_tool() -> Tool { 573 Tool { 574 name: "query", 575 parse_call: QueryCall::parse, 576 description: "Note query functionality. Used for finding notes using full-text search terms, scoped by different contexts. You can use a combination of limit, since, and until to pull notes from any time range.", 577 arguments: vec![ 578 ToolArg { 579 name: "search", 580 typ: ArgType::String, 581 required: false, 582 default: None, 583 description: "A fulltext search query. Queries with multiple words will only return results with notes that have all of those words. Don't include filler words/symbols like 'and', punctuation, etc", 584 }, 585 586 ToolArg { 587 name: "limit", 588 typ: ArgType::Number, 589 required: true, 590 default: Some(Value::Number(serde_json::Number::from_i128(50).unwrap())), 591 description: "The number of results to return.", 592 }, 593 594 ToolArg { 595 name: "since", 596 typ: ArgType::Number, 597 required: false, 598 default: None, 599 description: "Only pull notes after this unix timestamp", 600 }, 601 602 ToolArg { 603 name: "until", 604 typ: ArgType::Number, 605 required: false, 606 default: None, 607 description: "Only pull notes up until this unix timestamp. Always include this when searching notes within some date range (yesterday, last week, etc).", 608 }, 609 610 ToolArg { 611 name: "author", 612 typ: ArgType::String, 613 required: false, 614 default: None, 615 description: "An author *pubkey* to constrain the query on. Can be used to search for notes from individual users. If unsure what pubkey to u 616 se, you can query for kind 0 profiles with the search argument.", 617 }, 618 619 ToolArg { 620 name: "kind", 621 typ: ArgType::Number, 622 required: false, 623 default: Some(Value::Number(serde_json::Number::from_i128(1).unwrap())), 624 description: r#"The kind of note. Kind list: 625 - 0: profiles 626 - 1: microblogs/\"tweets\"/posts 627 - 6: reposts of kind 1 notes 628 - 7: emoji reactions/likes 629 - 9735: zaps (bitcoin micropayment receipts) 630 - 30023: longform articles, blog posts, etc 631 632 "#, 633 }, 634 635 ] 636 } 637 } 638 639 pub fn dave_tools() -> Vec<Tool> { 640 vec![query_tool(), present_tool()] 641 }