commit 56f5151739b910ee50d92729c7e5d1a481348c45
parent 9692b6b9fecb79a08f0c5d3a2d7abc823a02dbc5
Author: William Casarin <jb55@jb55.com>
Date: Tue, 22 Apr 2025 16:04:54 -0700
dave: return tool errors back to the ai
So that it can correct itself
Diffstat:
4 files changed, 90 insertions(+), 9 deletions(-)
diff --git a/crates/notedeck_dave/src/lib.rs b/crates/notedeck_dave/src/lib.rs
@@ -9,6 +9,7 @@ use futures::StreamExt;
use nostrdb::Transaction;
use notedeck::AppContext;
use std::collections::HashMap;
+use std::string::ToString;
use std::sync::mpsc::{self, Receiver};
use std::sync::Arc;
@@ -137,6 +138,15 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr
should_send = true;
}
+ ToolCalls::Invalid(invalid) => {
+ should_send = true;
+
+ self.chat.push(Message::tool_error(
+ call.id().to_string(),
+ invalid.error.clone(),
+ ));
+ }
+
ToolCalls::Query(search_call) => {
should_send = true;
@@ -270,12 +280,23 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr
parsed_tool_calls.push(tool_call);
}
Err(err) => {
+ // TODO: we should be
tracing::error!(
- "failed to parse tool call {:?}: {:?}",
+ "failed to parse tool call {:?}: {}",
unknown_tool_call,
err,
);
- // TODO: return error to user
+
+ if let Some(id) = partial.id() {
+ // we have an id, so we can communicate the error
+ // back to the ai
+ parsed_tool_calls.push(ToolCall::invalid(
+ id.to_string(),
+ partial.name,
+ partial.arguments,
+ err.to_string(),
+ ));
+ }
}
};
}
diff --git a/crates/notedeck_dave/src/messages.rs b/crates/notedeck_dave/src/messages.rs
@@ -19,6 +19,10 @@ pub enum DaveApiResponse {
}
impl Message {
+ pub fn tool_error(id: String, msg: String) -> Self {
+ Self::ToolResponse(ToolResponse::error(id, msg))
+ }
+
pub fn to_api_msg(&self, txn: &Transaction, ndb: &Ndb) -> ChatCompletionRequestMessage {
match self {
Message::User(msg) => {
diff --git a/crates/notedeck_dave/src/tools.rs b/crates/notedeck_dave/src/tools.rs
@@ -4,7 +4,7 @@ use enostr::{NoteId, Pubkey};
use nostrdb::{Ndb, Note, NoteKey, Transaction};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
-use std::collections::HashMap;
+use std::{collections::HashMap, fmt};
/// A tool
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -18,6 +18,22 @@ impl ToolCall {
&self.id
}
+ pub fn invalid(
+ id: String,
+ name: Option<String>,
+ arguments: Option<String>,
+ error: String,
+ ) -> Self {
+ Self {
+ id,
+ typ: ToolCalls::Invalid(InvalidToolCall {
+ name,
+ arguments,
+ error,
+ }),
+ }
+ }
+
pub fn calls(&self) -> &ToolCalls {
&self.typ
}
@@ -34,11 +50,11 @@ impl ToolCall {
/// On streaming APIs, tool calls are incremental. We use this
/// to represent tool calls that are in the process of returning.
/// These eventually just become [`ToolCall`]'s
-#[derive(Default, Debug, Clone)]
+#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct PartialToolCall {
- id: Option<String>,
- name: Option<String>,
- arguments: Option<String>,
+ pub id: Option<String>,
+ pub name: Option<String>,
+ pub arguments: Option<String>,
}
impl PartialToolCall {
@@ -75,6 +91,7 @@ pub struct QueryResponse {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolResponses {
+ Error(String),
Query(QueryResponse),
PresentNotes,
}
@@ -110,6 +127,13 @@ impl PartialToolCall {
}
}
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct InvalidToolCall {
+ pub error: String,
+ pub name: Option<String>,
+ pub arguments: Option<String>,
+}
+
/// An enumeration of the possible tool calls that
/// can be parsed from Dave responses. When adding
/// new tools, this needs to be updated so that we can
@@ -118,6 +142,12 @@ impl PartialToolCall {
pub enum ToolCalls {
Query(QueryCall),
PresentNotes(PresentNotesCall),
+ Invalid(InvalidToolCall),
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+struct ErrorCall {
+ error: String,
}
impl ToolCalls {
@@ -131,6 +161,7 @@ impl ToolCalls {
fn name(&self) -> &'static str {
match self {
Self::Query(_) => "search",
+ Self::Invalid(_) => "error",
Self::PresentNotes(_) => "present",
}
}
@@ -138,6 +169,7 @@ impl ToolCalls {
fn arguments(&self) -> String {
match self {
Self::Query(search) => serde_json::to_string(search).unwrap(),
+ Self::Invalid(partial) => serde_json::to_string(partial).unwrap(),
Self::PresentNotes(call) => serde_json::to_string(&call.to_simple()).unwrap(),
}
}
@@ -151,6 +183,19 @@ pub enum ToolCallError {
ArgParseFailure(String),
}
+impl fmt::Display for ToolCallError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ ToolCallError::EmptyName => write!(f, "the tool name was empty"),
+ ToolCallError::EmptyArgs => write!(f, "no arguments were provided"),
+ ToolCallError::NotFound(ref name) => write!(f, "tool '{}' not found", name),
+ ToolCallError::ArgParseFailure(ref msg) => {
+ write!(f, "failed to parse arguments: {}", msg)
+ }
+ }
+ }
+}
+
#[derive(Debug, Clone)]
enum ArgType {
String,
@@ -276,6 +321,13 @@ impl ToolResponse {
Self { id, typ: responses }
}
+ pub fn error(id: String, msg: String) -> Self {
+ Self {
+ id,
+ typ: ToolResponses::Error(msg),
+ }
+ }
+
pub fn responses(&self) -> &ToolResponses {
&self.typ
}
@@ -323,7 +375,7 @@ impl PresentNotesCall {
Ok(ToolCalls::PresentNotes(PresentNotesCall { note_ids }))
}
Err(e) => Err(ToolCallError::ArgParseFailure(format!(
- "Failed to parse args: '{}', error: {}",
+ "{}, error: {}",
args, e
))),
}
@@ -424,7 +476,7 @@ impl QueryCall {
match serde_json::from_str::<QueryCall>(args) {
Ok(call) => Ok(ToolCalls::Query(call)),
Err(e) => Err(ToolCallError::ArgParseFailure(format!(
- "Failed to parse args: '{}', error: {}",
+ "{}, error: {}",
args, e
))),
}
@@ -448,6 +500,7 @@ struct SimpleNote {
fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String {
match resp {
ToolResponses::PresentNotes => "".to_string(),
+ ToolResponses::Error(s) => format!("error: {}", &s),
ToolResponses::Query(search_r) => {
let simple_notes: Vec<SimpleNote> = search_r
diff --git a/crates/notedeck_dave/src/ui/dave.rs b/crates/notedeck_dave/src/ui/dave.rs
@@ -204,6 +204,9 @@ impl<'a> DaveUi<'a> {
for call in toolcalls {
match call.calls() {
ToolCalls::PresentNotes(call) => Self::present_notes_ui(ctx, call, ui),
+ ToolCalls::Invalid(err) => {
+ ui.label(format!("invalid tool call: {:?}", err));
+ }
ToolCalls::Query(search_call) => {
ui.allocate_ui_with_layout(
egui::vec2(ui.available_size().x, 32.0),