openai.rs (6586B)
1 use crate::backend::traits::AiBackend; 2 use crate::messages::DaveApiResponse; 3 use crate::tools::{PartialToolCall, Tool, ToolCall}; 4 use crate::Message; 5 use async_openai::{ 6 config::OpenAIConfig, 7 types::{ChatCompletionRequestMessage, CreateChatCompletionRequest}, 8 Client, 9 }; 10 use claude_agent_sdk_rs::PermissionMode; 11 use futures::StreamExt; 12 use nostrdb::{Ndb, Transaction}; 13 use std::collections::HashMap; 14 use std::path::PathBuf; 15 use std::sync::mpsc; 16 use std::sync::Arc; 17 18 pub struct OpenAiBackend { 19 client: Client<OpenAIConfig>, 20 ndb: Ndb, 21 } 22 23 impl OpenAiBackend { 24 pub fn new(client: Client<OpenAIConfig>, ndb: Ndb) -> Self { 25 Self { client, ndb } 26 } 27 } 28 29 impl AiBackend for OpenAiBackend { 30 fn stream_request( 31 &self, 32 messages: Vec<Message>, 33 tools: Arc<HashMap<String, Tool>>, 34 model: String, 35 user_id: String, 36 _session_id: String, 37 _cwd: Option<PathBuf>, 38 _resume_session_id: Option<String>, 39 ctx: egui::Context, 40 ) -> ( 41 mpsc::Receiver<DaveApiResponse>, 42 Option<tokio::task::JoinHandle<()>>, 43 ) { 44 let (tx, rx) = mpsc::channel(); 45 46 let api_messages: Vec<ChatCompletionRequestMessage> = { 47 let txn = Transaction::new(&self.ndb).expect("txn"); 48 messages 49 .iter() 50 .filter_map(|c| c.to_api_msg(&txn, &self.ndb)) 51 .collect() 52 }; 53 54 let client = self.client.clone(); 55 let tool_list: Vec<_> = tools.values().map(|t| t.to_api()).collect(); 56 57 let handle = tokio::spawn(async move { 58 let mut token_stream = match client 59 .chat() 60 .create_stream(CreateChatCompletionRequest { 61 model, 62 stream: Some(true), 63 messages: api_messages, 64 tools: Some(tool_list), 65 user: Some(user_id), 66 ..Default::default() 67 }) 68 .await 69 { 70 Err(err) => { 71 tracing::error!("openai chat error: {err}"); 72 let _ = tx.send(DaveApiResponse::Failed(err.to_string())); 73 return; 74 } 75 76 Ok(stream) => stream, 77 }; 78 79 let mut all_tool_calls: HashMap<u32, PartialToolCall> = HashMap::new(); 80 81 while let Some(token) = token_stream.next().await { 82 let token = match token { 83 Ok(token) => token, 84 Err(err) => { 85 tracing::error!("failed to get token: {err}"); 86 let _ = tx.send(DaveApiResponse::Failed(err.to_string())); 87 return; 88 } 89 }; 90 91 for choice in &token.choices { 92 let resp = &choice.delta; 93 94 // if we have tool call arg chunks, collect them here 95 if let Some(tool_calls) = &resp.tool_calls { 96 for tool in tool_calls { 97 let entry = all_tool_calls.entry(tool.index).or_default(); 98 99 if let Some(id) = &tool.id { 100 entry.id_mut().get_or_insert(id.clone()); 101 } 102 103 if let Some(name) = tool.function.as_ref().and_then(|f| f.name.as_ref()) 104 { 105 entry.name_mut().get_or_insert(name.to_string()); 106 } 107 108 if let Some(argchunk) = 109 tool.function.as_ref().and_then(|f| f.arguments.as_ref()) 110 { 111 entry 112 .arguments_mut() 113 .get_or_insert_with(String::new) 114 .push_str(argchunk); 115 } 116 } 117 } 118 119 if let Some(content) = &resp.content { 120 if let Err(err) = tx.send(DaveApiResponse::Token(content.to_owned())) { 121 tracing::error!("failed to send dave response token to ui: {err}"); 122 } 123 ctx.request_repaint(); 124 } 125 } 126 } 127 128 let mut parsed_tool_calls = vec![]; 129 for (_index, partial) in all_tool_calls { 130 let Some(unknown_tool_call) = partial.complete() else { 131 tracing::error!("could not complete partial tool call: {:?}", partial); 132 continue; 133 }; 134 135 match unknown_tool_call.parse(&tools) { 136 Ok(tool_call) => { 137 parsed_tool_calls.push(tool_call); 138 } 139 Err(err) => { 140 tracing::error!( 141 "failed to parse tool call {:?}: {}", 142 unknown_tool_call, 143 err, 144 ); 145 146 if let Some(id) = partial.id() { 147 parsed_tool_calls.push(ToolCall::invalid( 148 id.to_string(), 149 partial.name, 150 partial.arguments, 151 err.to_string(), 152 )); 153 } 154 } 155 }; 156 } 157 158 if !parsed_tool_calls.is_empty() 159 && tx 160 .send(DaveApiResponse::ToolCalls(parsed_tool_calls)) 161 .is_ok() 162 { 163 ctx.request_repaint(); 164 } 165 166 tracing::debug!("stream closed"); 167 }); 168 169 (rx, Some(handle)) 170 } 171 172 fn cleanup_session(&self, _session_id: String) { 173 // OpenAI backend doesn't maintain persistent connections per session 174 // No cleanup needed 175 } 176 177 fn interrupt_session(&self, _session_id: String, _ctx: egui::Context) { 178 // OpenAI backend doesn't support interrupts - requests complete atomically 179 // The JoinHandle can be aborted from the session side if needed 180 } 181 182 fn set_permission_mode(&self, _session_id: String, _mode: PermissionMode, _ctx: egui::Context) { 183 // OpenAI backend doesn't support permission modes / plan mode 184 tracing::warn!("Plan mode is not supported with the OpenAI backend"); 185 } 186 }