notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

commit 52aedaf762f496225dc317d39f094db147b6379e
parent 5be02ab0139a498042b89f9d801b138c85f9aa55
Author: William Casarin <jb55@jb55.com>
Date:   Sun, 25 Jan 2026 15:37:47 -0800

dave: add pluggable AI backend abstraction with Claude support

Refactor Dave to use a pluggable backend architecture, enabling support
for multiple AI providers. This implements Phase 1 (Backend Abstraction)
and Phase 2 (Basic Claude Integration) of the agent SDK integration plan.

Backend Abstraction:
- Extract OpenAI logic into OpenAiBackend implementing new AiBackend trait
- Create clean abstraction boundary at mpsc channel level
- Maintain existing streaming pattern and tool execution unchanged
- Add BackendType enum for backend selection

Claude Integration:
- Add ClaudeBackend using claude-agent-sdk-rs (v0.6)
- Support text streaming from Claude Code CLI
- Add environment variable support: ANTHROPIC_API_KEY, CLAUDE_API_KEY
- Add DAVE_BACKEND env var for explicit backend selection
- Auto-detect backend based on available API keys

Configuration:
- Default models: gpt-4o (OpenAI), claude-sonnet-4.5 (Claude)
- OpenAI backend remains default with trial key
- Backend selection: DAVE_BACKEND=claude or auto-detect from API keys

