notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

commit 0b4807f62d0366c9b3f49b9154f03b46d6a60340
parent 4dfb013d6a9ac159fb74128965e8b364da6872d3
Author: William Casarin <jb55@jb55.com>
Date:   Tue, 25 Mar 2025 19:34:16 -0700

dave: tools working even better

Signed-off-by: William Casarin <jb55@jb55.com>

Diffstat:
MCargo.lock | 11+++++++----
Mcrates/notedeck_dave/Cargo.toml | 3+++
Mcrates/notedeck_dave/src/lib.rs | 409++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------------
3 files changed, 335 insertions(+), 88 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock @@ -3302,10 +3302,13 @@ dependencies = [ "egui", "egui-wgpu", "futures", + "hex", + "nostrdb", "notedeck", "reqwest", "serde", "serde_json", + "time", "tokio", "tracing", ] @@ -5261,9 +5264,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.40" +version = "0.3.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d9c75b47bdff86fa3334a3db91356b8d7d86a9b839dab7d0bdc5c3d3a077618" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" dependencies = [ "deranged", "itoa", @@ -5282,9 +5285,9 @@ checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" [[package]] name = "time-macros" -version = "0.2.21" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29aa485584182073ed57fd5004aa09c371f021325014694e432313345865fd04" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" dependencies = [ "num-conv", "time-core", diff --git a/crates/notedeck_dave/Cargo.toml b/crates/notedeck_dave/Cargo.toml @@ -13,6 +13,9 @@ tracing = { workspace = true } egui-wgpu = { workspace = true } serde_json = { workspace = true } serde = { workspace = true } +nostrdb = { workspace = true } +hex = { workspace = true } +time = "0.3.41" bytemuck = "1.22.0" futures = "0.3.31" reqwest = "0.12.15" diff --git a/crates/notedeck_dave/src/lib.rs b/crates/notedeck_dave/src/lib.rs @@ -1,21 +1,25 @@ use async_openai::{ config::OpenAIConfig, types::{ - ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent, - ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, - ChatCompletionRequestSystemMessageContent, ChatCompletionRequestUserMessage, - ChatCompletionRequestUserMessageContent, ChatCompletionTool, ChatCompletionToolType, - CreateChatCompletionRequest, FunctionObject, + ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage, + ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent, + ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, + ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, + ChatCompletionTool, ChatCompletionToolType, CreateChatCompletionRequest, FunctionCall, + FunctionObject, }, Client, }; use futures::StreamExt; +use nostrdb::{Ndb, NoteKey, Transaction}; use notedeck::AppContext; -use serde::Deserialize; -use serde_json::Value; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; use std::collections::HashMap; use std::sync::mpsc::{self, Receiver}; use std::sync::Arc; +use time::{format_description::well_known::Rfc3339, OffsetDateTime}; use avatar::DaveAvatar; use egui::{Rect, Vec2}; @@ -28,10 +32,114 @@ pub enum Message { User(String), Assistant(String), System(String), + ToolCalls(Vec<ToolCall>), + ToolResponse(ToolResponse), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResponse { + context: SearchContext, + notes: Vec<u64>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ToolResponses { + Search(SearchResponse), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResponse { + id: String, + typ: ToolResponses, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + id: String, + typ: ToolCalls, +} + +#[derive(Default, Debug, Clone)] +pub struct PartialToolCall { + id: Option<String>, + name: Option<String>, + arguments: Option<String>, +} + +#[derive(Debug, Clone)] +pub struct UnknownToolCall { + id: String, + name: String, + arguments: String, +} + +impl UnknownToolCall { + pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> { + let Some(tool) = tools.get(&self.name) else { + return Err(ToolCallError::NotFound(self.name.to_owned())); + }; + + let parsed_args = (tool.parse_call)(&self.arguments)?; + Ok(ToolCall { + id: self.id.clone(), + typ: parsed_args, + }) + } +} + +impl PartialToolCall { + pub fn complete(&self) -> Option<UnknownToolCall> { + Some(UnknownToolCall { + id: self.id.clone()?, + name: self.name.clone()?, + arguments: self.arguments.clone()?, + }) + } +} + +impl ToolCall { + pub fn to_api(&self) -> ChatCompletionMessageToolCall { + ChatCompletionMessageToolCall { + id: self.id.clone(), + r#type: ChatCompletionToolType::Function, + function: self.typ.to_api(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ToolCalls { + Search(SearchCall), +} + +impl ToolCalls { + pub fn to_api(&self) -> FunctionCall { + FunctionCall { + name: self.name().to_owned(), + arguments: self.arguments(), + } + } + + fn name(&self) -> &'static str { + match self { + Self::Search(_) => "search", + } + } + + fn arguments(&self) -> String { + match self { + Self::Search(search) => serde_json::to_string(search).unwrap(), + } + } +} + +pub enum DaveResponse { + ToolCalls(Vec<ToolCall>), + Token(String), } impl Message { - pub fn to_api_msg(&self) -> ChatCompletionRequestMessage { + pub fn to_api_msg(&self, txn: &Transaction, ndb: &Ndb) -> ChatCompletionRequestMessage { match self { Message::User(msg) => { ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { @@ -55,11 +163,88 @@ impl Message { ..Default::default() }) } + + Message::ToolCalls(calls) => { + ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { + tool_calls: Some(calls.iter().map(|c| c.to_api()).collect()), + ..Default::default() + }) + } + + Message::ToolResponse(resp) => { + let tool_response = format_tool_response_for_ai(txn, ndb, &resp.typ); + + ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage { + tool_call_id: resp.id.clone(), + content: ChatCompletionRequestToolMessageContent::Text(tool_response), + }) + } + } + } +} + +#[derive(Debug, Serialize)] +struct SimpleNote { + pubkey: String, + name: String, + content: String, + created_at: String, + note_kind: String, // todo: add replying to +} + +fn note_kind_desc(kind: u64) -> String { + match kind { + 1 => "microblog".to_string(), + 0 => "profile".to_string(), + _ => kind.to_string(), + } +} + +/// Take the result of a tool response and present it to the ai so that +/// it can interepret it and take further action +fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String { + match resp { + ToolResponses::Search(search_r) => { + let simple_notes: Vec<SimpleNote> = search_r + .notes + .iter() + .filter_map(|nkey| { + let Ok(note) = ndb.get_note_by_key(txn, NoteKey::new(*nkey)) else { + return None; + }; + + let name = ndb + .get_profile_by_pubkey(txn, note.pubkey()) + .ok() + .and_then(|p| p.record().profile()) + .and_then(|p| p.name().or_else(|| p.display_name())) + .unwrap_or("Anonymous") + .to_string(); + + let content = note.content().to_owned(); + let pubkey = hex::encode(note.pubkey()); + let note_kind = note_kind_desc(note.kind() as u64); + let created_at = OffsetDateTime::from_unix_timestamp(note.created_at() as i64) + .unwrap() + .format(&Rfc3339) + .unwrap(); + + Some(SimpleNote { + pubkey, + name, + content, + created_at, + note_kind, + }) + }) + .collect(); + + serde_json::to_string(&json!({"search_results": simple_notes})).unwrap() } } } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] #[serde(rename_all = "lowercase")] pub enum SearchContext { Home, @@ -67,16 +252,35 @@ pub enum SearchContext { Any, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize, Clone)] pub struct SearchCall { context: SearchContext, query: String, } impl SearchCall { - pub fn parse(args: &str) -> Result<ToolCall, ToolCallError> { + pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> SearchResponse { + let limit = 10i32; + let filter = nostrdb::Filter::new() + .search(&self.query) + .limit(limit as u64) + .build(); + let notes = { + if let Ok(results) = ndb.query(&txn, &[filter], limit) { + results.into_iter().map(|r| r.note_key.as_u64()).collect() + } else { + vec![] + } + }; + SearchResponse { + context: self.context.clone(), + notes, + } + } + + pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> { match serde_json::from_str::<SearchCall>(args) { - Ok(call) => Ok(ToolCall::Search(call)), + Ok(call) => Ok(ToolCalls::Search(call)), Err(e) => Err(ToolCallError::ArgParseFailure(format!( "Failed to parse args: '{}', error: {}", args, e @@ -85,16 +289,6 @@ impl SearchCall { } } -#[derive(Debug)] -pub enum ToolCall { - Search(SearchCall), -} - -pub enum DaveResponse { - ToolCall(ToolCall), - Token(String), -} - pub struct Dave { chat: Vec<Message>, /// A 3d representation of dave. @@ -116,13 +310,15 @@ impl Dave { let client = Client::with_config(config); let input = "".to_string(); - let pubkey = "test_pubkey".to_string(); + let pubkey = "32e1827635450ebb3c5a7d12c1f8e7b2b514439ac10a67eef3d9fd9c5c68e245".to_string(); let avatar = render_state.map(DaveAvatar::new); let mut tools: HashMap<String, Tool> = HashMap::new(); for tool in dave_tools() { tools.insert(tool.name.to_string(), tool); } + let system_prompt = Message::System(format!("You are an ai agent for the nostr protocol. You have access to tools that can query the network, so you can help find and summarize content for users. The current user's pubkey is {}.", &pubkey).to_string()); + Dave { client, pubkey, @@ -130,13 +326,11 @@ impl Dave { incoming_tokens: None, tools: Arc::new(tools), input, - chat: vec![ - Message::System("You are an ai agent for the nostr protocol. You have access to tools that can query the network, so you can help find content for users (TODO: actually implement this)".to_string()), - ], + chat: vec![system_prompt], } } - fn render(&mut self, ui: &mut egui::Ui) { + fn render(&mut self, app_ctx: &AppContext, ui: &mut egui::Ui) { if let Some(recvr) = &self.incoming_tokens { while let Ok(res) = recvr.try_recv() { match res { @@ -146,8 +340,23 @@ impl Dave { None => {} }, - DaveResponse::ToolCall(tool) => { - tracing::info!("got tool call: {:?}", tool); + DaveResponse::ToolCalls(toolcalls) => { + tracing::info!("got tool calls: {:?}", toolcalls); + self.chat.push(Message::ToolCalls(toolcalls.clone())); + + let txn = Transaction::new(app_ctx.ndb).unwrap(); + for call in &toolcalls { + // execute toolcall + match &call.typ { + ToolCalls::Search(search_call) => { + let resp = search_call.execute(&txn, app_ctx.ndb); + self.chat.push(Message::ToolResponse(ToolResponse { + id: call.id.clone(), + typ: ToolResponses::Search(resp), + })) + } + } + } } } } @@ -162,7 +371,7 @@ impl Dave { ui.vertical(|ui| { self.render_chat(ui); - self.inputbox(ui); + self.inputbox(app_ctx, ui); }) }); }); @@ -180,20 +389,48 @@ impl Dave { match message { Message::User(msg) => self.user_chat(msg, ui), Message::Assistant(msg) => self.assistant_chat(msg, ui), + Message::ToolResponse(msg) => Self::tool_response_ui(msg, ui), Message::System(_msg) => { // system prompt is not rendered. Maybe we could // have a debug option to show this } + Message::ToolCalls(toolcalls) => { + Self::tool_call_ui(&toolcalls, ui); + } } } } - fn inputbox(&mut self, ui: &mut egui::Ui) { + fn tool_response_ui(tool_response: &ToolResponse, ui: &mut egui::Ui) { + ui.label(format!("tool_response: {:?}", tool_response)); + } + + fn tool_call_ui(toolcalls: &[ToolCall], ui: &mut egui::Ui) { + ui.vertical(|ui| { + for call in toolcalls { + match &call.typ { + ToolCalls::Search(search_call) => { + ui.horizontal(|ui| { + let context = match search_call.context { + SearchContext::Profile => "profile ", + SearchContext::Any => " ", + SearchContext::Home => "home ", + }; + + ui.label(format!("Searching {}for '{}'", context, search_call.query)); + }); + } + } + } + }); + } + + fn inputbox(&mut self, app_ctx: &AppContext, ui: &mut egui::Ui) { ui.horizontal(|ui| { ui.add(egui::TextEdit::multiline(&mut self.input)); if ui.button("Sned").clicked() { self.chat.push(Message::User(self.input.clone())); - self.send_user_message(ui.ctx()); + self.send_user_message(app_ctx, ui.ctx()); self.input.clear(); } }); @@ -217,14 +454,22 @@ impl Dave { }); } - fn send_user_message(&mut self, ctx: &egui::Context) { - let messages = self.chat.iter().map(|c| c.to_api_msg()).collect(); + fn send_user_message(&mut self, app_ctx: &AppContext, ctx: &egui::Context) { + let messages = { + let txn = Transaction::new(app_ctx.ndb).expect("txn"); + self.chat + .iter() + .map(|c| c.to_api_msg(&txn, app_ctx.ndb)) + .collect() + }; let pubkey = self.pubkey.clone(); - let (tx, rx) = mpsc::channel(); - self.incoming_tokens = Some(rx); let ctx = ctx.clone(); let client = self.client.clone(); let tools = self.tools.clone(); + + let (tx, rx) = mpsc::channel(); + self.incoming_tokens = Some(rx); + tokio::spawn(async move { let mut token_stream = match client .chat() @@ -247,8 +492,7 @@ impl Dave { Ok(stream) => stream, }; - let mut tool_call_name: Option<String> = None; - let mut tool_call_chunks: Vec<String> = vec![]; + let mut all_tool_calls: HashMap<u32, PartialToolCall> = HashMap::new(); while let Some(token) = token_stream.next().await { let token = match token { @@ -265,19 +509,25 @@ impl Dave { // if we have tool call arg chunks, collect them here if let Some(tool_calls) = &resp.tool_calls { for tool in tool_calls { - let Some(fcall) = &tool.function else { - continue; - }; + let entry = all_tool_calls.entry(tool.index).or_default(); - if let Some(name) = &fcall.name { - tool_call_name = Some(name.clone()); + if let Some(id) = &tool.id { + entry.id.get_or_insert(id.to_string()); } - let Some(argchunk) = &fcall.arguments else { - continue; - }; + if let Some(name) = tool.function.as_ref().and_then(|f| f.name.as_ref()) + { + entry.name.get_or_insert(name.to_string()); + } - tool_call_chunks.push(argchunk.clone()); + if let Some(argchunk) = + tool.function.as_ref().and_then(|f| f.arguments.as_ref()) + { + entry + .arguments + .get_or_insert_with(String::new) + .push_str(&argchunk); + } } } @@ -288,28 +538,31 @@ impl Dave { } } - if let Some(tool_name) = tool_call_name { - if !tool_call_chunks.is_empty() { - let args = tool_call_chunks.join(""); - match parse_tool_call(&tools, &tool_name, &args) { - Ok(tool_call) => { - tx.send(DaveResponse::ToolCall(tool_call)).unwrap(); - ctx.request_repaint(); - } - Err(err) => { - tracing::error!( - "failed to parse tool call err({:?}): name({:?}) args({:?})", - err, - tool_name, - args, - ); - // TODO: return error to user - } - }; - } else { - // TODO: return error to user - tracing::error!("got tool call '{}' with no arguments?", tool_name); - } + let mut parsed_tool_calls = vec![]; + for (_index, partial) in &all_tool_calls { + let Some(unknown_tool_call) = partial.complete() else { + tracing::error!("could not complete partial tool call: {:?}", partial); + continue; + }; + + match unknown_tool_call.parse(&tools) { + Ok(tool_call) => { + parsed_tool_calls.push(tool_call); + } + Err(err) => { + tracing::error!( + "failed to parse tool call {:?}: {:?}", + unknown_tool_call, + err, + ); + // TODO: return error to user + } + }; + } + + if !parsed_tool_calls.is_empty() { + tx.send(DaveResponse::ToolCalls(parsed_tool_calls)).unwrap(); + ctx.request_repaint(); } tracing::debug!("stream closed"); @@ -318,7 +571,7 @@ impl Dave { } impl notedeck::App for Dave { - fn update(&mut self, _ctx: &mut AppContext<'_>, ui: &mut egui::Ui) { + fn update(&mut self, ctx: &mut AppContext<'_>, ui: &mut egui::Ui) { /* self.app .frame_history @@ -326,7 +579,7 @@ impl notedeck::App for Dave { */ //update_dave(self, ctx, ui.ctx()); - self.render(ui); + self.render(ctx, ui); } } @@ -357,7 +610,7 @@ struct ToolArg { #[derive(Debug, Clone)] pub struct Tool { - parse_call: fn(&str) -> Result<ToolCall, ToolCallError>, + parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>, name: &'static str, description: &'static str, arguments: Vec<ToolArg>, @@ -458,18 +711,6 @@ pub enum ToolCallError { ArgParseFailure(String), } -fn parse_tool_call( - tools: &HashMap<String, Tool>, - name: &str, - args: &str, -) -> Result<ToolCall, ToolCallError> { - let Some(tool) = tools.get(name) else { - return Err(ToolCallError::NotFound(name.to_owned())); - }; - - (tool.parse_call)(&args) -} - fn dave_tools() -> Vec<Tool> { vec![search_tool()] }