notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

session_loader.rs (15282B)


      1 //! Load a previous session's conversation from nostr events in ndb.
      2 //!
      3 //! Queries for kind-1988 events with a matching `d` tag (session ID),
      4 //! orders them by created_at, and converts them into `Message` variants
      5 //! for populating the chat UI.
      6 
      7 use crate::messages::{AssistantMessage, ExecutedTool, PermissionRequest, PermissionResponseType};
      8 use crate::session::PermissionTracker;
      9 use crate::session_events::{get_tag_value, is_conversation_role, AI_CONVERSATION_KIND};
     10 use crate::tools::ToolResponse;
     11 use crate::Message;
     12 use nostrdb::{Filter, Ndb, NoteKey, Transaction};
     13 use std::collections::{HashMap, HashSet};
     14 
     15 /// Query replaceable events via `ndb.fold`, deduplicating by `d` tag.
     16 ///
     17 /// nostrdb doesn't deduplicate replaceable events internally, so multiple
     18 /// revisions of the same (kind, pubkey, d-tag) tuple may exist. This
     19 /// folds over all matching notes and keeps only the one with the highest
     20 /// `created_at` for each unique `d` tag value.
     21 ///
     22 /// Returns a Vec of `NoteKey`s for the winning notes (one per unique d-tag).
     23 pub fn query_replaceable(ndb: &Ndb, txn: &Transaction, filters: &[Filter]) -> Vec<NoteKey> {
     24     query_replaceable_filtered(ndb, txn, filters, |_| true)
     25 }
     26 
     27 /// Like `query_replaceable`, but with a predicate to filter notes.
     28 ///
     29 /// The predicate is called on the latest revision of each d-tag group.
     30 /// If it returns false, that d-tag is removed from results (even if an
     31 /// older revision would have passed).
     32 pub fn query_replaceable_filtered(
     33     ndb: &Ndb,
     34     txn: &Transaction,
     35     filters: &[Filter],
     36     predicate: impl Fn(&nostrdb::Note) -> bool,
     37 ) -> Vec<NoteKey> {
     38     // Fold: for each d-tag value, track the latest created_at and optionally
     39     // a NoteKey (only if the latest revision passes the predicate).
     40     // Notes may arrive in any order from ndb.fold, so we always track the
     41     // highest timestamp and only keep a key when that revision is valid.
     42     let best = ndb.fold(
     43         txn,
     44         filters,
     45         std::collections::HashMap::<String, (u64, Option<NoteKey>)>::new(),
     46         |mut acc, note| {
     47             let Some(d_tag) = get_tag_value(&note, "d") else {
     48                 return acc;
     49             };
     50 
     51             let created_at = note.created_at();
     52 
     53             if let Some((existing_ts, _)) = acc.get(d_tag) {
     54                 if created_at <= *existing_ts {
     55                     return acc;
     56                 }
     57             }
     58 
     59             let key = if predicate(&note) {
     60                 Some(note.key().expect("note key"))
     61             } else {
     62                 None
     63             };
     64 
     65             acc.insert(d_tag.to_string(), (created_at, key));
     66             acc
     67         },
     68     );
     69 
     70     match best {
     71         Ok(map) => map.into_values().filter_map(|(_, key)| key).collect(),
     72         Err(_) => vec![],
     73     }
     74 }
     75 
     76 /// Result of loading session messages, including threading info for live events.
     77 pub struct LoadedSession {
     78     pub messages: Vec<Message>,
     79     pub root_note_id: Option<[u8; 32]>,
     80     pub last_note_id: Option<[u8; 32]>,
     81     pub event_count: u32,
     82     /// Permission state loaded from events (responded set + request note IDs).
     83     pub permissions: PermissionTracker,
     84     /// All note IDs found, for seeding dedup in live polling.
     85     pub note_ids: HashSet<[u8; 32]>,
     86 }
     87 
     88 /// Load conversation messages from ndb for a given session ID.
     89 ///
     90 /// This queries for kind-1988 events with a `d` tag matching the session ID,
     91 /// sorts them chronologically, and converts relevant roles into Messages.
     92 pub fn load_session_messages(ndb: &Ndb, txn: &Transaction, session_id: &str) -> LoadedSession {
     93     let filter = Filter::new()
     94         .kinds([AI_CONVERSATION_KIND as u64])
     95         .tags([session_id], 'd')
     96         .build();
     97 
     98     let results = match ndb.query(txn, &[filter], 10000) {
     99         Ok(r) => r,
    100         Err(_) => {
    101             return LoadedSession {
    102                 messages: vec![],
    103                 root_note_id: None,
    104                 last_note_id: None,
    105                 event_count: 0,
    106                 permissions: PermissionTracker::new(),
    107                 note_ids: HashSet::new(),
    108             };
    109         }
    110     };
    111 
    112     // Collect notes with their created_at for sorting
    113     let mut notes: Vec<_> = results
    114         .iter()
    115         .filter_map(|qr| ndb.get_note_by_key(txn, qr.note_key).ok())
    116         .collect();
    117 
    118     // Sort by created_at first, then by seq tag as tiebreaker for events
    119     // within the same second (seq is per-session, not globally ordered)
    120     notes.sort_by_key(|note| {
    121         let seq = get_tag_value(note, "seq")
    122             .and_then(|s| s.parse::<u32>().ok())
    123             .unwrap_or(0);
    124         (note.created_at(), seq)
    125     });
    126 
    127     let event_count = notes.len() as u32;
    128     let note_ids: HashSet<[u8; 32]> = notes.iter().map(|n| *n.id()).collect();
    129 
    130     // Find the first conversation note (skip metadata like queue-operation)
    131     // so the threading root is a real message.
    132     let root_note_id = notes
    133         .iter()
    134         .find(|n| {
    135             get_tag_value(n, "role")
    136                 .map(is_conversation_role)
    137                 .unwrap_or(false)
    138         })
    139         .map(|n| *n.id());
    140     let last_note_id = notes.last().map(|n| *n.id());
    141 
    142     // First pass: collect responded permission IDs and perm request note IDs
    143     let mut permissions = PermissionTracker::new();
    144     for note in &notes {
    145         let role = get_tag_value(note, "role");
    146         if role == Some("permission_response") {
    147             if let Some(perm_id_str) = get_tag_value(note, "perm-id") {
    148                 if let Ok(perm_id) = uuid::Uuid::parse_str(perm_id_str) {
    149                     permissions.responded.insert(perm_id);
    150                 }
    151             }
    152         } else if role == Some("permission_request") {
    153             if let Some(perm_id_str) = get_tag_value(note, "perm-id") {
    154                 if let Ok(perm_id) = uuid::Uuid::parse_str(perm_id_str) {
    155                     permissions.request_note_ids.insert(perm_id, *note.id());
    156                 }
    157             }
    158         }
    159     }
    160 
    161     // Second pass: convert to messages
    162     let mut messages = Vec::new();
    163     for note in &notes {
    164         let content = note.content();
    165         let role = get_tag_value(note, "role");
    166 
    167         let msg = match role {
    168             Some("user") => Some(Message::User(content.to_string())),
    169             Some("assistant") | Some("tool_call") => Some(Message::Assistant(
    170                 AssistantMessage::from_text(content.to_string()),
    171             )),
    172             Some("tool_result") => {
    173                 let summary = truncate(content, 200);
    174                 Some(Message::ToolResponse(ToolResponse::executed_tool(
    175                     ExecutedTool {
    176                         tool_name: get_tag_value(note, "tool-name")
    177                             .unwrap_or("tool")
    178                             .to_string(),
    179                         summary,
    180                         parent_task_id: None,
    181                         file_update: None,
    182                     },
    183                 )))
    184             }
    185             Some("permission_request") => {
    186                 if let Ok(content_json) = serde_json::from_str::<serde_json::Value>(content) {
    187                     let tool_name = content_json["tool_name"]
    188                         .as_str()
    189                         .unwrap_or("unknown")
    190                         .to_string();
    191                     let tool_input = content_json
    192                         .get("tool_input")
    193                         .cloned()
    194                         .unwrap_or(serde_json::Value::Null);
    195                     let perm_id = get_tag_value(note, "perm-id")
    196                         .and_then(|s| uuid::Uuid::parse_str(s).ok())
    197                         .unwrap_or_else(uuid::Uuid::new_v4);
    198 
    199                     let response = if permissions.responded.contains(&perm_id) {
    200                         Some(PermissionResponseType::Allowed)
    201                     } else {
    202                         None
    203                     };
    204 
    205                     // Parse plan markdown for ExitPlanMode requests
    206                     let cached_plan = if tool_name == "ExitPlanMode" {
    207                         tool_input
    208                             .get("plan")
    209                             .and_then(|v| v.as_str())
    210                             .map(crate::messages::ParsedMarkdown::parse)
    211                     } else {
    212                         None
    213                     };
    214 
    215                     Some(Message::PermissionRequest(PermissionRequest {
    216                         id: perm_id,
    217                         tool_name,
    218                         tool_input,
    219                         response,
    220                         answer_summary: None,
    221                         cached_plan,
    222                     }))
    223                 } else {
    224                     None
    225                 }
    226             }
    227             // Skip permission_response, progress, queue-operation, etc.
    228             _ => None,
    229         };
    230 
    231         if let Some(msg) = msg {
    232             messages.push(msg);
    233         }
    234     }
    235 
    236     LoadedSession {
    237         messages,
    238         root_note_id,
    239         last_note_id,
    240         event_count,
    241         permissions,
    242         note_ids,
    243     }
    244 }
    245 
    246 /// A persisted session state from a kind-31988 event.
    247 pub struct SessionState {
    248     pub claude_session_id: String,
    249     pub title: String,
    250     pub custom_title: Option<String>,
    251     pub cwd: String,
    252     pub status: String,
    253     pub indicator: Option<String>,
    254     pub hostname: String,
    255     pub home_dir: String,
    256     pub backend: Option<String>,
    257     pub permission_mode: Option<String>,
    258     pub created_at: u64,
    259     /// Real CLI session ID when the d-tag is a provisional UUID.
    260     /// Present only for sessions created via spawn commands.
    261     /// Empty string means the backend hasn't started yet.
    262     pub cli_session_id: Option<String>,
    263 }
    264 
    265 impl SessionState {
    266     /// Build a SessionState from a kind-31988 note's tags.
    267     ///
    268     /// Returns None if the note has no d-tag (session ID).
    269     pub fn from_note(note: &nostrdb::Note, session_id: Option<&str>) -> Option<Self> {
    270         let claude_session_id = session_id
    271             .map(|s| s.to_string())
    272             .or_else(|| get_tag_value(note, "d").map(|s| s.to_string()))?;
    273 
    274         Some(SessionState {
    275             claude_session_id,
    276             title: get_tag_value(note, "title")
    277                 .unwrap_or("Untitled")
    278                 .to_string(),
    279             custom_title: get_tag_value(note, "custom_title").map(|s| s.to_string()),
    280             cwd: get_tag_value(note, "cwd").unwrap_or("").to_string(),
    281             status: get_tag_value(note, "status").unwrap_or("idle").to_string(),
    282             indicator: get_tag_value(note, "indicator").map(|s| s.to_string()),
    283             hostname: get_tag_value(note, "hostname").unwrap_or("").to_string(),
    284             home_dir: get_tag_value(note, "home_dir").unwrap_or("").to_string(),
    285             backend: get_tag_value(note, "backend").map(|s| s.to_string()),
    286             permission_mode: get_tag_value(note, "permission-mode").map(|s| s.to_string()),
    287             created_at: note.created_at(),
    288             cli_session_id: get_tag_value(note, "cli_session").map(|s| s.to_string()),
    289         })
    290     }
    291 }
    292 
    293 /// Load all session states from kind-31988 events in ndb.
    294 ///
    295 /// Uses `query_replaceable_filtered` to deduplicate by d-tag, keeping
    296 /// only the most recent non-deleted revision of each session state.
    297 pub fn load_session_states(ndb: &Ndb, txn: &Transaction) -> Vec<SessionState> {
    298     use crate::session_events::AI_SESSION_STATE_KIND;
    299 
    300     let filter = Filter::new().kinds([AI_SESSION_STATE_KIND as u64]).build();
    301 
    302     let is_valid = |note: &nostrdb::Note| {
    303         // Skip deleted sessions
    304         if get_tag_value(note, "status") == Some("deleted") {
    305             return false;
    306         }
    307         // Skip old JSON-content format events
    308         if note.content().starts_with('{') {
    309             return false;
    310         }
    311         true
    312     };
    313 
    314     let note_keys = query_replaceable_filtered(ndb, txn, &[filter], is_valid);
    315 
    316     let mut states = Vec::new();
    317     for key in note_keys {
    318         let Ok(note) = ndb.get_note_by_key(txn, key) else {
    319             continue;
    320         };
    321 
    322         let Some(state) = SessionState::from_note(&note, None) else {
    323             continue;
    324         };
    325         states.push(state);
    326     }
    327 
    328     states
    329 }
    330 
    331 /// Look up the latest valid revision of a single session by d-tag.
    332 ///
    333 /// PNS wrapping causes relays to store all revisions of replaceable
    334 /// events. This queries for the latest revision and returns it only
    335 /// if it's non-deleted and in the current format.
    336 pub fn latest_valid_session(
    337     ndb: &Ndb,
    338     txn: &Transaction,
    339     session_id: &str,
    340 ) -> Option<SessionState> {
    341     use crate::session_events::AI_SESSION_STATE_KIND;
    342 
    343     let filter = Filter::new()
    344         .kinds([AI_SESSION_STATE_KIND as u64])
    345         .tags([session_id], 'd')
    346         .limit(1)
    347         .build();
    348 
    349     let results = ndb.query(txn, &[filter], 1).ok()?;
    350     let note = &results.first()?.note;
    351 
    352     if get_tag_value(note, "status") == Some("deleted") {
    353         return None;
    354     }
    355     if note.content().starts_with('{') {
    356         return None;
    357     }
    358 
    359     SessionState::from_note(note, Some(session_id))
    360 }
    361 
    362 /// Extract recent working directories grouped by hostname from kind-31988
    363 /// session state events.
    364 ///
    365 /// Returns up to `MAX_RECENT_PER_HOST` unique paths per hostname, ordered
    366 /// by most recently seen first. Useful for populating the directory picker
    367 /// with previously used paths (both local and remote hosts).
    368 pub fn load_recent_paths_by_host(
    369     ndb: &Ndb,
    370     txn: &Transaction,
    371 ) -> HashMap<String, Vec<std::path::PathBuf>> {
    372     use crate::session_events::AI_SESSION_STATE_KIND;
    373 
    374     const MAX_RECENT_PER_HOST: usize = 10;
    375 
    376     let filter = Filter::new().kinds([AI_SESSION_STATE_KIND as u64]).build();
    377 
    378     let is_valid = |note: &nostrdb::Note| {
    379         if get_tag_value(note, "status") == Some("deleted") {
    380             return false;
    381         }
    382         if note.content().starts_with('{') {
    383             return false;
    384         }
    385         true
    386     };
    387 
    388     let note_keys = query_replaceable_filtered(ndb, txn, &[filter], is_valid);
    389 
    390     // Collect (hostname, cwd, created_at) triples
    391     let mut entries: Vec<(String, String, u64)> = Vec::new();
    392     for key in note_keys {
    393         let Ok(note) = ndb.get_note_by_key(txn, key) else {
    394             continue;
    395         };
    396         let hostname = get_tag_value(&note, "hostname").unwrap_or("").to_string();
    397         let cwd = get_tag_value(&note, "cwd").unwrap_or("").to_string();
    398         if cwd.is_empty() {
    399             continue;
    400         }
    401         entries.push((hostname, cwd, note.created_at()));
    402     }
    403 
    404     // Sort by created_at descending (most recent first)
    405     entries.sort_by(|a, b| b.2.cmp(&a.2));
    406 
    407     // Group by hostname, dedup cwds, cap per host
    408     let mut result: HashMap<String, Vec<std::path::PathBuf>> = HashMap::new();
    409     for (hostname, cwd, _) in entries {
    410         let paths = result.entry(hostname).or_default();
    411         let path = std::path::PathBuf::from(&cwd);
    412         if !paths.contains(&path) && paths.len() < MAX_RECENT_PER_HOST {
    413             paths.push(path);
    414         }
    415     }
    416 
    417     result
    418 }
    419 
    420 pub(crate) fn truncate(s: &str, max_chars: usize) -> String {
    421     if s.chars().count() <= max_chars {
    422         s.to_string()
    423     } else {
    424         let truncated: String = s.chars().take(max_chars).collect();
    425         format!("{}...", truncated)
    426     }
    427 }