notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

model_cache.rs (5806B)


      1 use std::collections::HashMap;
      2 use std::path::PathBuf;
      3 
      4 use poll_promise::Promise;
      5 use sha2::{Digest, Sha256};
      6 
      7 /// Status of a model fetch operation.
      8 enum ModelFetchStatus {
      9     /// HTTP download in progress.
     10     Downloading(Promise<Result<PathBuf, String>>),
     11     /// Downloaded to disk, ready for GPU load on next poll.
     12     ReadyToLoad(PathBuf),
     13     /// Model handle assigned; terminal state.
     14     Loaded,
     15     /// Download or load failed; terminal state.
     16     Failed,
     17 }
     18 
     19 /// Manages async downloading and disk caching of remote 3D models.
     20 ///
     21 /// Local file paths are passed through unchanged.
     22 /// HTTP/HTTPS URLs are downloaded via `ehttp`, cached to disk under
     23 /// a sha256-hashed filename, and then loaded from the cache path.
     24 pub struct ModelCache {
     25     cache_dir: PathBuf,
     26     fetches: HashMap<String, ModelFetchStatus>,
     27 }
     28 
     29 impl ModelCache {
     30     pub fn new(cache_dir: PathBuf) -> Self {
     31         let _ = std::fs::create_dir_all(&cache_dir);
     32         Self {
     33             cache_dir,
     34             fetches: HashMap::new(),
     35         }
     36     }
     37 
     38     /// Returns true if `url` is an HTTP or HTTPS URL.
     39     fn is_remote(url: &str) -> bool {
     40         url.starts_with("http://") || url.starts_with("https://")
     41     }
     42 
     43     /// Compute on-disk cache path: `<cache_dir>/<sha256(url)>.<ext>`.
     44     fn cache_path(&self, url: &str) -> PathBuf {
     45         let mut hasher = Sha256::new();
     46         hasher.update(url.as_bytes());
     47         let hash = format!("{:x}", hasher.finalize());
     48 
     49         let ext = std::path::Path::new(url)
     50             .extension()
     51             .and_then(|e| e.to_str())
     52             .unwrap_or("glb");
     53 
     54         self.cache_dir.join(format!("{hash}.{ext}"))
     55     }
     56 
     57     /// Request a model by URL.
     58     ///
     59     /// - Local paths: returns `Some(PathBuf)` immediately.
     60     /// - Cached remote URLs: returns `Some(PathBuf)` from disk cache.
     61     /// - Uncached remote URLs: initiates async download, returns `None`.
     62     ///   The download result will be available via [`poll`] on a later frame.
     63     pub fn request(&mut self, url: &str) -> Option<PathBuf> {
     64         if !Self::is_remote(url) {
     65             return Some(PathBuf::from(url));
     66         }
     67 
     68         if let Some(status) = self.fetches.get(url) {
     69             return match status {
     70                 ModelFetchStatus::ReadyToLoad(path) => Some(path.clone()),
     71                 ModelFetchStatus::Loaded
     72                 | ModelFetchStatus::Failed
     73                 | ModelFetchStatus::Downloading(_) => None,
     74             };
     75         }
     76 
     77         // Check disk cache
     78         let cached = self.cache_path(url);
     79         if cached.exists() {
     80             tracing::info!("Model cache hit: {}", url);
     81             self.fetches.insert(
     82                 url.to_owned(),
     83                 ModelFetchStatus::ReadyToLoad(cached.clone()),
     84             );
     85             return Some(cached);
     86         }
     87 
     88         // Start async download
     89         tracing::info!("Downloading model: {}", url);
     90         let (sender, promise) = Promise::new();
     91         let target_path = cached;
     92         let request = ehttp::Request::get(url);
     93 
     94         let url_owned = url.to_owned();
     95         ehttp::fetch(request, move |response: Result<ehttp::Response, String>| {
     96             let result = (|| -> Result<PathBuf, String> {
     97                 let resp = response.map_err(|e| format!("HTTP error: {e}"))?;
     98                 if !resp.ok {
     99                     return Err(format!("HTTP {}: {}", resp.status, resp.status_text));
    100                 }
    101                 if resp.bytes.is_empty() {
    102                     return Err("Empty response body".to_string());
    103                 }
    104 
    105                 if let Some(parent) = target_path.parent() {
    106                     std::fs::create_dir_all(parent).map_err(|e| format!("mkdir: {e}"))?;
    107                 }
    108 
    109                 // Atomic write: .tmp then rename
    110                 let tmp_path = target_path.with_extension("tmp");
    111                 std::fs::write(&tmp_path, &resp.bytes).map_err(|e| format!("write: {e}"))?;
    112                 std::fs::rename(&tmp_path, &target_path).map_err(|e| format!("rename: {e}"))?;
    113 
    114                 tracing::info!("Cached {} bytes for {}", resp.bytes.len(), url_owned);
    115                 Ok(target_path)
    116             })();
    117             sender.send(result);
    118         });
    119 
    120         self.fetches
    121             .insert(url.to_owned(), ModelFetchStatus::Downloading(promise));
    122         None
    123     }
    124 
    125     /// Poll in-flight downloads. Returns URLs whose files are now ready to load.
    126     pub fn poll(&mut self) -> Vec<(String, PathBuf)> {
    127         let mut ready = Vec::new();
    128         let keys: Vec<String> = self.fetches.keys().cloned().collect();
    129 
    130         for url in keys {
    131             let needs_transition = {
    132                 let status = self.fetches.get_mut(&url).unwrap();
    133                 if let ModelFetchStatus::Downloading(promise) = status {
    134                     promise.ready().is_some()
    135                 } else {
    136                     false
    137                 }
    138             };
    139 
    140             if needs_transition
    141                 && let Some(ModelFetchStatus::Downloading(promise)) = self.fetches.remove(&url)
    142             {
    143                 match promise.block_and_take() {
    144                     Ok(path) => {
    145                         ready.push((url.clone(), path.clone()));
    146                         self.fetches
    147                             .insert(url, ModelFetchStatus::ReadyToLoad(path));
    148                     }
    149                     Err(e) => {
    150                         tracing::warn!("Model download failed for {}: {}", url, e);
    151                         self.fetches.insert(url, ModelFetchStatus::Failed);
    152                     }
    153                 }
    154             }
    155         }
    156 
    157         ready
    158     }
    159 
    160     /// Mark a URL as fully loaded (model handle assigned).
    161     pub fn mark_loaded(&mut self, url: &str) {
    162         self.fetches
    163             .insert(url.to_owned(), ModelFetchStatus::Loaded);
    164     }
    165 }