notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

commit 56f5151739b910ee50d92729c7e5d1a481348c45
parent 9692b6b9fecb79a08f0c5d3a2d7abc823a02dbc5
Author: William Casarin <jb55@jb55.com>
Date:   Tue, 22 Apr 2025 16:04:54 -0700

dave: return tool errors back to the ai

So that it can correct itself

Diffstat:
Mcrates/notedeck_dave/src/lib.rs | 25+++++++++++++++++++++++--
Mcrates/notedeck_dave/src/messages.rs | 4++++
Mcrates/notedeck_dave/src/tools.rs | 67++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------
Mcrates/notedeck_dave/src/ui/dave.rs | 3+++
4 files changed, 90 insertions(+), 9 deletions(-)

diff --git a/crates/notedeck_dave/src/lib.rs b/crates/notedeck_dave/src/lib.rs @@ -9,6 +9,7 @@ use futures::StreamExt; use nostrdb::Transaction; use notedeck::AppContext; use std::collections::HashMap; +use std::string::ToString; use std::sync::mpsc::{self, Receiver}; use std::sync::Arc; @@ -137,6 +138,15 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr should_send = true; } + ToolCalls::Invalid(invalid) => { + should_send = true; + + self.chat.push(Message::tool_error( + call.id().to_string(), + invalid.error.clone(), + )); + } + ToolCalls::Query(search_call) => { should_send = true; @@ -270,12 +280,23 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr parsed_tool_calls.push(tool_call); } Err(err) => { + // TODO: we should be tracing::error!( - "failed to parse tool call {:?}: {:?}", + "failed to parse tool call {:?}: {}", unknown_tool_call, err, ); - // TODO: return error to user + + if let Some(id) = partial.id() { + // we have an id, so we can communicate the error + // back to the ai + parsed_tool_calls.push(ToolCall::invalid( + id.to_string(), + partial.name, + partial.arguments, + err.to_string(), + )); + } } }; } diff --git a/crates/notedeck_dave/src/messages.rs b/crates/notedeck_dave/src/messages.rs @@ -19,6 +19,10 @@ pub enum DaveApiResponse { } impl Message { + pub fn tool_error(id: String, msg: String) -> Self { + Self::ToolResponse(ToolResponse::error(id, msg)) + } + pub fn to_api_msg(&self, txn: &Transaction, ndb: &Ndb) -> ChatCompletionRequestMessage { match self { Message::User(msg) => { diff --git a/crates/notedeck_dave/src/tools.rs b/crates/notedeck_dave/src/tools.rs @@ -4,7 +4,7 @@ use enostr::{NoteId, Pubkey}; use nostrdb::{Ndb, Note, NoteKey, Transaction}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use std::collections::HashMap; +use std::{collections::HashMap, fmt}; /// A tool #[derive(Debug, Clone, Serialize, Deserialize)] @@ -18,6 +18,22 @@ impl ToolCall { &self.id } + pub fn invalid( + id: String, + name: Option<String>, + arguments: Option<String>, + error: String, + ) -> Self { + Self { + id, + typ: ToolCalls::Invalid(InvalidToolCall { + name, + arguments, + error, + }), + } + } + pub fn calls(&self) -> &ToolCalls { &self.typ } @@ -34,11 +50,11 @@ impl ToolCall { /// On streaming APIs, tool calls are incremental. We use this /// to represent tool calls that are in the process of returning. /// These eventually just become [`ToolCall`]'s -#[derive(Default, Debug, Clone)] +#[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct PartialToolCall { - id: Option<String>, - name: Option<String>, - arguments: Option<String>, + pub id: Option<String>, + pub name: Option<String>, + pub arguments: Option<String>, } impl PartialToolCall { @@ -75,6 +91,7 @@ pub struct QueryResponse { #[derive(Debug, Clone, Serialize, Deserialize)] pub enum ToolResponses { + Error(String), Query(QueryResponse), PresentNotes, } @@ -110,6 +127,13 @@ impl PartialToolCall { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InvalidToolCall { + pub error: String, + pub name: Option<String>, + pub arguments: Option<String>, +} + /// An enumeration of the possible tool calls that /// can be parsed from Dave responses. When adding /// new tools, this needs to be updated so that we can @@ -118,6 +142,12 @@ impl PartialToolCall { pub enum ToolCalls { Query(QueryCall), PresentNotes(PresentNotesCall), + Invalid(InvalidToolCall), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ErrorCall { + error: String, } impl ToolCalls { @@ -131,6 +161,7 @@ impl ToolCalls { fn name(&self) -> &'static str { match self { Self::Query(_) => "search", + Self::Invalid(_) => "error", Self::PresentNotes(_) => "present", } } @@ -138,6 +169,7 @@ impl ToolCalls { fn arguments(&self) -> String { match self { Self::Query(search) => serde_json::to_string(search).unwrap(), + Self::Invalid(partial) => serde_json::to_string(partial).unwrap(), Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(), } } @@ -151,6 +183,19 @@ pub enum ToolCallError { ArgParseFailure(String), } +impl fmt::Display for ToolCallError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ToolCallError::EmptyName => write!(f, "the tool name was empty"), + ToolCallError::EmptyArgs => write!(f, "no arguments were provided"), + ToolCallError::NotFound(ref name) => write!(f, "tool '{}' not found", name), + ToolCallError::ArgParseFailure(ref msg) => { + write!(f, "failed to parse arguments: {}", msg) + } + } + } +} + #[derive(Debug, Clone)] enum ArgType { String, @@ -276,6 +321,13 @@ impl ToolResponse { Self { id, typ: responses } } + pub fn error(id: String, msg: String) -> Self { + Self { + id, + typ: ToolResponses::Error(msg), + } + } + pub fn responses(&self) -> &ToolResponses { &self.typ } @@ -323,7 +375,7 @@ impl PresentNotesCall { Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids })) } Err(e) => Err(ToolCallError::ArgParseFailure(format!( - "Failed to parse args: '{}', error: {}", + "{}, error: {}", args, e ))), } @@ -424,7 +476,7 @@ impl QueryCall { match serde_json::from_str::<QueryCall>(args) { Ok(call) => Ok(ToolCalls::Query(call)), Err(e) => Err(ToolCallError::ArgParseFailure(format!( - "Failed to parse args: '{}', error: {}", + "{}, error: {}", args, e ))), } @@ -448,6 +500,7 @@ struct SimpleNote { fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String { match resp { ToolResponses::PresentNotes => "".to_string(), + ToolResponses::Error(s) => format!("error: {}", &s), ToolResponses::Query(search_r) => { let simple_notes: Vec<SimpleNote> = search_r diff --git a/crates/notedeck_dave/src/ui/dave.rs b/crates/notedeck_dave/src/ui/dave.rs @@ -204,6 +204,9 @@ impl<'a> DaveUi<'a> { for call in toolcalls { match call.calls() { ToolCalls::PresentNotes(call) => Self::present_notes_ui(ctx, call, ui), + ToolCalls::Invalid(err) => { + ui.label(format!("invalid tool call: {:?}", err)); + } ToolCalls::Query(search_call) => { ui.allocate_ui_with_layout( egui::vec2(ui.available_size().x, 32.0),