session_loader.rs (15282B)
1 //! Load a previous session's conversation from nostr events in ndb. 2 //! 3 //! Queries for kind-1988 events with a matching `d` tag (session ID), 4 //! orders them by created_at, and converts them into `Message` variants 5 //! for populating the chat UI. 6 7 use crate::messages::{AssistantMessage, ExecutedTool, PermissionRequest, PermissionResponseType}; 8 use crate::session::PermissionTracker; 9 use crate::session_events::{get_tag_value, is_conversation_role, AI_CONVERSATION_KIND}; 10 use crate::tools::ToolResponse; 11 use crate::Message; 12 use nostrdb::{Filter, Ndb, NoteKey, Transaction}; 13 use std::collections::{HashMap, HashSet}; 14 15 /// Query replaceable events via `ndb.fold`, deduplicating by `d` tag. 16 /// 17 /// nostrdb doesn't deduplicate replaceable events internally, so multiple 18 /// revisions of the same (kind, pubkey, d-tag) tuple may exist. This 19 /// folds over all matching notes and keeps only the one with the highest 20 /// `created_at` for each unique `d` tag value. 21 /// 22 /// Returns a Vec of `NoteKey`s for the winning notes (one per unique d-tag). 23 pub fn query_replaceable(ndb: &Ndb, txn: &Transaction, filters: &[Filter]) -> Vec<NoteKey> { 24 query_replaceable_filtered(ndb, txn, filters, |_| true) 25 } 26 27 /// Like `query_replaceable`, but with a predicate to filter notes. 28 /// 29 /// The predicate is called on the latest revision of each d-tag group. 30 /// If it returns false, that d-tag is removed from results (even if an 31 /// older revision would have passed). 32 pub fn query_replaceable_filtered( 33 ndb: &Ndb, 34 txn: &Transaction, 35 filters: &[Filter], 36 predicate: impl Fn(&nostrdb::Note) -> bool, 37 ) -> Vec<NoteKey> { 38 // Fold: for each d-tag value, track the latest created_at and optionally 39 // a NoteKey (only if the latest revision passes the predicate). 40 // Notes may arrive in any order from ndb.fold, so we always track the 41 // highest timestamp and only keep a key when that revision is valid. 42 let best = ndb.fold( 43 txn, 44 filters, 45 std::collections::HashMap::<String, (u64, Option<NoteKey>)>::new(), 46 |mut acc, note| { 47 let Some(d_tag) = get_tag_value(¬e, "d") else { 48 return acc; 49 }; 50 51 let created_at = note.created_at(); 52 53 if let Some((existing_ts, _)) = acc.get(d_tag) { 54 if created_at <= *existing_ts { 55 return acc; 56 } 57 } 58 59 let key = if predicate(¬e) { 60 Some(note.key().expect("note key")) 61 } else { 62 None 63 }; 64 65 acc.insert(d_tag.to_string(), (created_at, key)); 66 acc 67 }, 68 ); 69 70 match best { 71 Ok(map) => map.into_values().filter_map(|(_, key)| key).collect(), 72 Err(_) => vec![], 73 } 74 } 75 76 /// Result of loading session messages, including threading info for live events. 77 pub struct LoadedSession { 78 pub messages: Vec<Message>, 79 pub root_note_id: Option<[u8; 32]>, 80 pub last_note_id: Option<[u8; 32]>, 81 pub event_count: u32, 82 /// Permission state loaded from events (responded set + request note IDs). 83 pub permissions: PermissionTracker, 84 /// All note IDs found, for seeding dedup in live polling. 85 pub note_ids: HashSet<[u8; 32]>, 86 } 87 88 /// Load conversation messages from ndb for a given session ID. 89 /// 90 /// This queries for kind-1988 events with a `d` tag matching the session ID, 91 /// sorts them chronologically, and converts relevant roles into Messages. 92 pub fn load_session_messages(ndb: &Ndb, txn: &Transaction, session_id: &str) -> LoadedSession { 93 let filter = Filter::new() 94 .kinds([AI_CONVERSATION_KIND as u64]) 95 .tags([session_id], 'd') 96 .build(); 97 98 let results = match ndb.query(txn, &[filter], 10000) { 99 Ok(r) => r, 100 Err(_) => { 101 return LoadedSession { 102 messages: vec![], 103 root_note_id: None, 104 last_note_id: None, 105 event_count: 0, 106 permissions: PermissionTracker::new(), 107 note_ids: HashSet::new(), 108 }; 109 } 110 }; 111 112 // Collect notes with their created_at for sorting 113 let mut notes: Vec<_> = results 114 .iter() 115 .filter_map(|qr| ndb.get_note_by_key(txn, qr.note_key).ok()) 116 .collect(); 117 118 // Sort by created_at first, then by seq tag as tiebreaker for events 119 // within the same second (seq is per-session, not globally ordered) 120 notes.sort_by_key(|note| { 121 let seq = get_tag_value(note, "seq") 122 .and_then(|s| s.parse::<u32>().ok()) 123 .unwrap_or(0); 124 (note.created_at(), seq) 125 }); 126 127 let event_count = notes.len() as u32; 128 let note_ids: HashSet<[u8; 32]> = notes.iter().map(|n| *n.id()).collect(); 129 130 // Find the first conversation note (skip metadata like queue-operation) 131 // so the threading root is a real message. 132 let root_note_id = notes 133 .iter() 134 .find(|n| { 135 get_tag_value(n, "role") 136 .map(is_conversation_role) 137 .unwrap_or(false) 138 }) 139 .map(|n| *n.id()); 140 let last_note_id = notes.last().map(|n| *n.id()); 141 142 // First pass: collect responded permission IDs and perm request note IDs 143 let mut permissions = PermissionTracker::new(); 144 for note in ¬es { 145 let role = get_tag_value(note, "role"); 146 if role == Some("permission_response") { 147 if let Some(perm_id_str) = get_tag_value(note, "perm-id") { 148 if let Ok(perm_id) = uuid::Uuid::parse_str(perm_id_str) { 149 permissions.responded.insert(perm_id); 150 } 151 } 152 } else if role == Some("permission_request") { 153 if let Some(perm_id_str) = get_tag_value(note, "perm-id") { 154 if let Ok(perm_id) = uuid::Uuid::parse_str(perm_id_str) { 155 permissions.request_note_ids.insert(perm_id, *note.id()); 156 } 157 } 158 } 159 } 160 161 // Second pass: convert to messages 162 let mut messages = Vec::new(); 163 for note in ¬es { 164 let content = note.content(); 165 let role = get_tag_value(note, "role"); 166 167 let msg = match role { 168 Some("user") => Some(Message::User(content.to_string())), 169 Some("assistant") | Some("tool_call") => Some(Message::Assistant( 170 AssistantMessage::from_text(content.to_string()), 171 )), 172 Some("tool_result") => { 173 let summary = truncate(content, 200); 174 Some(Message::ToolResponse(ToolResponse::executed_tool( 175 ExecutedTool { 176 tool_name: get_tag_value(note, "tool-name") 177 .unwrap_or("tool") 178 .to_string(), 179 summary, 180 parent_task_id: None, 181 file_update: None, 182 }, 183 ))) 184 } 185 Some("permission_request") => { 186 if let Ok(content_json) = serde_json::from_str::<serde_json::Value>(content) { 187 let tool_name = content_json["tool_name"] 188 .as_str() 189 .unwrap_or("unknown") 190 .to_string(); 191 let tool_input = content_json 192 .get("tool_input") 193 .cloned() 194 .unwrap_or(serde_json::Value::Null); 195 let perm_id = get_tag_value(note, "perm-id") 196 .and_then(|s| uuid::Uuid::parse_str(s).ok()) 197 .unwrap_or_else(uuid::Uuid::new_v4); 198 199 let response = if permissions.responded.contains(&perm_id) { 200 Some(PermissionResponseType::Allowed) 201 } else { 202 None 203 }; 204 205 // Parse plan markdown for ExitPlanMode requests 206 let cached_plan = if tool_name == "ExitPlanMode" { 207 tool_input 208 .get("plan") 209 .and_then(|v| v.as_str()) 210 .map(crate::messages::ParsedMarkdown::parse) 211 } else { 212 None 213 }; 214 215 Some(Message::PermissionRequest(PermissionRequest { 216 id: perm_id, 217 tool_name, 218 tool_input, 219 response, 220 answer_summary: None, 221 cached_plan, 222 })) 223 } else { 224 None 225 } 226 } 227 // Skip permission_response, progress, queue-operation, etc. 228 _ => None, 229 }; 230 231 if let Some(msg) = msg { 232 messages.push(msg); 233 } 234 } 235 236 LoadedSession { 237 messages, 238 root_note_id, 239 last_note_id, 240 event_count, 241 permissions, 242 note_ids, 243 } 244 } 245 246 /// A persisted session state from a kind-31988 event. 247 pub struct SessionState { 248 pub claude_session_id: String, 249 pub title: String, 250 pub custom_title: Option<String>, 251 pub cwd: String, 252 pub status: String, 253 pub indicator: Option<String>, 254 pub hostname: String, 255 pub home_dir: String, 256 pub backend: Option<String>, 257 pub permission_mode: Option<String>, 258 pub created_at: u64, 259 /// Real CLI session ID when the d-tag is a provisional UUID. 260 /// Present only for sessions created via spawn commands. 261 /// Empty string means the backend hasn't started yet. 262 pub cli_session_id: Option<String>, 263 } 264 265 impl SessionState { 266 /// Build a SessionState from a kind-31988 note's tags. 267 /// 268 /// Returns None if the note has no d-tag (session ID). 269 pub fn from_note(note: &nostrdb::Note, session_id: Option<&str>) -> Option<Self> { 270 let claude_session_id = session_id 271 .map(|s| s.to_string()) 272 .or_else(|| get_tag_value(note, "d").map(|s| s.to_string()))?; 273 274 Some(SessionState { 275 claude_session_id, 276 title: get_tag_value(note, "title") 277 .unwrap_or("Untitled") 278 .to_string(), 279 custom_title: get_tag_value(note, "custom_title").map(|s| s.to_string()), 280 cwd: get_tag_value(note, "cwd").unwrap_or("").to_string(), 281 status: get_tag_value(note, "status").unwrap_or("idle").to_string(), 282 indicator: get_tag_value(note, "indicator").map(|s| s.to_string()), 283 hostname: get_tag_value(note, "hostname").unwrap_or("").to_string(), 284 home_dir: get_tag_value(note, "home_dir").unwrap_or("").to_string(), 285 backend: get_tag_value(note, "backend").map(|s| s.to_string()), 286 permission_mode: get_tag_value(note, "permission-mode").map(|s| s.to_string()), 287 created_at: note.created_at(), 288 cli_session_id: get_tag_value(note, "cli_session").map(|s| s.to_string()), 289 }) 290 } 291 } 292 293 /// Load all session states from kind-31988 events in ndb. 294 /// 295 /// Uses `query_replaceable_filtered` to deduplicate by d-tag, keeping 296 /// only the most recent non-deleted revision of each session state. 297 pub fn load_session_states(ndb: &Ndb, txn: &Transaction) -> Vec<SessionState> { 298 use crate::session_events::AI_SESSION_STATE_KIND; 299 300 let filter = Filter::new().kinds([AI_SESSION_STATE_KIND as u64]).build(); 301 302 let is_valid = |note: &nostrdb::Note| { 303 // Skip deleted sessions 304 if get_tag_value(note, "status") == Some("deleted") { 305 return false; 306 } 307 // Skip old JSON-content format events 308 if note.content().starts_with('{') { 309 return false; 310 } 311 true 312 }; 313 314 let note_keys = query_replaceable_filtered(ndb, txn, &[filter], is_valid); 315 316 let mut states = Vec::new(); 317 for key in note_keys { 318 let Ok(note) = ndb.get_note_by_key(txn, key) else { 319 continue; 320 }; 321 322 let Some(state) = SessionState::from_note(¬e, None) else { 323 continue; 324 }; 325 states.push(state); 326 } 327 328 states 329 } 330 331 /// Look up the latest valid revision of a single session by d-tag. 332 /// 333 /// PNS wrapping causes relays to store all revisions of replaceable 334 /// events. This queries for the latest revision and returns it only 335 /// if it's non-deleted and in the current format. 336 pub fn latest_valid_session( 337 ndb: &Ndb, 338 txn: &Transaction, 339 session_id: &str, 340 ) -> Option<SessionState> { 341 use crate::session_events::AI_SESSION_STATE_KIND; 342 343 let filter = Filter::new() 344 .kinds([AI_SESSION_STATE_KIND as u64]) 345 .tags([session_id], 'd') 346 .limit(1) 347 .build(); 348 349 let results = ndb.query(txn, &[filter], 1).ok()?; 350 let note = &results.first()?.note; 351 352 if get_tag_value(note, "status") == Some("deleted") { 353 return None; 354 } 355 if note.content().starts_with('{') { 356 return None; 357 } 358 359 SessionState::from_note(note, Some(session_id)) 360 } 361 362 /// Extract recent working directories grouped by hostname from kind-31988 363 /// session state events. 364 /// 365 /// Returns up to `MAX_RECENT_PER_HOST` unique paths per hostname, ordered 366 /// by most recently seen first. Useful for populating the directory picker 367 /// with previously used paths (both local and remote hosts). 368 pub fn load_recent_paths_by_host( 369 ndb: &Ndb, 370 txn: &Transaction, 371 ) -> HashMap<String, Vec<std::path::PathBuf>> { 372 use crate::session_events::AI_SESSION_STATE_KIND; 373 374 const MAX_RECENT_PER_HOST: usize = 10; 375 376 let filter = Filter::new().kinds([AI_SESSION_STATE_KIND as u64]).build(); 377 378 let is_valid = |note: &nostrdb::Note| { 379 if get_tag_value(note, "status") == Some("deleted") { 380 return false; 381 } 382 if note.content().starts_with('{') { 383 return false; 384 } 385 true 386 }; 387 388 let note_keys = query_replaceable_filtered(ndb, txn, &[filter], is_valid); 389 390 // Collect (hostname, cwd, created_at) triples 391 let mut entries: Vec<(String, String, u64)> = Vec::new(); 392 for key in note_keys { 393 let Ok(note) = ndb.get_note_by_key(txn, key) else { 394 continue; 395 }; 396 let hostname = get_tag_value(¬e, "hostname").unwrap_or("").to_string(); 397 let cwd = get_tag_value(¬e, "cwd").unwrap_or("").to_string(); 398 if cwd.is_empty() { 399 continue; 400 } 401 entries.push((hostname, cwd, note.created_at())); 402 } 403 404 // Sort by created_at descending (most recent first) 405 entries.sort_by(|a, b| b.2.cmp(&a.2)); 406 407 // Group by hostname, dedup cwds, cap per host 408 let mut result: HashMap<String, Vec<std::path::PathBuf>> = HashMap::new(); 409 for (hostname, cwd, _) in entries { 410 let paths = result.entry(hostname).or_default(); 411 let path = std::path::PathBuf::from(&cwd); 412 if !paths.contains(&path) && paths.len() < MAX_RECENT_PER_HOST { 413 paths.push(path); 414 } 415 } 416 417 result 418 } 419 420 pub(crate) fn truncate(s: &str, max_chars: usize) -> String { 421 if s.chars().count() <= max_chars { 422 s.to_string() 423 } else { 424 let truncated: String = s.chars().take(max_chars).collect(); 425 format!("{}...", truncated) 426 } 427 }