Tool support will be added in Phase 3.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Diffstat:
MCargo.lock | 204+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------
Mcrates/notedeck_chrome/src/chrome.rs | 2+-
Mcrates/notedeck_dave/Cargo.toml | 1+
Acrates/notedeck_dave/src/backend/claude.rs | 134+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Acrates/notedeck_dave/src/backend/mod.rs | 7+++++++
Acrates/notedeck_dave/src/backend/openai.rs | 160+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Acrates/notedeck_dave/src/backend/traits.rs | 27+++++++++++++++++++++++++++
Mcrates/notedeck_dave/src/config.rs | 52+++++++++++++++++++++++++++++++++++++++++++++++-----
Mcrates/notedeck_dave/src/lib.rs | 166++++++++++++++-----------------------------------------------------------------
Mcrates/notedeck_dave/src/tools.rs | 25+++++++++++++++++++++++--
10 files changed, 607 insertions(+), 171 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock @@ -453,6 +453,28 @@ dependencies = [ ] [[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + +[[package]] name = "async-task" version = "4.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -651,7 +673,7 @@ dependencies = [ "bitflags 2.9.1", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "lazy_static", "lazycell", "log", @@ -674,7 +696,7 @@ dependencies = [ "bitflags 2.9.1", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "log", "prettyplease", "proc-macro2", @@ -1077,7 +1099,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -1109,6 +1131,30 @@ dependencies = [ ] [[package]] +name = "claude-agent-sdk-rs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47e08b9f18ec1810d0355942b4a2b76b777c4d368ac0d72825e1b45f07b8fe1" +dependencies = [ + "anyhow", + "async-stream", + "async-trait", + "dashmap", + "flume", + "futures", + "paste", + "path-absolutize", + "pin-project", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", + "tracing", + "typed-builder", + "uuid", +] + +[[package]] name = "clipboard-win" version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1359,6 +1405,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06d2e3287df1c007e74221c49ca10a95d557349e54b3a75dc2fb14712c751f04" [[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] name = "data-encoding" version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1997,6 +2057,9 @@ name = "fastrand" version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +dependencies = [ + "getrandom 0.2.16", +] [[package]] name = "fdeflate" @@ -2123,6 +2186,18 @@ dependencies = [ ] [[package]] +name = "flume" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e139bc46ca777eb5efaf62df0ab8cc5fd400866427e56c68b22e414e53bd3be" +dependencies = [ + "fastrand", + "futures-core", + "futures-sink", + "spin", +] + +[[package]] name = "fnv" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2571,6 +2646,12 @@ checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" [[package]] name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + +[[package]] +name = "hashbrown" version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" @@ -2780,7 +2861,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.5.10", "tokio", "tower-service", "tracing", @@ -3315,7 +3396,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.53.2", ] [[package]] @@ -4042,6 +4123,7 @@ dependencies = [ "async-openai", "bytemuck", "chrono", + "claude-agent-sdk-rs", "eframe", "egui", "egui-wgpu", @@ -4691,6 +4773,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] +name = "path-absolutize" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4af381fe79fa195b4909485d99f73a80792331df0625188e707854f0b3383f5" +dependencies = [ + "path-dedot", +] + +[[package]] +name = "path-dedot" +version = "3.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07ba0ad7e047712414213ff67533e6dd477af0a4e1d14fb52343e53d30ea9397" +dependencies = [ + "once_cell", +] + +[[package]] name = "pbkdf2" version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -5026,7 +5126,7 @@ dependencies = [ "quinn-udp", "rustc-hash 2.1.1", "rustls", - "socket2", + "socket2 0.5.10", "thiserror 2.0.12", "tokio", "tracing", @@ -5063,7 +5163,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2", + "socket2 0.5.10", "tracing", "windows-sys 0.59.0", ] @@ -6109,6 +6209,25 @@ dependencies = [ ] [[package]] +name = "socket2" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] name = "spirv" version = "0.3.0+sdk-1.3.268.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -6441,25 +6560,25 @@ dependencies = [ [[package]] name = "tokio" -version = "1.45.1" +version = "1.49.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" +checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" dependencies = [ - "backtrace", "bytes", "libc", "mio", "pin-project-lite", - "socket2", + "signal-hook-registry", + "socket2 0.6.2", "tokio-macros", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] name = "tokio-macros" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" dependencies = [ "proc-macro2", "quote", @@ -6710,7 +6829,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "319c70195101a93f56db4c74733e272d720768e13471f400c78406a326b172b0" dependencies = [ "cc", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -6775,6 +6894,26 @@ dependencies = [ ] [[package]] +name = "typed-builder" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31aa81521b70f94402501d848ccc0ecaa8f93c8eb6999eb9747e72287757ffda" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "076a02dc54dd46795c2e9c8282ed40bcfb1e22747e955de9389a1de28190fb26" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + +[[package]] name = "typenum" version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -7006,9 +7145,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.17.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" +checksum = "ee48d38b119b0cd71fe4141b30f5ba9c7c5d9f4e7a3a8b4a674e4b6ef789976f" dependencies = [ "getrandom 0.3.3", "js-sys", @@ -7492,7 +7631,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -7540,7 +7679,7 @@ dependencies = [ "windows-collections", "windows-core 0.61.2", "windows-future", - "windows-link", + "windows-link 0.1.3", "windows-numerics", ] @@ -7593,7 +7732,7 @@ checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement 0.60.0", "windows-interface 0.59.1", - "windows-link", + "windows-link 0.1.3", "windows-result 0.3.4", "windows-strings 0.4.2", ] @@ -7605,7 +7744,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" dependencies = [ "windows-core 0.61.2", - "windows-link", + "windows-link 0.1.3", "windows-threading", ] @@ -7660,13 +7799,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" [[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] name = "windows-numerics" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" dependencies = [ "windows-core 0.61.2", - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -7693,7 +7838,7 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -7712,7 +7857,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-link", + "windows-link 0.1.3", ] [[package]] @@ -7761,6 +7906,15 @@ dependencies = [ ] [[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link 0.2.1", +] + +[[package]] name = "windows-targets" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -7828,7 +7982,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" dependencies = [ - "windows-link", + "windows-link 0.1.3", ] [[package]] diff --git a/crates/notedeck_chrome/src/chrome.rs b/crates/notedeck_chrome/src/chrome.rs @@ -152,7 +152,7 @@ impl Chrome { stop_debug_mode(notedeck.options()); let context = &mut notedeck.app_context(); - let dave = Dave::new(cc.wgpu_render_state.as_ref()); + let dave = Dave::new(cc.wgpu_render_state.as_ref(), context.ndb.clone()); let mut chrome = Chrome::default(); if !app_args.iter().any(|arg| arg == "--no-columns-app") { diff --git a/crates/notedeck_dave/Cargo.toml b/crates/notedeck_dave/Cargo.toml @@ -5,6 +5,7 @@ version.workspace = true [dependencies] async-openai = { version = "0.28.0", features = ["rustls-webpki-roots"] } +claude-agent-sdk-rs = "0.6" egui = { workspace = true } sha2 = { workspace = true } notedeck = { workspace = true } diff --git a/crates/notedeck_dave/src/backend/claude.rs b/crates/notedeck_dave/src/backend/claude.rs @@ -0,0 +1,134 @@ +use crate::backend::traits::AiBackend; +use crate::messages::DaveApiResponse; +use crate::tools::Tool; +use crate::Message; +use claude_agent_sdk_rs::{query_stream, ContentBlock, Message as ClaudeMessage, TextBlock}; +use futures::StreamExt; +use std::collections::HashMap; +use std::sync::mpsc; +use std::sync::Arc; + +pub struct ClaudeBackend { + api_key: String, +} + +impl ClaudeBackend { + pub fn new(api_key: String) -> Self { + Self { api_key } + } + + /// Convert our messages to a prompt for Claude Code + fn messages_to_prompt(messages: &[Message]) -> String { + let mut prompt = String::new(); + + // Include system message if present + for msg in messages { + if let Message::System(content) = msg { + prompt.push_str(content); + prompt.push_str("\n\n"); + break; + } + } + + // Format conversation history + for msg in messages { + match msg { + Message::System(_) => {} // Already handled + Message::User(content) => { + prompt.push_str("Human: "); + prompt.push_str(content); + prompt.push_str("\n\n"); + } + Message::Assistant(content) => { + prompt.push_str("Assistant: "); + prompt.push_str(content); + prompt.push_str("\n\n"); + } + Message::ToolCalls(_) | Message::ToolResponse(_) | Message::Error(_) => { + // Skip tool-related and error messages + } + } + } + + // Get the last user message as the actual query + if let Some(Message::User(user_msg)) = messages + .iter() + .rev() + .find(|m| matches!(m, Message::User(_))) + { + user_msg.clone() + } else { + prompt + } + } +} + +impl AiBackend for ClaudeBackend { + fn stream_request( + &self, + messages: Vec<Message>, + _tools: Arc<HashMap<String, Tool>>, + _model: String, + _user_id: String, + ctx: egui::Context, + ) -> mpsc::Receiver<DaveApiResponse> { + let (tx, rx) = mpsc::channel(); + let _api_key = self.api_key.clone(); + + tokio::spawn(async move { + let prompt = ClaudeBackend::messages_to_prompt(&messages); + + tracing::debug!( + "Sending request to Claude Code: prompt length: {}", + prompt.len() + ); + + let mut stream = match query_stream(prompt, None).await { + Ok(stream) => stream, + Err(err) => { + tracing::error!("Claude Code error: {}", err); + let _ = tx.send(DaveApiResponse::Failed(err.to_string())); + return; + } + }; + + while let Some(result) = stream.next().await { + match result { + Ok(message) => match message { + ClaudeMessage::Assistant(assistant_msg) => { + for block in &assistant_msg.message.content { + if let ContentBlock::Text(TextBlock { text }) = block { + if let Err(err) = tx.send(DaveApiResponse::Token(text.clone())) + { + tracing::error!("Failed to send token to UI: {}", err); + return; + } + ctx.request_repaint(); + } + } + } + ClaudeMessage::Result(result_msg) => { + if result_msg.is_error { + let error_text = result_msg + .result + .unwrap_or_else(|| "Unknown error".to_string()); + let _ = tx.send(DaveApiResponse::Failed(error_text)); + } + break; + } + _ => {} + }, + Err(err) => { + tracing::error!("Claude stream error: {}", err); + let _ = tx.send(DaveApiResponse::Failed(err.to_string())); + return; + } + } + } + + tracing::debug!("Claude stream closed"); + }); + + rx + } +} diff --git a/crates/notedeck_dave/src/backend/mod.rs b/crates/notedeck_dave/src/backend/mod.rs @@ -0,0 +1,7 @@ +mod claude; +mod openai; +mod traits; + +pub use claude::ClaudeBackend; +pub use openai::OpenAiBackend; +pub use traits::{AiBackend, BackendType}; diff --git a/crates/notedeck_dave/src/backend/openai.rs b/crates/notedeck_dave/src/backend/openai.rs @@ -0,0 +1,160 @@ +use crate::backend::traits::AiBackend; +use crate::messages::DaveApiResponse; +use crate::tools::{PartialToolCall, Tool, ToolCall}; +use crate::Message; +use async_openai::{ + config::OpenAIConfig, + types::{ChatCompletionRequestMessage, CreateChatCompletionRequest}, + Client, +}; +use futures::StreamExt; +use nostrdb::{Ndb, Transaction}; +use std::collections::HashMap; +use std::sync::mpsc; +use std::sync::Arc; + +pub struct OpenAiBackend { + client: Client<OpenAIConfig>, + ndb: Ndb, +} + +impl OpenAiBackend { + pub fn new(client: Client<OpenAIConfig>, ndb: Ndb) -> Self { + Self { client, ndb } + } +} + +impl AiBackend for OpenAiBackend { + fn stream_request( + &self, + messages: Vec<Message>, + tools: Arc<HashMap<String, Tool>>, + model: String, + user_id: String, + ctx: egui::Context, + ) -> mpsc::Receiver<DaveApiResponse> { + let (tx, rx) = mpsc::channel(); + + let api_messages: Vec<ChatCompletionRequestMessage> = { + let txn = Transaction::new(&self.ndb).expect("txn"); + messages + .iter() + .filter_map(|c| c.to_api_msg(&txn, &self.ndb)) + .collect() + }; + + let client = self.client.clone(); + let tool_list: Vec<_> = tools.values().map(|t| t.to_api()).collect(); + + tokio::spawn(async move { + let mut token_stream = match client + .chat() + .create_stream(CreateChatCompletionRequest { + model, + stream: Some(true), + messages: api_messages, + tools: Some(tool_list), + user: Some(user_id), + ..Default::default() + }) + .await + { + Err(err) => { + tracing::error!("openai chat error: {err}"); + return; + } + + Ok(stream) => stream, + }; + + let mut all_tool_calls: HashMap<u32, PartialToolCall> = HashMap::new(); + + while let Some(token) = token_stream.next().await { + let token = match token { + Ok(token) => token, + Err(err) => { + tracing::error!("failed to get token: {err}"); + let _ = tx.send(DaveApiResponse::Failed(err.to_string())); + return; + } + }; + + for choice in &token.choices { + let resp = &choice.delta; + + // if we have tool call arg chunks, collect them here + if let Some(tool_calls) = &resp.tool_calls { + for tool in tool_calls { + let entry = all_tool_calls.entry(tool.index).or_default(); + + if let Some(id) = &tool.id { + entry.id_mut().get_or_insert(id.clone()); + } + + if let Some(name) = tool.function.as_ref().and_then(|f| f.name.as_ref()) + { + entry.name_mut().get_or_insert(name.to_string()); + } + + if let Some(argchunk) = + tool.function.as_ref().and_then(|f| f.arguments.as_ref()) + { + entry + .arguments_mut() + .get_or_insert_with(String::new) + .push_str(argchunk); + } + } + } + + if let Some(content) = &resp.content { + if let Err(err) = tx.send(DaveApiResponse::Token(content.to_owned())) { + tracing::error!("failed to send dave response token to ui: {err}"); + } + ctx.request_repaint(); + } + } + } + + let mut parsed_tool_calls = vec![]; + for (_index, partial) in all_tool_calls { + let Some(unknown_tool_call) = partial.complete() else { + tracing::error!("could not complete partial tool call: {:?}", partial); + continue; + }; + + match unknown_tool_call.parse(&tools) { + Ok(tool_call) => { + parsed_tool_calls.push(tool_call); + } + Err(err) => { + tracing::error!( + "failed to parse tool call {:?}: {}", + unknown_tool_call, + err, + ); + + if let Some(id) = partial.id() { + parsed_tool_calls.push(ToolCall::invalid( + id.to_string(), + partial.name, + partial.arguments, + err.to_string(), + )); + } + } + }; + } + + if !parsed_tool_calls.is_empty() { + tx.send(DaveApiResponse::ToolCalls(parsed_tool_calls)) + .unwrap(); + ctx.request_repaint(); + } + + tracing::debug!("stream closed"); + }); + + rx + } +} diff --git a/crates/notedeck_dave/src/backend/traits.rs b/crates/notedeck_dave/src/backend/traits.rs @@ -0,0 +1,27 @@ +use crate::messages::DaveApiResponse; +use crate::tools::Tool; +use std::collections::HashMap; +use std::sync::mpsc; +use std::sync::Arc; + +/// Backend type selection +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BackendType { + OpenAI, + Claude, +} + +/// Trait for AI backend implementations +pub trait AiBackend: Send + Sync { + /// Stream a request to the AI backend + /// + /// Returns a receiver that will receive tokens and tool calls as they arrive + fn stream_request( + &self, + messages: Vec<crate::Message>, + tools: Arc<HashMap<String, Tool>>, + model: String, + user_id: String, + ctx: egui::Context, + ) -> mpsc::Receiver<DaveApiResponse>; +} diff --git a/crates/notedeck_dave/src/config.rs b/crates/notedeck_dave/src/config.rs @@ -1,3 +1,4 @@ +use crate::backend::BackendType; use async_openai::config::OpenAIConfig; /// Available AI providers for Dave @@ -124,9 +125,11 @@ impl DaveSettings { #[derive(Debug)] pub struct ModelConfig { pub trial: bool, + pub backend: BackendType, endpoint: Option<String>, model: String, api_key: Option<String>, + pub anthropic_api_key: Option<String>, } // short-term trial key for testing @@ -152,17 +155,54 @@ impl Default for ModelConfig { .ok() .or(std::env::var("OPENAI_API_KEY").ok()); + let anthropic_api_key = std::env::var("ANTHROPIC_API_KEY") + .ok() + .or(std::env::var("CLAUDE_API_KEY").ok()); + + // Determine backend: explicit env var takes precedence, otherwise auto-detect + let backend = if let Ok(backend_str) = std::env::var("DAVE_BACKEND") { + match backend_str.to_lowercase().as_str() { + "claude" | "anthropic" => BackendType::Claude, + "openai" => BackendType::OpenAI, + _ => { + tracing::warn!( + "Unknown DAVE_BACKEND value: {}, defaulting to OpenAI", + backend_str + ); + BackendType::OpenAI + } + } + } else { + // Auto-detect: prefer Claude if key is available, otherwise OpenAI + if anthropic_api_key.is_some() { + BackendType::Claude + } else { + BackendType::OpenAI + } + }; + // trial mode? - let trial = api_key.is_none(); - let api_key = api_key.or(Some(DAVE_TRIAL.to_string())); + let trial = api_key.is_none() && backend == BackendType::OpenAI; + let api_key = if backend == BackendType::OpenAI { + api_key.or(Some(DAVE_TRIAL.to_string())) + } else { + api_key + }; + + let model = std::env::var("DAVE_MODEL") + .ok() + .unwrap_or_else(|| match backend { + BackendType::OpenAI => "gpt-4o".to_string(), + BackendType::Claude => "claude-sonnet-4.5".to_string(), + }); ModelConfig { trial, + backend, endpoint: std::env::var("DAVE_ENDPOINT").ok(), - model: std::env::var("DAVE_MODEL") - .ok() - .unwrap_or("gpt-4o".to_string()), + model, api_key, + anthropic_api_key, } } } @@ -183,9 +223,11 @@ impl ModelConfig { pub fn ollama() -> Self { ModelConfig { trial: false, + backend: BackendType::OpenAI, // Ollama uses OpenAI-compatible API endpoint: std::env::var("OLLAMA_HOST").ok().map(|h| h + "/v1"), model: "hhao/qwen2.5-coder-tools:latest".to_string(), api_key: None, + anthropic_api_key: None, } } diff --git a/crates/notedeck_dave/src/lib.rs b/crates/notedeck_dave/src/lib.rs @@ -1,4 +1,5 @@ mod avatar; +mod backend; mod config; pub(crate) mod mesh; mod messages; @@ -8,20 +9,14 @@ mod tools; mod ui; mod vec3; -use async_openai::{ - config::OpenAIConfig, - types::{ChatCompletionRequestMessage, CreateChatCompletionRequest}, - Client, -}; +use backend::{AiBackend, BackendType, ClaudeBackend, OpenAiBackend}; use chrono::{Duration, Local}; use egui_wgpu::RenderState; use enostr::KeypairUnowned; -use futures::StreamExt; use nostrdb::Transaction; use notedeck::{ui::is_narrow, AppAction, AppContext, AppResponse}; use std::collections::HashMap; use std::string::ToString; -use std::sync::mpsc; use std::sync::Arc; pub use avatar::DaveAvatar; @@ -46,8 +41,8 @@ pub struct Dave { avatar: Option<DaveAvatar>, /// Shared tools available to all sessions tools: Arc<HashMap<String, Tool>>, - /// Shared API client - client: async_openai::Client<OpenAIConfig>, + /// AI backend (OpenAI, Claude, etc.) + backend: Box<dyn AiBackend>, /// Model configuration model_config: ModelConfig, /// Whether to show session list on mobile @@ -100,10 +95,25 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr )) } - pub fn new(render_state: Option<&RenderState>) -> Self { + pub fn new(render_state: Option<&RenderState>, ndb: nostrdb::Ndb) -> Self { let model_config = ModelConfig::default(); //let model_config = ModelConfig::ollama(); - let client = Client::with_config(model_config.to_api()); + + // Create backend based on configuration + let backend: Box<dyn AiBackend> = match model_config.backend { + BackendType::OpenAI => { + use async_openai::Client; + let client = Client::with_config(model_config.to_api()); + Box::new(OpenAiBackend::new(client, ndb.clone())) + } + BackendType::Claude => { + let api_key = model_config + .anthropic_api_key + .as_ref() + .expect("Claude backend requires ANTHROPIC_API_KEY or CLAUDE_API_KEY"); + Box::new(ClaudeBackend::new(api_key.clone())) + } + }; let avatar = render_state.map(DaveAvatar::new); let mut tools: HashMap<String, Tool> = HashMap::new(); @@ -114,7 +124,7 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr let settings = DaveSettings::from_model_config(&model_config); Dave { - client, + backend, avatar, session_manager: SessionManager::new(), tools: Arc::new(tools), @@ -335,137 +345,17 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr return; }; - let messages: Vec<ChatCompletionRequestMessage> = { - let txn = Transaction::new(app_ctx.ndb).expect("txn"); - session - .chat - .iter() - .filter_map(|c| c.to_api_msg(&txn, app_ctx.ndb)) - .collect() - }; - tracing::debug!("sending messages, latest: {:?}", messages.last().unwrap()); - let user_id = calculate_user_id(app_ctx.accounts.get_selected_account().keypair()); - - let ctx = ctx.clone(); - let client = self.client.clone(); + let messages = session.chat.clone(); let tools = self.tools.clone(); let model_name = self.model_config.model().to_owned(); + let ctx = ctx.clone(); - let (tx, rx) = mpsc::channel(); + // Use backend to stream request + let rx = self + .backend + .stream_request(messages, tools, model_name, user_id, ctx); session.incoming_tokens = Some(rx); - - tokio::spawn(async move { - let mut token_stream = match client - .chat() - .create_stream(CreateChatCompletionRequest { - model: model_name, - stream: Some(true), - messages, - tools: Some(tools::dave_tools().iter().map(|t| t.to_api()).collect()), - user: Some(user_id), - ..Default::default() - }) - .await - { - Err(err) => { - tracing::error!("openai chat error: {err}"); - return; - } - - Ok(stream) => stream, - }; - - let mut all_tool_calls: HashMap<u32, PartialToolCall> = HashMap::new(); - - while let Some(token) = token_stream.next().await { - let token = match token { - Ok(token) => token, - Err(err) => { - tracing::error!("failed to get token: {err}"); - let _ = tx.send(DaveApiResponse::Failed(err.to_string())); - return; - } - }; - - for choice in &token.choices { - let resp = &choice.delta; - - // if we have tool call arg chunks, collect them here - if let Some(tool_calls) = &resp.tool_calls { - for tool in tool_calls { - let entry = all_tool_calls.entry(tool.index).or_default(); - - if let Some(id) = &tool.id { - entry.id_mut().get_or_insert(id.clone()); - } - - if let Some(name) = tool.function.as_ref().and_then(|f| f.name.as_ref()) - { - entry.name_mut().get_or_insert(name.to_string()); - } - - if let Some(argchunk) = - tool.function.as_ref().and_then(|f| f.arguments.as_ref()) - { - entry - .arguments_mut() - .get_or_insert_with(String::new) - .push_str(argchunk); - } - } - } - - if let Some(content) = &resp.content { - if let Err(err) = tx.send(DaveApiResponse::Token(content.to_owned())) { - tracing::error!("failed to send dave response token to ui: {err}"); - } - ctx.request_repaint(); - } - } - } - - let mut parsed_tool_calls = vec![]; - for (_index, partial) in all_tool_calls { - let Some(unknown_tool_call) = partial.complete() else { - tracing::error!("could not complete partial tool call: {:?}", partial); - continue; - }; - - match unknown_tool_call.parse(&tools) { - Ok(tool_call) => { - parsed_tool_calls.push(tool_call); - } - Err(err) => { - // TODO: we should be - tracing::error!( - "failed to parse tool call {:?}: {}", - unknown_tool_call, - err, - ); - - if let Some(id) = partial.id() { - // we have an id, so we can communicate the error - // back to the ai - parsed_tool_calls.push(ToolCall::invalid( - id.to_string(), - partial.name, - partial.arguments, - err.to_string(), - )); - } - } - }; - } - - if !parsed_tool_calls.is_empty() { - tx.send(DaveApiResponse::ToolCalls(parsed_tool_calls)) - .unwrap(); - ctx.request_repaint(); - } - - tracing::debug!("stream closed"); - }); } } diff --git a/crates/notedeck_dave/src/tools.rs b/crates/notedeck_dave/src/tools.rs @@ -14,6 +14,10 @@ pub struct ToolCall { } impl ToolCall { + pub fn new(id: String, typ: ToolCalls) -> Self { + Self { id, typ } + } + pub fn id(&self) -> &str { &self.id } @@ -86,7 +90,7 @@ impl PartialToolCall { /// The query response from nostrdb for a given context #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueryResponse { - notes: Vec<u64>, + pub notes: Vec<u64>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -161,7 +165,16 @@ impl ToolCalls { } } - fn arguments(&self) -> String { + /// Returns the tool name as defined in the tool registry (for prompt-based tool calls) + pub fn tool_name(&self) -> &'static str { + match self { + Self::Query(_) => "query", + Self::Invalid(_) => "invalid", + Self::PresentNotes(_) => "present_notes", + } + } + + pub fn arguments(&self) -> String { match self { Self::Query(search) => serde_json::to_string(search).unwrap(), Self::Invalid(partial) => serde_json::to_string(partial).unwrap(), @@ -232,6 +245,14 @@ impl Tool { self.name } + pub fn description(&self) -> &'static str { + self.description + } + + pub fn parse_call(&self) -> fn(&str) -> Result<ToolCalls, ToolCallError> { + self.parse_call + } + pub fn to_function_object(&self) -> FunctionObject { let required_args = self .arguments