jobs.rs (5919B)
1 2 use std::sync::{mpsc, Arc, Mutex}; 3 use std::thread; 4 5 // A worker that stores no T data directly. 6 struct Worker { 7 id: usize, 8 handle: Option<thread::JoinHandle<()>>, 9 } 10 11 impl Worker { 12 /// Generic function that spawns a thread capable of processing `Message<T>`. 13 /// This does the same job as `Worker<T>::new` did, but doesn't make `Worker` itself generic. 14 fn spawn_worker_for<T: Send + 'static>( 15 id: usize, 16 receiver: Arc<Mutex<mpsc::Receiver<Message<T>>>>, 17 ) -> Self { 18 let handle = thread::spawn(move || loop { 19 match receiver.lock().expect("lock poisoned").recv() { 20 Ok(Message::NewJob(data, job)) => { 21 println!("Worker {id} got a job; executing on data..."); 22 (job)(data); 23 } 24 Ok(Message::Terminate) => { 25 println!("Worker {id} told to terminate."); 26 break; 27 } 28 Err(_) => { 29 println!("Worker {id} failed to receive message; terminating."); 30 break; 31 } 32 } 33 }); 34 35 Worker { 36 id, 37 handle: Some(handle), 38 } 39 } 40 } 41 42 // Our job and message definitions 43 type Job<T> = Box<dyn FnOnce(T) + Send + 'static>; 44 45 pub enum Message<T> { 46 NewJob(T, Job<T>), 47 Terminate, 48 } 49 50 // The thread pool is still generic over T 51 pub struct ThreadPool<T> { 52 workers: Vec<Worker>, 53 sender: mpsc::Sender<Message<T>>, 54 } 55 56 impl<T: Send + 'static> ThreadPool<T> { 57 pub fn new(size: usize) -> Self { 58 assert!(size > 0); 59 let (sender, receiver) = mpsc::channel::<Message<T>>(); 60 let receiver = Arc::new(Mutex::new(receiver)); 61 62 // Create plain `Worker` structs using a generic spawn function 63 let mut workers = Vec::with_capacity(size); 64 for id in 0..size { 65 workers.push(Worker::spawn_worker_for(id, Arc::clone(&receiver))); 66 } 67 68 Self { workers, sender } 69 } 70 71 pub fn execute<F>(&self, data: T, f: F) 72 where 73 F: FnOnce(T) + Send + 'static, 74 { 75 let job = Box::new(f) as Job<T>; 76 self.sender 77 .send(Message::NewJob(data, job)) 78 .expect("Failed to send job to worker."); 79 } 80 } 81 82 impl<T> Drop for ThreadPool<T> { 83 fn drop(&mut self) { 84 // Ask each worker to terminate 85 for _ in &self.workers { 86 let _ = self.sender.send(Message::Terminate); 87 } 88 89 // Join them 90 for worker in &mut self.workers { 91 if let Some(handle) = worker.handle.take() { 92 handle.join().expect("Worker thread panicked."); 93 } 94 } 95 } 96 } 97 98 // Example usage 99 fn main() { 100 let pool = ThreadPool::new(4); 101 102 for i in 0..8 { 103 pool.execute(i, move |num| { 104 println!("Data {num} processed in a worker thread."); 105 }); 106 } 107 } 108 109 110 #[cfg(test)] 111 mod tests { 112 use super::*; 113 use std::sync::{mpsc, Arc, Mutex}; 114 use std::thread; 115 use std::time::Duration; 116 117 /// Test basic task execution and verify that we can gather all results. 118 #[test] 119 fn test_execute_tasks() { 120 let pool = ThreadPool::new(4); 121 let (tx, rx) = mpsc::channel(); 122 123 // Send 8 jobs to the thread pool. Each job sends back its data via tx. 124 for i in 0..8 { 125 let tx_clone = tx.clone(); 126 pool.execute(i, move |num| { 127 // Simulate some work 128 thread::sleep(Duration::from_millis(10)); 129 tx_clone.send(num).expect("Failed to send result"); 130 }); 131 } 132 133 // Gather results 134 let mut results = Vec::new(); 135 for _ in 0..8 { 136 results.push(rx.recv().expect("Failed to receive result")); 137 } 138 results.sort(); 139 assert_eq!(results, (0..8).collect::<Vec<_>>()); 140 } 141 142 /// Test that the thread pool does not allow a size of 0. 143 /// This test should panic due to the assert!(size > 0) in `ThreadPool::new`. 144 #[test] 145 #[should_panic(expected = "size > 0")] 146 fn test_zero_size_pool_should_panic() { 147 let _ = ThreadPool::<i32>::new(0); 148 } 149 150 /// Test dropping the pool after submitting jobs. 151 /// If everything is correct, the workers will gracefully terminate. 152 #[test] 153 fn test_dropping_pool() { 154 let pool = ThreadPool::new(2); 155 let (tx, rx) = mpsc::channel(); 156 157 for i in 0..4 { 158 let tx_clone = tx.clone(); 159 pool.execute(i, move |num| { 160 // Simulate work 161 thread::sleep(Duration::from_millis(5)); 162 tx_clone.send(num).unwrap(); 163 }); 164 } 165 166 // Scope so pool goes out of scope and triggers drop after sending messages 167 drop(pool); 168 169 // We should still be able to receive all 4 messages 170 let mut results = vec![]; 171 for _ in 0..4 { 172 results.push(rx.recv().expect("Failed to receive result")); 173 } 174 results.sort(); 175 assert_eq!(results, vec![0, 1, 2, 3]); 176 } 177 178 /// A test that ensures multiple data types can be handled. 179 /// In reality, we just need T: Send + 'static, but let's be explicit. 180 #[test] 181 fn test_execute_string_tasks() { 182 let pool = ThreadPool::new(2); 183 let (tx, rx) = mpsc::channel(); 184 185 let strings = vec!["alpha", "beta", "gamma", "delta"]; 186 for s in strings { 187 let tx_clone = tx.clone(); 188 pool.execute(s.to_string(), move |st: String| { 189 // Just pass the string back 190 tx_clone.send(st).unwrap(); 191 }); 192 } 193 194 // Collect results 195 let mut results = Vec::new(); 196 for _ in 0..4 { 197 let r = rx.recv().unwrap(); 198 results.push(r); 199 } 200 results.sort(); 201 assert_eq!(results, vec!["alpha", "beta", "delta", "gamma"]); 202 } 203 }