notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

openai.rs (6586B)


      1 use crate::backend::traits::AiBackend;
      2 use crate::messages::DaveApiResponse;
      3 use crate::tools::{PartialToolCall, Tool, ToolCall};
      4 use crate::Message;
      5 use async_openai::{
      6     config::OpenAIConfig,
      7     types::{ChatCompletionRequestMessage, CreateChatCompletionRequest},
      8     Client,
      9 };
     10 use claude_agent_sdk_rs::PermissionMode;
     11 use futures::StreamExt;
     12 use nostrdb::{Ndb, Transaction};
     13 use std::collections::HashMap;
     14 use std::path::PathBuf;
     15 use std::sync::mpsc;
     16 use std::sync::Arc;
     17 
     18 pub struct OpenAiBackend {
     19     client: Client<OpenAIConfig>,
     20     ndb: Ndb,
     21 }
     22 
     23 impl OpenAiBackend {
     24     pub fn new(client: Client<OpenAIConfig>, ndb: Ndb) -> Self {
     25         Self { client, ndb }
     26     }
     27 }
     28 
     29 impl AiBackend for OpenAiBackend {
     30     fn stream_request(
     31         &self,
     32         messages: Vec<Message>,
     33         tools: Arc<HashMap<String, Tool>>,
     34         model: String,
     35         user_id: String,
     36         _session_id: String,
     37         _cwd: Option<PathBuf>,
     38         _resume_session_id: Option<String>,
     39         ctx: egui::Context,
     40     ) -> (
     41         mpsc::Receiver<DaveApiResponse>,
     42         Option<tokio::task::JoinHandle<()>>,
     43     ) {
     44         let (tx, rx) = mpsc::channel();
     45 
     46         let api_messages: Vec<ChatCompletionRequestMessage> = {
     47             let txn = Transaction::new(&self.ndb).expect("txn");
     48             messages
     49                 .iter()
     50                 .filter_map(|c| c.to_api_msg(&txn, &self.ndb))
     51                 .collect()
     52         };
     53 
     54         let client = self.client.clone();
     55         let tool_list: Vec<_> = tools.values().map(|t| t.to_api()).collect();
     56 
     57         let handle = tokio::spawn(async move {
     58             let mut token_stream = match client
     59                 .chat()
     60                 .create_stream(CreateChatCompletionRequest {
     61                     model,
     62                     stream: Some(true),
     63                     messages: api_messages,
     64                     tools: Some(tool_list),
     65                     user: Some(user_id),
     66                     ..Default::default()
     67                 })
     68                 .await
     69             {
     70                 Err(err) => {
     71                     tracing::error!("openai chat error: {err}");
     72                     let _ = tx.send(DaveApiResponse::Failed(err.to_string()));
     73                     return;
     74                 }
     75 
     76                 Ok(stream) => stream,
     77             };
     78 
     79             let mut all_tool_calls: HashMap<u32, PartialToolCall> = HashMap::new();
     80 
     81             while let Some(token) = token_stream.next().await {
     82                 let token = match token {
     83                     Ok(token) => token,
     84                     Err(err) => {
     85                         tracing::error!("failed to get token: {err}");
     86                         let _ = tx.send(DaveApiResponse::Failed(err.to_string()));
     87                         return;
     88                     }
     89                 };
     90 
     91                 for choice in &token.choices {
     92                     let resp = &choice.delta;
     93 
     94                     // if we have tool call arg chunks, collect them here
     95                     if let Some(tool_calls) = &resp.tool_calls {
     96                         for tool in tool_calls {
     97                             let entry = all_tool_calls.entry(tool.index).or_default();
     98 
     99                             if let Some(id) = &tool.id {
    100                                 entry.id_mut().get_or_insert(id.clone());
    101                             }
    102 
    103                             if let Some(name) = tool.function.as_ref().and_then(|f| f.name.as_ref())
    104                             {
    105                                 entry.name_mut().get_or_insert(name.to_string());
    106                             }
    107 
    108                             if let Some(argchunk) =
    109                                 tool.function.as_ref().and_then(|f| f.arguments.as_ref())
    110                             {
    111                                 entry
    112                                     .arguments_mut()
    113                                     .get_or_insert_with(String::new)
    114                                     .push_str(argchunk);
    115                             }
    116                         }
    117                     }
    118 
    119                     if let Some(content) = &resp.content {
    120                         if let Err(err) = tx.send(DaveApiResponse::Token(content.to_owned())) {
    121                             tracing::error!("failed to send dave response token to ui: {err}");
    122                         }
    123                         ctx.request_repaint();
    124                     }
    125                 }
    126             }
    127 
    128             let mut parsed_tool_calls = vec![];
    129             for (_index, partial) in all_tool_calls {
    130                 let Some(unknown_tool_call) = partial.complete() else {
    131                     tracing::error!("could not complete partial tool call: {:?}", partial);
    132                     continue;
    133                 };
    134 
    135                 match unknown_tool_call.parse(&tools) {
    136                     Ok(tool_call) => {
    137                         parsed_tool_calls.push(tool_call);
    138                     }
    139                     Err(err) => {
    140                         tracing::error!(
    141                             "failed to parse tool call {:?}: {}",
    142                             unknown_tool_call,
    143                             err,
    144                         );
    145 
    146                         if let Some(id) = partial.id() {
    147                             parsed_tool_calls.push(ToolCall::invalid(
    148                                 id.to_string(),
    149                                 partial.name,
    150                                 partial.arguments,
    151                                 err.to_string(),
    152                             ));
    153                         }
    154                     }
    155                 };
    156             }
    157 
    158             if !parsed_tool_calls.is_empty()
    159                 && tx
    160                     .send(DaveApiResponse::ToolCalls(parsed_tool_calls))
    161                     .is_ok()
    162             {
    163                 ctx.request_repaint();
    164             }
    165 
    166             tracing::debug!("stream closed");
    167         });
    168 
    169         (rx, Some(handle))
    170     }
    171 
    172     fn cleanup_session(&self, _session_id: String) {
    173         // OpenAI backend doesn't maintain persistent connections per session
    174         // No cleanup needed
    175     }
    176 
    177     fn interrupt_session(&self, _session_id: String, _ctx: egui::Context) {
    178         // OpenAI backend doesn't support interrupts - requests complete atomically
    179         // The JoinHandle can be aborted from the session side if needed
    180     }
    181 
    182     fn set_permission_mode(&self, _session_id: String, _mode: PermissionMode, _ctx: egui::Context) {
    183         // OpenAI backend doesn't support permission modes / plan mode
    184         tracing::warn!("Plan mode is not supported with the OpenAI backend");
    185     }
    186 }