jobs-vibe-coded

A vibe coded work queue in rust
git clone git://jb55.com/jobs-vibe-coded
Log | Files | Refs

jobs.rs (5919B)


      1 
      2 use std::sync::{mpsc, Arc, Mutex};
      3 use std::thread;
      4 
      5 // A worker that stores no T data directly.
      6 struct Worker {
      7     id: usize,
      8     handle: Option<thread::JoinHandle<()>>,
      9 }
     10 
     11 impl Worker {
     12     /// Generic function that spawns a thread capable of processing `Message<T>`.
     13     /// This does the same job as `Worker<T>::new` did, but doesn't make `Worker` itself generic.
     14     fn spawn_worker_for<T: Send + 'static>(
     15         id: usize,
     16         receiver: Arc<Mutex<mpsc::Receiver<Message<T>>>>,
     17     ) -> Self {
     18         let handle = thread::spawn(move || loop {
     19             match receiver.lock().expect("lock poisoned").recv() {
     20                 Ok(Message::NewJob(data, job)) => {
     21                     println!("Worker {id} got a job; executing on data...");
     22                     (job)(data);
     23                 }
     24                 Ok(Message::Terminate) => {
     25                     println!("Worker {id} told to terminate.");
     26                     break;
     27                 }
     28                 Err(_) => {
     29                     println!("Worker {id} failed to receive message; terminating.");
     30                     break;
     31                 }
     32             }
     33         });
     34 
     35         Worker {
     36             id,
     37             handle: Some(handle),
     38         }
     39     }
     40 }
     41 
     42 // Our job and message definitions
     43 type Job<T> = Box<dyn FnOnce(T) + Send + 'static>;
     44 
     45 pub enum Message<T> {
     46     NewJob(T, Job<T>),
     47     Terminate,
     48 }
     49 
     50 // The thread pool is still generic over T
     51 pub struct ThreadPool<T> {
     52     workers: Vec<Worker>,
     53     sender: mpsc::Sender<Message<T>>,
     54 }
     55 
     56 impl<T: Send + 'static> ThreadPool<T> {
     57     pub fn new(size: usize) -> Self {
     58         assert!(size > 0);
     59         let (sender, receiver) = mpsc::channel::<Message<T>>();
     60         let receiver = Arc::new(Mutex::new(receiver));
     61 
     62         // Create plain `Worker` structs using a generic spawn function
     63         let mut workers = Vec::with_capacity(size);
     64         for id in 0..size {
     65             workers.push(Worker::spawn_worker_for(id, Arc::clone(&receiver)));
     66         }
     67 
     68         Self { workers, sender }
     69     }
     70 
     71     pub fn execute<F>(&self, data: T, f: F)
     72     where
     73         F: FnOnce(T) + Send + 'static,
     74     {
     75         let job = Box::new(f) as Job<T>;
     76         self.sender
     77             .send(Message::NewJob(data, job))
     78             .expect("Failed to send job to worker.");
     79     }
     80 }
     81 
     82 impl<T> Drop for ThreadPool<T> {
     83     fn drop(&mut self) {
     84         // Ask each worker to terminate
     85         for _ in &self.workers {
     86             let _ = self.sender.send(Message::Terminate);
     87         }
     88 
     89         // Join them
     90         for worker in &mut self.workers {
     91             if let Some(handle) = worker.handle.take() {
     92                 handle.join().expect("Worker thread panicked.");
     93             }
     94         }
     95     }
     96 }
     97 
     98 // Example usage
     99 fn main() {
    100     let pool = ThreadPool::new(4);
    101 
    102     for i in 0..8 {
    103         pool.execute(i, move |num| {
    104             println!("Data {num} processed in a worker thread.");
    105         });
    106     }
    107 }
    108 
    109 
    110 #[cfg(test)]
    111 mod tests {
    112     use super::*;
    113     use std::sync::{mpsc, Arc, Mutex};
    114     use std::thread;
    115     use std::time::Duration;
    116 
    117     /// Test basic task execution and verify that we can gather all results.
    118     #[test]
    119     fn test_execute_tasks() {
    120         let pool = ThreadPool::new(4);
    121         let (tx, rx) = mpsc::channel();
    122 
    123         // Send 8 jobs to the thread pool. Each job sends back its data via tx.
    124         for i in 0..8 {
    125             let tx_clone = tx.clone();
    126             pool.execute(i, move |num| {
    127                 // Simulate some work
    128                 thread::sleep(Duration::from_millis(10));
    129                 tx_clone.send(num).expect("Failed to send result");
    130             });
    131         }
    132 
    133         // Gather results
    134         let mut results = Vec::new();
    135         for _ in 0..8 {
    136             results.push(rx.recv().expect("Failed to receive result"));
    137         }
    138         results.sort();
    139         assert_eq!(results, (0..8).collect::<Vec<_>>());
    140     }
    141 
    142     /// Test that the thread pool does not allow a size of 0.
    143     /// This test should panic due to the assert!(size > 0) in `ThreadPool::new`.
    144     #[test]
    145     #[should_panic(expected = "size > 0")]
    146     fn test_zero_size_pool_should_panic() {
    147         let _ = ThreadPool::<i32>::new(0);
    148     }
    149 
    150     /// Test dropping the pool after submitting jobs.
    151     /// If everything is correct, the workers will gracefully terminate.
    152     #[test]
    153     fn test_dropping_pool() {
    154         let pool = ThreadPool::new(2);
    155         let (tx, rx) = mpsc::channel();
    156 
    157         for i in 0..4 {
    158             let tx_clone = tx.clone();
    159             pool.execute(i, move |num| {
    160                 // Simulate work
    161                 thread::sleep(Duration::from_millis(5));
    162                 tx_clone.send(num).unwrap();
    163             });
    164         }
    165 
    166         // Scope so pool goes out of scope and triggers drop after sending messages
    167         drop(pool);
    168 
    169         // We should still be able to receive all 4 messages
    170         let mut results = vec![];
    171         for _ in 0..4 {
    172             results.push(rx.recv().expect("Failed to receive result"));
    173         }
    174         results.sort();
    175         assert_eq!(results, vec![0, 1, 2, 3]);
    176     }
    177 
    178     /// A test that ensures multiple data types can be handled.
    179     /// In reality, we just need T: Send + 'static, but let's be explicit.
    180     #[test]
    181     fn test_execute_string_tasks() {
    182         let pool = ThreadPool::new(2);
    183         let (tx, rx) = mpsc::channel();
    184 
    185         let strings = vec!["alpha", "beta", "gamma", "delta"];
    186         for s in strings {
    187             let tx_clone = tx.clone();
    188             pool.execute(s.to_string(), move |st: String| {
    189                 // Just pass the string back
    190                 tx_clone.send(st).unwrap();
    191             });
    192         }
    193 
    194         // Collect results
    195         let mut results = Vec::new();
    196         for _ in 0..4 {
    197             let r = rx.recv().unwrap();
    198             results.push(r);
    199         }
    200         results.sort();
    201         assert_eq!(results, vec!["alpha", "beta", "delta", "gamma"]);
    202     }
    203 }