tools.rs (18090B)
1 use async_openai::types::*; 2 use chrono::DateTime; 3 use enostr::{NoteId, Pubkey}; 4 use nostrdb::{Ndb, Note, NoteKey, Transaction}; 5 use serde::{Deserialize, Serialize}; 6 use serde_json::{json, Value}; 7 use std::{collections::HashMap, fmt}; 8 9 /// A tool 10 #[derive(Debug, Clone, Serialize, Deserialize)] 11 pub struct ToolCall { 12 id: String, 13 typ: ToolCalls, 14 } 15 16 impl ToolCall { 17 pub fn id(&self) -> &str { 18 &self.id 19 } 20 21 pub fn invalid( 22 id: String, 23 name: Option<String>, 24 arguments: Option<String>, 25 error: String, 26 ) -> Self { 27 Self { 28 id, 29 typ: ToolCalls::Invalid(InvalidToolCall { 30 name, 31 arguments, 32 error, 33 }), 34 } 35 } 36 37 pub fn calls(&self) -> &ToolCalls { 38 &self.typ 39 } 40 41 pub fn to_api(&self) -> ChatCompletionMessageToolCall { 42 ChatCompletionMessageToolCall { 43 id: self.id.clone(), 44 r#type: ChatCompletionToolType::Function, 45 function: self.typ.to_api(), 46 } 47 } 48 } 49 50 /// On streaming APIs, tool calls are incremental. We use this 51 /// to represent tool calls that are in the process of returning. 52 /// These eventually just become [`ToolCall`]'s 53 #[derive(Default, Debug, Clone, Serialize, Deserialize)] 54 pub struct PartialToolCall { 55 pub id: Option<String>, 56 pub name: Option<String>, 57 pub arguments: Option<String>, 58 } 59 60 impl PartialToolCall { 61 pub fn id(&self) -> Option<&str> { 62 self.id.as_deref() 63 } 64 65 pub fn id_mut(&mut self) -> &mut Option<String> { 66 &mut self.id 67 } 68 69 pub fn name(&self) -> Option<&str> { 70 self.name.as_deref() 71 } 72 73 pub fn name_mut(&mut self) -> &mut Option<String> { 74 &mut self.name 75 } 76 77 pub fn arguments(&self) -> Option<&str> { 78 self.arguments.as_deref() 79 } 80 81 pub fn arguments_mut(&mut self) -> &mut Option<String> { 82 &mut self.arguments 83 } 84 } 85 86 /// The query response from nostrdb for a given context 87 #[derive(Debug, Clone, Serialize, Deserialize)] 88 pub struct QueryResponse { 89 notes: Vec<u64>, 90 } 91 92 #[derive(Debug, Clone, Serialize, Deserialize)] 93 pub enum ToolResponses { 94 Error(String), 95 Query(QueryResponse), 96 PresentNotes(i32), 97 } 98 99 #[derive(Debug, Clone)] 100 pub struct UnknownToolCall { 101 id: String, 102 name: String, 103 arguments: String, 104 } 105 106 impl UnknownToolCall { 107 pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> { 108 let Some(tool) = tools.get(&self.name) else { 109 return Err(ToolCallError::NotFound(self.name.to_owned())); 110 }; 111 112 let parsed_args = (tool.parse_call)(&self.arguments)?; 113 Ok(ToolCall { 114 id: self.id.clone(), 115 typ: parsed_args, 116 }) 117 } 118 } 119 120 impl PartialToolCall { 121 pub fn complete(&self) -> Option<UnknownToolCall> { 122 Some(UnknownToolCall { 123 id: self.id.clone()?, 124 name: self.name.clone()?, 125 arguments: self.arguments.clone()?, 126 }) 127 } 128 } 129 130 #[derive(Debug, Clone, Serialize, Deserialize)] 131 pub struct InvalidToolCall { 132 pub error: String, 133 pub name: Option<String>, 134 pub arguments: Option<String>, 135 } 136 137 /// An enumeration of the possible tool calls that 138 /// can be parsed from Dave responses. When adding 139 /// new tools, this needs to be updated so that we can 140 /// handle tool call responses. 141 #[derive(Debug, Clone, Serialize, Deserialize)] 142 pub enum ToolCalls { 143 Query(QueryCall), 144 PresentNotes(PresentNotesCall), 145 Invalid(InvalidToolCall), 146 } 147 148 impl ToolCalls { 149 pub fn to_api(&self) -> FunctionCall { 150 FunctionCall { 151 name: self.name().to_owned(), 152 arguments: self.arguments(), 153 } 154 } 155 156 fn name(&self) -> &'static str { 157 match self { 158 Self::Query(_) => "search", 159 Self::Invalid(_) => "error", 160 Self::PresentNotes(_) => "present", 161 } 162 } 163 164 fn arguments(&self) -> String { 165 match self { 166 Self::Query(search) => serde_json::to_string(search).unwrap(), 167 Self::Invalid(partial) => serde_json::to_string(partial).unwrap(), 168 Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(), 169 } 170 } 171 } 172 173 #[derive(Debug)] 174 pub enum ToolCallError { 175 EmptyName, 176 EmptyArgs, 177 NotFound(String), 178 ArgParseFailure(String), 179 } 180 181 impl fmt::Display for ToolCallError { 182 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 183 match self { 184 ToolCallError::EmptyName => write!(f, "the tool name was empty"), 185 ToolCallError::EmptyArgs => write!(f, "no arguments were provided"), 186 ToolCallError::NotFound(ref name) => write!(f, "tool '{name}' not found"), 187 ToolCallError::ArgParseFailure(ref msg) => { 188 write!(f, "failed to parse arguments: {msg}") 189 } 190 } 191 } 192 } 193 194 #[derive(Debug, Clone)] 195 enum ArgType { 196 String, 197 Number, 198 199 #[allow(dead_code)] 200 Enum(Vec<&'static str>), 201 } 202 203 impl ArgType { 204 pub fn type_string(&self) -> &'static str { 205 match self { 206 Self::String => "string", 207 Self::Number => "number", 208 Self::Enum(_) => "string", 209 } 210 } 211 } 212 213 #[derive(Debug, Clone)] 214 struct ToolArg { 215 typ: ArgType, 216 name: &'static str, 217 required: bool, 218 description: &'static str, 219 default: Option<Value>, 220 } 221 222 #[derive(Debug, Clone)] 223 pub struct Tool { 224 parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>, 225 name: &'static str, 226 description: &'static str, 227 arguments: Vec<ToolArg>, 228 } 229 230 impl Tool { 231 pub fn name(&self) -> &'static str { 232 self.name 233 } 234 235 pub fn to_function_object(&self) -> FunctionObject { 236 let required_args = self 237 .arguments 238 .iter() 239 .filter_map(|arg| { 240 if arg.required { 241 Some(Value::String(arg.name.to_owned())) 242 } else { 243 None 244 } 245 }) 246 .collect(); 247 248 let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new(); 249 parameters.insert("type".to_string(), Value::String("object".to_string())); 250 parameters.insert("required".to_string(), Value::Array(required_args)); 251 parameters.insert("additionalProperties".to_string(), Value::Bool(false)); 252 253 let mut properties: serde_json::Map<String, Value> = serde_json::Map::new(); 254 255 for arg in &self.arguments { 256 let mut props: serde_json::Map<String, Value> = serde_json::Map::new(); 257 props.insert( 258 "type".to_string(), 259 Value::String(arg.typ.type_string().to_string()), 260 ); 261 262 let description = if let Some(default) = &arg.default { 263 format!("{} (Default: {default}))", arg.description) 264 } else { 265 arg.description.to_owned() 266 }; 267 268 props.insert("description".to_string(), Value::String(description)); 269 if let ArgType::Enum(enums) = &arg.typ { 270 props.insert( 271 "enum".to_string(), 272 Value::Array( 273 enums 274 .iter() 275 .map(|s| Value::String((*s).to_owned())) 276 .collect(), 277 ), 278 ); 279 } 280 281 properties.insert(arg.name.to_owned(), Value::Object(props)); 282 } 283 284 parameters.insert("properties".to_string(), Value::Object(properties)); 285 286 FunctionObject { 287 name: self.name.to_owned(), 288 description: Some(self.description.to_owned()), 289 strict: Some(false), 290 parameters: Some(Value::Object(parameters)), 291 } 292 } 293 294 pub fn to_api(&self) -> ChatCompletionTool { 295 ChatCompletionTool { 296 r#type: ChatCompletionToolType::Function, 297 function: self.to_function_object(), 298 } 299 } 300 } 301 302 impl ToolResponses { 303 pub fn format_for_dave(&self, txn: &Transaction, ndb: &Ndb) -> String { 304 format_tool_response_for_ai(txn, ndb, self) 305 } 306 } 307 308 #[derive(Debug, Clone, Serialize, Deserialize)] 309 pub struct ToolResponse { 310 id: String, 311 typ: ToolResponses, 312 } 313 314 impl ToolResponse { 315 pub fn new(id: String, responses: ToolResponses) -> Self { 316 Self { id, typ: responses } 317 } 318 319 pub fn error(id: String, msg: String) -> Self { 320 Self { 321 id, 322 typ: ToolResponses::Error(msg), 323 } 324 } 325 326 pub fn responses(&self) -> &ToolResponses { 327 &self.typ 328 } 329 330 pub fn id(&self) -> &str { 331 &self.id 332 } 333 } 334 335 /// Called by dave when he wants to display notes on the screen 336 #[derive(Debug, Deserialize, Serialize, Clone)] 337 pub struct PresentNotesCall { 338 pub note_ids: Vec<NoteId>, 339 } 340 341 impl PresentNotesCall { 342 fn to_simple(&self) -> PresentNotesCallSimple { 343 let note_ids = self 344 .note_ids 345 .iter() 346 .map(|nid| hex::encode(nid.bytes())) 347 .collect::<Vec<_>>() 348 .join(","); 349 350 PresentNotesCallSimple { note_ids } 351 } 352 } 353 354 /// Called by dave when he wants to display notes on the screen 355 #[derive(Debug, Deserialize, Serialize, Clone)] 356 pub struct PresentNotesCallSimple { 357 note_ids: String, 358 } 359 360 impl PresentNotesCall { 361 fn parse(args: &str) -> Result<ToolCalls, ToolCallError> { 362 match serde_json::from_str::<PresentNotesCallSimple>(args) { 363 Ok(call) => { 364 let note_ids = call 365 .note_ids 366 .split(",") 367 .filter_map(|n| NoteId::from_hex(n).ok()) 368 .collect(); 369 370 Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids })) 371 } 372 Err(e) => Err(ToolCallError::ArgParseFailure(format!( 373 "{args}, error: {e}" 374 ))), 375 } 376 } 377 } 378 379 /// The parsed nostrdb query that dave wants to use to satisfy a request 380 #[derive(Debug, Deserialize, Serialize, Clone)] 381 pub struct QueryCall { 382 pub author: Option<Pubkey>, 383 pub limit: Option<u64>, 384 pub since: Option<u64>, 385 pub kind: Option<u64>, 386 pub until: Option<u64>, 387 pub search: Option<String>, 388 } 389 390 fn is_reply(note: Note) -> bool { 391 for tag in note.tags() { 392 if tag.count() < 4 { 393 continue; 394 } 395 396 let Some("e") = tag.get_str(0) else { 397 continue; 398 }; 399 400 let Some(s) = tag.get_str(3) else { 401 continue; 402 }; 403 404 if s == "root" || s == "reply" { 405 return true; 406 } 407 } 408 409 false 410 } 411 412 impl QueryCall { 413 pub fn to_filter(&self) -> nostrdb::Filter { 414 let mut filter = nostrdb::Filter::new() 415 .limit(self.limit()) 416 .custom(|n| !is_reply(n)) 417 .kinds([self.kind.unwrap_or(1)]); 418 419 if let Some(author) = &self.author { 420 filter = filter.authors([author.bytes()]); 421 } 422 423 if let Some(search) = &self.search { 424 filter = filter.search(search); 425 } 426 427 if let Some(until) = self.until { 428 filter = filter.until(until); 429 } 430 431 if let Some(since) = self.since { 432 filter = filter.since(since); 433 } 434 435 filter.build() 436 } 437 438 fn limit(&self) -> u64 { 439 self.limit.unwrap_or(10) 440 } 441 442 pub fn author(&self) -> Option<&Pubkey> { 443 self.author.as_ref() 444 } 445 446 pub fn since(&self) -> Option<u64> { 447 self.since 448 } 449 450 pub fn until(&self) -> Option<u64> { 451 self.until 452 } 453 454 pub fn search(&self) -> Option<&str> { 455 self.search.as_deref() 456 } 457 458 pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> QueryResponse { 459 let notes = { 460 if let Ok(results) = ndb.query(txn, &[self.to_filter()], self.limit() as i32) { 461 results.into_iter().map(|r| r.note_key.as_u64()).collect() 462 } else { 463 vec![] 464 } 465 }; 466 QueryResponse { notes } 467 } 468 469 pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> { 470 match serde_json::from_str::<QueryCall>(args) { 471 Ok(call) => Ok(ToolCalls::Query(call)), 472 Err(e) => Err(ToolCallError::ArgParseFailure(format!( 473 "{args}, error: {e}" 474 ))), 475 } 476 } 477 } 478 479 /// A simple note format for use when formatting 480 /// tool responses 481 #[derive(Debug, Serialize)] 482 struct SimpleNote { 483 note_id: String, 484 pubkey: String, 485 name: String, 486 content: String, 487 created_at: String, 488 note_kind: u64, // todo: add replying to 489 } 490 491 /// Take the result of a tool response and present it to the ai so that 492 /// it can interepret it and take further action 493 fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String { 494 match resp { 495 ToolResponses::PresentNotes(n) => format!("{n} notes presented to the user"), 496 ToolResponses::Error(s) => format!("error: {}", &s), 497 498 ToolResponses::Query(search_r) => { 499 let simple_notes: Vec<SimpleNote> = search_r 500 .notes 501 .iter() 502 .filter_map(|nkey| { 503 let Ok(note) = ndb.get_note_by_key(txn, NoteKey::new(*nkey)) else { 504 return None; 505 }; 506 507 let name = ndb 508 .get_profile_by_pubkey(txn, note.pubkey()) 509 .ok() 510 .and_then(|p| p.record().profile()) 511 .and_then(|p| p.name().or_else(|| p.display_name())) 512 .unwrap_or("Anonymous") 513 .to_string(); 514 515 let content = note.content().to_owned(); 516 let pubkey = hex::encode(note.pubkey()); 517 let note_kind = note.kind() as u64; 518 let note_id = hex::encode(note.id()); 519 520 let created_at = { 521 let datetime = 522 DateTime::from_timestamp(note.created_at() as i64, 0).unwrap(); 523 datetime.format("%Y-%m-%d %H:%M:%S").to_string() 524 }; 525 526 Some(SimpleNote { 527 note_id, 528 pubkey, 529 name, 530 content, 531 created_at, 532 note_kind, 533 }) 534 }) 535 .collect(); 536 537 serde_json::to_string(&json!({"search_results": simple_notes})).unwrap() 538 } 539 } 540 } 541 542 fn _note_kind_desc(kind: u64) -> String { 543 match kind { 544 1 => "microblog".to_string(), 545 0 => "profile".to_string(), 546 _ => kind.to_string(), 547 } 548 } 549 550 fn present_tool() -> Tool { 551 Tool { 552 name: "present_notes", 553 parse_call: PresentNotesCall::parse, 554 description: "A tool for presenting notes to the user for display. Should be called at the end of a response so that the UI can present the notes referred to in the previous message.", 555 arguments: vec![ 556 ToolArg { 557 name: "note_ids", 558 description: "A comma-separated list of hex note ids", 559 typ: ArgType::String, 560 required: true, 561 default: None 562 } 563 ] 564 } 565 } 566 567 fn query_tool() -> Tool { 568 Tool { 569 name: "query", 570 parse_call: QueryCall::parse, 571 description: "Note query functionality. Used for finding notes using full-text search terms, scoped by different contexts. You can use a combination of limit, since, and until to pull notes from any time range.", 572 arguments: vec![ 573 ToolArg { 574 name: "search", 575 typ: ArgType::String, 576 required: false, 577 default: None, 578 description: "A fulltext search query. Queries with multiple words will only return results with notes that have all of those words. Don't include filler words/symbols like 'and', punctuation, etc", 579 }, 580 581 ToolArg { 582 name: "limit", 583 typ: ArgType::Number, 584 required: true, 585 default: Some(Value::Number(serde_json::Number::from_i128(50).unwrap())), 586 description: "The number of results to return.", 587 }, 588 589 ToolArg { 590 name: "since", 591 typ: ArgType::Number, 592 required: false, 593 default: None, 594 description: "Only pull notes after this unix timestamp", 595 }, 596 597 ToolArg { 598 name: "until", 599 typ: ArgType::Number, 600 required: false, 601 default: None, 602 description: "Only pull notes up until this unix timestamp. Always include this when searching notes within some date range (yesterday, last week, etc).", 603 }, 604 605 ToolArg { 606 name: "author", 607 typ: ArgType::String, 608 required: false, 609 default: None, 610 description: "An author *pubkey* to constrain the query on. Can be used to search for notes from individual users. If unsure what pubkey to u 611 se, you can query for kind 0 profiles with the search argument.", 612 }, 613 614 ToolArg { 615 name: "kind", 616 typ: ArgType::Number, 617 required: false, 618 default: Some(Value::Number(serde_json::Number::from_i128(1).unwrap())), 619 description: r#"The kind of note. Kind list: 620 - 0: profiles 621 - 1: microblogs/\"tweets\"/posts 622 - 6: reposts of kind 1 notes 623 - 7: emoji reactions/likes 624 - 9735: zaps (bitcoin micropayment receipts) 625 - 30023: longform articles, blog posts, etc 626 627 "#, 628 }, 629 630 ] 631 } 632 } 633 634 pub fn dave_tools() -> Vec<Tool> { 635 vec![query_tool(), present_tool()] 636 }