claude.rs (33153B)
1 use crate::auto_accept::AutoAcceptRules; 2 use crate::backend::session_info::parse_session_info; 3 use crate::backend::tool_summary::{ 4 extract_response_content, format_tool_summary, truncate_output, 5 }; 6 use crate::backend::traits::AiBackend; 7 use crate::messages::{ 8 CompactionInfo, DaveApiResponse, PendingPermission, PermissionRequest, PermissionResponse, 9 SubagentInfo, SubagentStatus, ToolResult, 10 }; 11 use crate::tools::Tool; 12 use crate::Message; 13 use claude_agent_sdk_rs::{ 14 ClaudeAgentOptions, ClaudeClient, ContentBlock, Message as ClaudeMessage, PermissionMode, 15 PermissionResult, PermissionResultAllow, PermissionResultDeny, ToolUseBlock, UserContentBlock, 16 }; 17 use dashmap::DashMap; 18 use futures::future::BoxFuture; 19 use futures::StreamExt; 20 use std::collections::HashMap; 21 use std::path::PathBuf; 22 use std::sync::mpsc; 23 use std::sync::Arc; 24 use tokio::sync::mpsc as tokio_mpsc; 25 use tokio::sync::oneshot; 26 use uuid::Uuid; 27 28 /// Commands sent to a session's actor task 29 enum SessionCommand { 30 Query { 31 prompt: String, 32 response_tx: mpsc::Sender<DaveApiResponse>, 33 ctx: egui::Context, 34 }, 35 /// Interrupt the current query - stops the stream but preserves session 36 Interrupt { 37 ctx: egui::Context, 38 }, 39 /// Set the permission mode (Default or Plan) 40 SetPermissionMode { 41 mode: PermissionMode, 42 ctx: egui::Context, 43 }, 44 Shutdown, 45 } 46 47 /// Handle to a session's actor 48 struct SessionHandle { 49 command_tx: tokio_mpsc::Sender<SessionCommand>, 50 } 51 52 pub struct ClaudeBackend { 53 #[allow(dead_code)] // May be used in the future for API key validation 54 api_key: String, 55 /// Registry of active sessions (using dashmap for lock-free access) 56 sessions: DashMap<String, SessionHandle>, 57 } 58 59 impl ClaudeBackend { 60 pub fn new(api_key: String) -> Self { 61 Self { 62 api_key, 63 sessions: DashMap::new(), 64 } 65 } 66 67 /// Convert our messages to a prompt for Claude Code 68 fn messages_to_prompt(messages: &[Message]) -> String { 69 let mut prompt = String::new(); 70 71 // Include system message if present 72 for msg in messages { 73 if let Message::System(content) = msg { 74 prompt.push_str(content); 75 prompt.push_str("\n\n"); 76 break; 77 } 78 } 79 80 // Format conversation history 81 for msg in messages { 82 match msg { 83 Message::System(_) => {} // Already handled 84 Message::User(content) => { 85 prompt.push_str("Human: "); 86 prompt.push_str(content); 87 prompt.push_str("\n\n"); 88 } 89 Message::Assistant(content) => { 90 prompt.push_str("Assistant: "); 91 prompt.push_str(content); 92 prompt.push_str("\n\n"); 93 } 94 Message::ToolCalls(_) 95 | Message::ToolResponse(_) 96 | Message::Error(_) 97 | Message::PermissionRequest(_) 98 | Message::ToolResult(_) 99 | Message::CompactionComplete(_) 100 | Message::Subagent(_) => { 101 // Skip tool-related, error, permission, tool result, compaction, and subagent messages 102 } 103 } 104 } 105 106 prompt 107 } 108 109 /// Extract only the latest user message for session continuation 110 fn get_latest_user_message(messages: &[Message]) -> String { 111 messages 112 .iter() 113 .rev() 114 .find_map(|m| match m { 115 Message::User(content) => Some(content.clone()), 116 _ => None, 117 }) 118 .unwrap_or_default() 119 } 120 } 121 122 /// Permission request forwarded from the callback to the actor 123 struct PermissionRequestInternal { 124 tool_name: String, 125 tool_input: serde_json::Value, 126 response_tx: oneshot::Sender<PermissionResult>, 127 } 128 129 /// Session actor task that owns a single ClaudeClient with persistent connection 130 async fn session_actor( 131 session_id: String, 132 cwd: Option<PathBuf>, 133 resume_session_id: Option<String>, 134 mut command_rx: tokio_mpsc::Receiver<SessionCommand>, 135 ) { 136 // Permission channel - the callback sends to perm_tx, actor receives on perm_rx 137 let (perm_tx, mut perm_rx) = tokio_mpsc::channel::<PermissionRequestInternal>(16); 138 139 // Create the can_use_tool callback that forwards to our permission channel 140 let can_use_tool: Arc< 141 dyn Fn( 142 String, 143 serde_json::Value, 144 claude_agent_sdk_rs::ToolPermissionContext, 145 ) -> BoxFuture<'static, PermissionResult> 146 + Send 147 + Sync, 148 > = Arc::new({ 149 let perm_tx = perm_tx.clone(); 150 move |tool_name: String, 151 tool_input: serde_json::Value, 152 _context: claude_agent_sdk_rs::ToolPermissionContext| { 153 let perm_tx = perm_tx.clone(); 154 Box::pin(async move { 155 let (resp_tx, resp_rx) = oneshot::channel(); 156 if perm_tx 157 .send(PermissionRequestInternal { 158 tool_name: tool_name.clone(), 159 tool_input, 160 response_tx: resp_tx, 161 }) 162 .await 163 .is_err() 164 { 165 return PermissionResult::Deny(PermissionResultDeny { 166 message: "Session actor channel closed".to_string(), 167 interrupt: true, 168 }); 169 } 170 // Wait for response from session actor (which forwards from UI) 171 match resp_rx.await { 172 Ok(result) => result, 173 Err(_) => PermissionResult::Deny(PermissionResultDeny { 174 message: "Permission response cancelled".to_string(), 175 interrupt: true, 176 }), 177 } 178 }) 179 } 180 }); 181 182 // A stderr callback to prevent the subprocess from blocking 183 let stderr_callback = Arc::new(|msg: String| { 184 tracing::trace!("Claude CLI stderr: {}", msg); 185 }); 186 187 // Log if we're resuming a session 188 if let Some(ref resume_id) = resume_session_id { 189 tracing::info!( 190 "Session {} will resume Claude session: {}", 191 session_id, 192 resume_id 193 ); 194 } 195 196 // Create client once - this maintains the persistent connection 197 // Using match to handle the TypedBuilder's strict type requirements 198 let options = match (&cwd, &resume_session_id) { 199 (Some(dir), Some(resume_id)) => ClaudeAgentOptions::builder() 200 .permission_mode(PermissionMode::Default) 201 .stderr_callback(stderr_callback) 202 .can_use_tool(can_use_tool) 203 .include_partial_messages(true) 204 .cwd(dir) 205 .resume(resume_id) 206 .build(), 207 (Some(dir), None) => ClaudeAgentOptions::builder() 208 .permission_mode(PermissionMode::Default) 209 .stderr_callback(stderr_callback) 210 .can_use_tool(can_use_tool) 211 .include_partial_messages(true) 212 .cwd(dir) 213 .build(), 214 (None, Some(resume_id)) => ClaudeAgentOptions::builder() 215 .permission_mode(PermissionMode::Default) 216 .stderr_callback(stderr_callback) 217 .can_use_tool(can_use_tool) 218 .include_partial_messages(true) 219 .resume(resume_id) 220 .build(), 221 (None, None) => ClaudeAgentOptions::builder() 222 .permission_mode(PermissionMode::Default) 223 .stderr_callback(stderr_callback) 224 .can_use_tool(can_use_tool) 225 .include_partial_messages(true) 226 .build(), 227 }; 228 let mut client = ClaudeClient::new(options); 229 230 // Connect once - this starts the subprocess 231 if let Err(err) = client.connect().await { 232 tracing::error!("Session {} failed to connect: {}", session_id, err); 233 // Process any pending commands to report the error 234 while let Some(cmd) = command_rx.recv().await { 235 if let SessionCommand::Query { 236 ref response_tx, .. 237 } = cmd 238 { 239 let _ = response_tx.send(DaveApiResponse::Failed(format!( 240 "Failed to connect to Claude: {}", 241 err 242 ))); 243 } 244 if matches!(cmd, SessionCommand::Shutdown) { 245 break; 246 } 247 } 248 return; 249 } 250 251 tracing::debug!("Session {} connected successfully", session_id); 252 253 // Process commands 254 while let Some(cmd) = command_rx.recv().await { 255 match cmd { 256 SessionCommand::Query { 257 prompt, 258 response_tx, 259 ctx, 260 } => { 261 // Send query using session_id for context 262 if let Err(err) = client.query_with_session(&prompt, &session_id).await { 263 tracing::error!("Session {} query error: {}", session_id, err); 264 let _ = response_tx.send(DaveApiResponse::Failed(err.to_string())); 265 continue; 266 } 267 268 // Track pending tool uses: tool_use_id -> (tool_name, tool_input) 269 let mut pending_tools: HashMap<String, (String, serde_json::Value)> = 270 HashMap::new(); 271 272 // Stream response with select! to handle stream, permission requests, and interrupts 273 let mut stream = client.receive_response(); 274 let mut stream_done = false; 275 276 while !stream_done { 277 tokio::select! { 278 biased; 279 280 // Check for interrupt command (highest priority) 281 Some(cmd) = command_rx.recv() => { 282 match cmd { 283 SessionCommand::Interrupt { ctx: interrupt_ctx } => { 284 tracing::debug!("Session {} received interrupt", session_id); 285 if let Err(err) = client.interrupt().await { 286 tracing::error!("Failed to send interrupt: {}", err); 287 } 288 // Let the stream end naturally - it will send a Result message 289 // The session history is preserved by the CLI 290 interrupt_ctx.request_repaint(); 291 } 292 SessionCommand::Query { response_tx: new_tx, .. } => { 293 // A new query came in while we're still streaming - shouldn't happen 294 // but handle gracefully by rejecting it 295 let _ = new_tx.send(DaveApiResponse::Failed( 296 "Query already in progress".to_string() 297 )); 298 } 299 SessionCommand::SetPermissionMode { mode, ctx: mode_ctx } => { 300 // Permission mode change during query - apply it 301 tracing::debug!("Session {} setting permission mode to {:?} during query", session_id, mode); 302 if let Err(err) = client.set_permission_mode(mode).await { 303 tracing::error!("Failed to set permission mode: {}", err); 304 } 305 mode_ctx.request_repaint(); 306 } 307 SessionCommand::Shutdown => { 308 tracing::debug!("Session actor {} shutting down during query", session_id); 309 // Drop stream and disconnect - break to exit loop first 310 drop(stream); 311 if let Err(err) = client.disconnect().await { 312 tracing::warn!("Error disconnecting session {}: {}", session_id, err); 313 } 314 tracing::debug!("Session {} actor exited", session_id); 315 return; 316 } 317 } 318 } 319 320 // Handle permission requests (they're blocking the SDK) 321 Some(perm_req) = perm_rx.recv() => { 322 // Check auto-accept rules 323 let auto_accept_rules = AutoAcceptRules::default(); 324 if auto_accept_rules.should_auto_accept(&perm_req.tool_name, &perm_req.tool_input) { 325 tracing::debug!("Auto-accepting {}: matched auto-accept rule", perm_req.tool_name); 326 let _ = perm_req.response_tx.send(PermissionResult::Allow(PermissionResultAllow::default())); 327 continue; 328 } 329 330 // Forward permission request to UI 331 let request_id = Uuid::new_v4(); 332 let (ui_resp_tx, ui_resp_rx) = oneshot::channel(); 333 334 let request = PermissionRequest { 335 id: request_id, 336 tool_name: perm_req.tool_name.clone(), 337 tool_input: perm_req.tool_input.clone(), 338 response: None, 339 answer_summary: None, 340 }; 341 342 let pending = PendingPermission { 343 request, 344 response_tx: ui_resp_tx, 345 }; 346 347 if response_tx.send(DaveApiResponse::PermissionRequest(pending)).is_err() { 348 tracing::error!("Failed to send permission request to UI"); 349 let _ = perm_req.response_tx.send(PermissionResult::Deny(PermissionResultDeny { 350 message: "UI channel closed".to_string(), 351 interrupt: true, 352 })); 353 continue; 354 } 355 356 ctx.request_repaint(); 357 358 // Wait for UI response inline - blocking is OK since stream is 359 // waiting for permission result anyway 360 let tool_name = perm_req.tool_name.clone(); 361 let result = match ui_resp_rx.await { 362 Ok(PermissionResponse::Allow { message }) => { 363 if let Some(msg) = &message { 364 tracing::debug!("User allowed tool {} with message: {}", tool_name, msg); 365 // Inject user message into conversation so AI sees it 366 if let Err(err) = client.query_with_content_and_session( 367 vec![UserContentBlock::text(msg.as_str())], 368 &session_id 369 ).await { 370 tracing::error!("Failed to inject user message: {}", err); 371 } 372 } else { 373 tracing::debug!("User allowed tool: {}", tool_name); 374 } 375 PermissionResult::Allow(PermissionResultAllow::default()) 376 } 377 Ok(PermissionResponse::Deny { reason }) => { 378 tracing::debug!("User denied tool {}: {}", tool_name, reason); 379 PermissionResult::Deny(PermissionResultDeny { 380 message: reason, 381 interrupt: false, 382 }) 383 } 384 Err(_) => { 385 tracing::error!("Permission response channel closed"); 386 PermissionResult::Deny(PermissionResultDeny { 387 message: "Permission request cancelled".to_string(), 388 interrupt: true, 389 }) 390 } 391 }; 392 let _ = perm_req.response_tx.send(result); 393 } 394 395 stream_result = stream.next() => { 396 match stream_result { 397 Some(Ok(message)) => { 398 match message { 399 ClaudeMessage::Assistant(assistant_msg) => { 400 for block in &assistant_msg.message.content { 401 if let ContentBlock::ToolUse(ToolUseBlock { id, name, input }) = block { 402 pending_tools.insert(id.clone(), (name.clone(), input.clone())); 403 404 // Emit SubagentSpawned for Task tool calls 405 if name == "Task" { 406 let description = input 407 .get("description") 408 .and_then(|v| v.as_str()) 409 .unwrap_or("task") 410 .to_string(); 411 let subagent_type = input 412 .get("subagent_type") 413 .and_then(|v| v.as_str()) 414 .unwrap_or("unknown") 415 .to_string(); 416 417 let subagent_info = SubagentInfo { 418 task_id: id.clone(), 419 description, 420 subagent_type, 421 status: SubagentStatus::Running, 422 output: String::new(), 423 max_output_size: 4000, 424 }; 425 let _ = response_tx.send(DaveApiResponse::SubagentSpawned(subagent_info)); 426 ctx.request_repaint(); 427 } 428 } 429 } 430 } 431 ClaudeMessage::StreamEvent(event) => { 432 if let Some(event_type) = event.event.get("type").and_then(|v| v.as_str()) { 433 if event_type == "content_block_delta" { 434 if let Some(text) = event 435 .event 436 .get("delta") 437 .and_then(|d| d.get("text")) 438 .and_then(|t| t.as_str()) 439 { 440 if response_tx.send(DaveApiResponse::Token(text.to_string())).is_err() { 441 tracing::error!("Failed to send token to UI"); 442 // Setting stream_done isn't needed since we break immediately 443 break; 444 } 445 ctx.request_repaint(); 446 } 447 } 448 } 449 } 450 ClaudeMessage::Result(result_msg) => { 451 if result_msg.is_error { 452 let error_text = result_msg 453 .result 454 .unwrap_or_else(|| "Unknown error".to_string()); 455 let _ = response_tx.send(DaveApiResponse::Failed(error_text)); 456 } 457 stream_done = true; 458 } 459 ClaudeMessage::User(user_msg) => { 460 if let Some(tool_use_result) = user_msg.extra.get("tool_use_result") { 461 let tool_use_id = user_msg 462 .extra 463 .get("message") 464 .and_then(|m| m.get("content")) 465 .and_then(|c| c.as_array()) 466 .and_then(|arr| arr.first()) 467 .and_then(|item| item.get("tool_use_id")) 468 .and_then(|id| id.as_str()); 469 470 if let Some(tool_use_id) = tool_use_id { 471 if let Some((tool_name, tool_input)) = pending_tools.remove(tool_use_id) { 472 // Check if this is a Task tool completion 473 if tool_name == "Task" { 474 let result_text = extract_response_content(tool_use_result) 475 .unwrap_or_else(|| "completed".to_string()); 476 let _ = response_tx.send(DaveApiResponse::SubagentCompleted { 477 task_id: tool_use_id.to_string(), 478 result: truncate_output(&result_text, 2000), 479 }); 480 } 481 482 let summary = format_tool_summary(&tool_name, &tool_input, tool_use_result); 483 let tool_result = ToolResult { tool_name, summary }; 484 let _ = response_tx.send(DaveApiResponse::ToolResult(tool_result)); 485 ctx.request_repaint(); 486 } 487 } 488 } 489 } 490 ClaudeMessage::System(system_msg) => { 491 // Handle system init message - extract session info 492 if system_msg.subtype == "init" { 493 let session_info = parse_session_info(&system_msg); 494 let _ = response_tx.send(DaveApiResponse::SessionInfo(session_info)); 495 ctx.request_repaint(); 496 } else if system_msg.subtype == "status" { 497 // Handle status messages (compaction start/end) 498 let status = system_msg.data.get("status") 499 .and_then(|v| v.as_str()); 500 if status == Some("compacting") { 501 let _ = response_tx.send(DaveApiResponse::CompactionStarted); 502 ctx.request_repaint(); 503 } 504 // status: null means compaction finished (handled by compact_boundary) 505 } else if system_msg.subtype == "compact_boundary" { 506 // Compaction completed - extract token savings info 507 let pre_tokens = system_msg.data.get("pre_tokens") 508 .and_then(|v| v.as_u64()) 509 .unwrap_or(0); 510 let info = CompactionInfo { pre_tokens }; 511 let _ = response_tx.send(DaveApiResponse::CompactionComplete(info)); 512 ctx.request_repaint(); 513 } else { 514 tracing::debug!("Received system message subtype: {}", system_msg.subtype); 515 } 516 } 517 ClaudeMessage::ControlCancelRequest(_) => { 518 // Ignore internal control messages 519 } 520 } 521 } 522 Some(Err(err)) => { 523 tracing::error!("Claude stream error: {}", err); 524 let _ = response_tx.send(DaveApiResponse::Failed(err.to_string())); 525 stream_done = true; 526 } 527 None => { 528 stream_done = true; 529 } 530 } 531 } 532 } 533 } 534 535 tracing::debug!("Query complete for session {}", session_id); 536 // Don't disconnect - keep the connection alive for subsequent queries 537 } 538 SessionCommand::Interrupt { ctx } => { 539 // Interrupt received when not in a query - just request repaint 540 tracing::debug!( 541 "Session {} received interrupt but no query active", 542 session_id 543 ); 544 ctx.request_repaint(); 545 } 546 SessionCommand::SetPermissionMode { mode, ctx } => { 547 tracing::debug!( 548 "Session {} setting permission mode to {:?}", 549 session_id, 550 mode 551 ); 552 if let Err(err) = client.set_permission_mode(mode).await { 553 tracing::error!("Failed to set permission mode: {}", err); 554 } 555 ctx.request_repaint(); 556 } 557 SessionCommand::Shutdown => { 558 tracing::debug!("Session actor {} shutting down", session_id); 559 break; 560 } 561 } 562 } 563 564 // Disconnect when shutting down 565 if let Err(err) = client.disconnect().await { 566 tracing::warn!("Error disconnecting session {}: {}", session_id, err); 567 } 568 tracing::debug!("Session {} actor exited", session_id); 569 } 570 571 impl AiBackend for ClaudeBackend { 572 fn stream_request( 573 &self, 574 messages: Vec<Message>, 575 _tools: Arc<HashMap<String, Tool>>, 576 _model: String, 577 _user_id: String, 578 session_id: String, 579 cwd: Option<PathBuf>, 580 resume_session_id: Option<String>, 581 ctx: egui::Context, 582 ) -> ( 583 mpsc::Receiver<DaveApiResponse>, 584 Option<tokio::task::JoinHandle<()>>, 585 ) { 586 let (response_tx, response_rx) = mpsc::channel(); 587 588 // Determine if this is the first message in the session 589 let is_first_message = messages 590 .iter() 591 .filter(|m| matches!(m, Message::User(_))) 592 .count() 593 == 1; 594 595 // For first message, send full prompt; for continuation, just the latest message 596 let prompt = if is_first_message { 597 Self::messages_to_prompt(&messages) 598 } else { 599 Self::get_latest_user_message(&messages) 600 }; 601 602 tracing::debug!( 603 "Sending request to Claude Code: session={}, is_first={}, prompt length: {}, preview: {:?}", 604 session_id, 605 is_first_message, 606 prompt.len(), 607 &prompt[..prompt.len().min(100)] 608 ); 609 610 // Get or create session actor 611 let command_tx = { 612 let entry = self.sessions.entry(session_id.clone()); 613 let handle = entry.or_insert_with(|| { 614 let (command_tx, command_rx) = tokio_mpsc::channel(16); 615 616 // Spawn session actor with cwd and optional resume session ID 617 let session_id_clone = session_id.clone(); 618 let cwd_clone = cwd.clone(); 619 let resume_session_id_clone = resume_session_id.clone(); 620 tokio::spawn(async move { 621 session_actor( 622 session_id_clone, 623 cwd_clone, 624 resume_session_id_clone, 625 command_rx, 626 ) 627 .await; 628 }); 629 630 SessionHandle { command_tx } 631 }); 632 handle.command_tx.clone() 633 }; 634 635 // Spawn a task to send the query command 636 let handle = tokio::spawn(async move { 637 if let Err(err) = command_tx 638 .send(SessionCommand::Query { 639 prompt, 640 response_tx, 641 ctx, 642 }) 643 .await 644 { 645 tracing::error!("Failed to send query command to session actor: {}", err); 646 } 647 }); 648 649 (response_rx, Some(handle)) 650 } 651 652 fn cleanup_session(&self, session_id: String) { 653 if let Some((_, handle)) = self.sessions.remove(&session_id) { 654 tokio::spawn(async move { 655 if let Err(err) = handle.command_tx.send(SessionCommand::Shutdown).await { 656 tracing::warn!("Failed to send shutdown command: {}", err); 657 } 658 }); 659 } 660 } 661 662 fn interrupt_session(&self, session_id: String, ctx: egui::Context) { 663 if let Some(handle) = self.sessions.get(&session_id) { 664 let command_tx = handle.command_tx.clone(); 665 tokio::spawn(async move { 666 if let Err(err) = command_tx.send(SessionCommand::Interrupt { ctx }).await { 667 tracing::warn!("Failed to send interrupt command: {}", err); 668 } 669 }); 670 } 671 } 672 673 fn set_permission_mode(&self, session_id: String, mode: PermissionMode, ctx: egui::Context) { 674 if let Some(handle) = self.sessions.get(&session_id) { 675 let command_tx = handle.command_tx.clone(); 676 tokio::spawn(async move { 677 if let Err(err) = command_tx 678 .send(SessionCommand::SetPermissionMode { mode, ctx }) 679 .await 680 { 681 tracing::warn!("Failed to send set_permission_mode command: {}", err); 682 } 683 }); 684 } else { 685 tracing::debug!( 686 "Session {} not active, permission mode will apply on next query", 687 session_id 688 ); 689 } 690 } 691 }