commit 4dfb013d6a9ac159fb74128965e8b364da6872d3
parent 6e2c4cb695ea8f68a07502d5d4f80b26b0ed9aa7
Author: William Casarin <jb55@jb55.com>
Date: Tue, 25 Mar 2025 16:45:22 -0700
dave: toolcall parsing
Signed-off-by: William Casarin <jb55@jb55.com>
Diffstat:
3 files changed, 267 insertions(+), 18 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
@@ -3304,6 +3304,8 @@ dependencies = [
"futures",
"notedeck",
"reqwest",
+ "serde",
+ "serde_json",
"tokio",
"tracing",
]
diff --git a/crates/notedeck_dave/Cargo.toml b/crates/notedeck_dave/Cargo.toml
@@ -11,6 +11,8 @@ eframe = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
egui-wgpu = { workspace = true }
+serde_json = { workspace = true }
+serde = { workspace = true }
bytemuck = "1.22.0"
futures = "0.3.31"
reqwest = "0.12.15"
diff --git a/crates/notedeck_dave/src/lib.rs b/crates/notedeck_dave/src/lib.rs
@@ -4,13 +4,18 @@ use async_openai::{
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent,
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage,
- ChatCompletionRequestUserMessageContent, CreateChatCompletionRequest,
+ ChatCompletionRequestUserMessageContent, ChatCompletionTool, ChatCompletionToolType,
+ CreateChatCompletionRequest, FunctionObject,
},
Client,
};
use futures::StreamExt;
use notedeck::AppContext;
+use serde::Deserialize;
+use serde_json::Value;
+use std::collections::HashMap;
use std::sync::mpsc::{self, Receiver};
+use std::sync::Arc;
use avatar::DaveAvatar;
use egui::{Rect, Vec2};
@@ -54,19 +59,56 @@ impl Message {
}
}
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "lowercase")]
+pub enum SearchContext {
+ Home,
+ Profile,
+ Any,
+}
+
+#[derive(Debug, Deserialize)]
+pub struct SearchCall {
+ context: SearchContext,
+ query: String,
+}
+
+impl SearchCall {
+ pub fn parse(args: &str) -> Result<ToolCall, ToolCallError> {
+ match serde_json::from_str::<SearchCall>(args) {
+ Ok(call) => Ok(ToolCall::Search(call)),
+ Err(e) => Err(ToolCallError::ArgParseFailure(format!(
+ "Failed to parse args: '{}', error: {}",
+ args, e
+ ))),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub enum ToolCall {
+ Search(SearchCall),
+}
+
+pub enum DaveResponse {
+ ToolCall(ToolCall),
+ Token(String),
+}
+
pub struct Dave {
chat: Vec<Message>,
/// A 3d representation of dave.
avatar: Option<DaveAvatar>,
input: String,
pubkey: String,
+ tools: Arc<HashMap<String, Tool>>,
client: async_openai::Client<OpenAIConfig>,
- incoming_tokens: Option<Receiver<String>>,
+ incoming_tokens: Option<Receiver<DaveResponse>>,
}
impl Dave {
pub fn new(render_state: Option<&RenderState>) -> Self {
- let mut config = OpenAIConfig::new().with_api_base("http://ollama.jb55.com/v1");
+ let mut config = OpenAIConfig::new(); //.with_api_base("http://ollama.jb55.com/v1");
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
config = config.with_api_key(api_key);
}
@@ -76,12 +118,17 @@ impl Dave {
let input = "".to_string();
let pubkey = "test_pubkey".to_string();
let avatar = render_state.map(DaveAvatar::new);
+ let mut tools: HashMap<String, Tool> = HashMap::new();
+ for tool in dave_tools() {
+ tools.insert(tool.name.to_string(), tool);
+ }
Dave {
client,
pubkey,
avatar,
incoming_tokens: None,
+ tools: Arc::new(tools),
input,
chat: vec![
Message::System("You are an ai agent for the nostr protocol. You have access to tools that can query the network, so you can help find content for users (TODO: actually implement this)".to_string()),
@@ -91,11 +138,17 @@ impl Dave {
fn render(&mut self, ui: &mut egui::Ui) {
if let Some(recvr) = &self.incoming_tokens {
- if let Ok(token) = recvr.try_recv() {
- match self.chat.last_mut() {
- Some(Message::Assistant(msg)) => *msg = msg.clone() + &token,
- Some(_) => self.chat.push(Message::Assistant(token)),
- None => {}
+ while let Ok(res) = recvr.try_recv() {
+ match res {
+ DaveResponse::Token(token) => match self.chat.last_mut() {
+ Some(Message::Assistant(msg)) => *msg = msg.clone() + &token,
+ Some(_) => self.chat.push(Message::Assistant(token)),
+ None => {}
+ },
+
+ DaveResponse::ToolCall(tool) => {
+ tracing::info!("got tool call: {:?}", tool);
+ }
}
}
}
@@ -171,14 +224,16 @@ impl Dave {
self.incoming_tokens = Some(rx);
let ctx = ctx.clone();
let client = self.client.clone();
+ let tools = self.tools.clone();
tokio::spawn(async move {
let mut token_stream = match client
.chat()
.create_stream(CreateChatCompletionRequest {
+ model: "gpt-4o".to_string(),
//model: "gpt-4o".to_string(),
- model: "llama3.1:latest".to_string(),
stream: Some(true),
messages,
+ tools: Some(dave_tools().iter().map(|t| t.to_api()).collect()),
user: Some(pubkey),
..Default::default()
})
@@ -192,7 +247,8 @@ impl Dave {
Ok(stream) => stream,
};
- tracing::info!("got stream!");
+ let mut tool_call_name: Option<String> = None;
+ let mut tool_call_chunks: Vec<String> = vec![];
while let Some(token) = token_stream.next().await {
let token = match token {
@@ -202,16 +258,61 @@ impl Dave {
return;
}
};
- let Some(choice) = token.choices.first() else {
- return;
- };
- let Some(content) = &choice.delta.content else {
- return;
- };
- tx.send(content.to_owned()).unwrap();
- ctx.request_repaint();
+ for choice in &token.choices {
+ let resp = &choice.delta;
+
+ // if we have tool call arg chunks, collect them here
+ if let Some(tool_calls) = &resp.tool_calls {
+ for tool in tool_calls {
+ let Some(fcall) = &tool.function else {
+ continue;
+ };
+
+ if let Some(name) = &fcall.name {
+ tool_call_name = Some(name.clone());
+ }
+
+ let Some(argchunk) = &fcall.arguments else {
+ continue;
+ };
+
+ tool_call_chunks.push(argchunk.clone());
+ }
+ }
+
+ if let Some(content) = &resp.content {
+ tx.send(DaveResponse::Token(content.to_owned())).unwrap();
+ ctx.request_repaint();
+ }
+ }
+ }
+
+ if let Some(tool_name) = tool_call_name {
+ if !tool_call_chunks.is_empty() {
+ let args = tool_call_chunks.join("");
+ match parse_tool_call(&tools, &tool_name, &args) {
+ Ok(tool_call) => {
+ tx.send(DaveResponse::ToolCall(tool_call)).unwrap();
+ ctx.request_repaint();
+ }
+ Err(err) => {
+ tracing::error!(
+ "failed to parse tool call err({:?}): name({:?}) args({:?})",
+ err,
+ tool_name,
+ args,
+ );
+ // TODO: return error to user
+ }
+ };
+ } else {
+ // TODO: return error to user
+ tracing::error!("got tool call '{}' with no arguments?", tool_name);
+ }
}
+
+ tracing::debug!("stream closed");
});
}
}
@@ -228,3 +329,147 @@ impl notedeck::App for Dave {
self.render(ui);
}
}
+
+#[derive(Debug, Clone)]
+enum ArgType {
+ String,
+ Number,
+ Enum(Vec<&'static str>),
+}
+
+impl ArgType {
+ pub fn type_string(&self) -> &'static str {
+ match self {
+ Self::String => "string",
+ Self::Number => "number",
+ Self::Enum(_) => "string",
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+struct ToolArg {
+ typ: ArgType,
+ name: &'static str,
+ required: bool,
+ description: &'static str,
+}
+
+#[derive(Debug, Clone)]
+pub struct Tool {
+ parse_call: fn(&str) -> Result<ToolCall, ToolCallError>,
+ name: &'static str,
+ description: &'static str,
+ arguments: Vec<ToolArg>,
+}
+
+impl Tool {
+ pub fn to_function_object(&self) -> FunctionObject {
+ let required_args = self
+ .arguments
+ .iter()
+ .filter_map(|arg| {
+ if arg.required {
+ Some(Value::String(arg.name.to_owned()))
+ } else {
+ None
+ }
+ })
+ .collect();
+
+ let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new();
+ parameters.insert("type".to_string(), Value::String("object".to_string()));
+ parameters.insert("required".to_string(), Value::Array(required_args));
+ parameters.insert("additionalProperties".to_string(), Value::Bool(false));
+
+ let mut properties: serde_json::Map<String, Value> = serde_json::Map::new();
+
+ for arg in &self.arguments {
+ let mut props: serde_json::Map<String, Value> = serde_json::Map::new();
+ props.insert(
+ "type".to_string(),
+ Value::String(arg.typ.type_string().to_string()),
+ );
+ props.insert(
+ "description".to_string(),
+ Value::String(arg.description.to_owned()),
+ );
+ if let ArgType::Enum(enums) = &arg.typ {
+ props.insert(
+ "enum".to_string(),
+ Value::Array(
+ enums
+ .into_iter()
+ .map(|s| Value::String((*s).to_owned()))
+ .collect(),
+ ),
+ );
+ }
+
+ properties.insert(arg.name.to_owned(), Value::Object(props));
+ }
+
+ parameters.insert("properties".to_string(), Value::Object(properties));
+
+ FunctionObject {
+ name: self.name.to_owned(),
+ description: Some(self.description.to_owned()),
+ strict: Some(true),
+ parameters: Some(Value::Object(parameters)),
+ }
+ }
+
+ pub fn to_api(&self) -> ChatCompletionTool {
+ ChatCompletionTool {
+ r#type: ChatCompletionToolType::Function,
+ function: self.to_function_object(),
+ }
+ }
+}
+
+fn search_tool() -> Tool {
+ Tool {
+ name: "search",
+ parse_call: SearchCall::parse,
+ description: "Full-text search functionality. Used for finding individual notes with specific terms. Queries with multiple words will only return results with notes that have all of those words.",
+ arguments: vec![
+ ToolArg {
+ name: "query",
+ typ: ArgType::String,
+ required: true,
+ description: "The search query",
+ },
+
+ ToolArg {
+ name: "context",
+ typ: ArgType::Enum(vec!["home", "profile", "any"]),
+ required: true,
+ description: "The context in which the search is occuring. valid options are 'home', 'profile', 'any'",
+ }
+ ]
+ }
+}
+
+#[derive(Debug)]
+pub enum ToolCallError {
+ EmptyName,
+ EmptyArgs,
+ NotFound(String),
+ ArgParseFailure(String),
+}
+
+fn parse_tool_call(
+ tools: &HashMap<String, Tool>,
+ name: &str,
+ args: &str,
+) -> Result<ToolCall, ToolCallError> {
+ let Some(tool) = tools.get(name) else {
+ return Err(ToolCallError::NotFound(name.to_owned()));
+ };
+
+ (tool.parse_call)(&args)
+}
+
+fn dave_tools() -> Vec<Tool> {
+ vec![search_tool()]
+}