diff --git a/Cargo.lock b/Cargo.lock index 214d9de9..464623f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4472,6 +4472,7 @@ dependencies = [ "bytes", "clap 3.1.3", "colored", + "crossbeam-channel", "csv", "dirs", "either", @@ -4480,6 +4481,7 @@ dependencies = [ "itertools", "libc", "num", + "num_cpus", "once_cell", "pretty_assertions", "rayon", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index 290e9174..81b3e209 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -22,12 +22,14 @@ backtrace = "0.3" bytes = { version = "1", optional = true } clap = { version = "3", features = ["derive"] } colored = "2" +crossbeam-channel = "0.5" csv = "1" dirs = "4" either = "1" hyper = { version = "0.14", optional = true } itertools = "0.10" num = "0.4" +num_cpus = "1" once_cell = "1" rayon = "1.5" serde = { version = "1", features = ["derive"] } diff --git a/crates/cli/predict.rs b/crates/cli/predict.rs index 7ed782d0..9e3ad61a 100644 --- a/crates/cli/predict.rs +++ b/crates/cli/predict.rs @@ -1,7 +1,11 @@ use crate::PredictArgs; use anyhow::Result; +use crossbeam_channel::Sender; +use csv::StringRecord; use either::Either; use itertools::Itertools; +use std::sync::Arc; +use tangram_core::predict::PredictOutput; use tangram_core::predict::{PredictInput, PredictInputValue, PredictOptions}; use tangram_zip::zip; @@ -68,25 +72,89 @@ pub fn predict(args: PredictArgs) -> Result<()> { } } }; + let header = reader.headers()?.to_owned(); - for records in &reader.records().chunks(PREDICT_CHUNK_SIZE) { - let input: Vec = records - .into_iter() - .map(|record| -> Result { - let record = record?; - let input = zip!(header.iter(), record.into_iter()) - .map(|(column_name, value)| { - ( - column_name.to_owned(), - PredictInputValue::String(value.to_owned()), - ) + let chunk_count = num_cpus::get() * 2; + let (input_tx, input_rx): ( + Sender<( + Vec, + Sender, anyhow::Error>>, + )>, + _, + ) = crossbeam_channel::bounded(chunk_count); + let (output_tx, output_rx) = crossbeam_channel::bounded(chunk_count); + + let header = Arc::new(header); + let model = Arc::new(model); + let options = Arc::new(options); + + let mut threads = Vec::new(); + + for _ in 0..num_cpus::get() { + let header = header.clone(); + let model = model.clone(); + let options = options.clone(); + let input_rx = input_rx.clone(); + + threads.push(std::thread::spawn(move || { + while let Ok((records, chunk_tx)) = input_rx.recv() { + let input: Result, _> = records + .into_iter() + .map(|record| -> Result { + let input = zip!(header.iter(), record.into_iter()) + .map(|(column_name, value)| { + ( + column_name.to_owned(), + PredictInputValue::String(value.to_owned()), + ) + }) + .collect(); + Ok(PredictInput(input)) }) .collect(); - Ok(PredictInput(input)) - }) - .collect::>()?; - let output = tangram_core::predict::predict(&model, &input, &options); - for output in output { + + let output = + input.map(|input| tangram_core::predict::predict(&model, &input, &options)); + + if chunk_tx.send(output).is_err() { + break; + }; + } + })); + } + + threads.push(std::thread::spawn(move || { + for records_chunk in &reader.records().chunks(PREDICT_CHUNK_SIZE) { + let records_chunk: Result, _> = records_chunk.collect(); + let records_chunk = match records_chunk { + Ok(records_chunk) => records_chunk, + Err(error) => { + let error: anyhow::Error = error.into(); + let _ = output_tx.send(Err(error)); + break; + } + }; + + // Here we create a single use channel which will allow the CSV writer + // to wait for the prediction results in-order while still allowing + // the prediction for future chunks to run in parallel. + let (chunk_tx, chunk_rx) = crossbeam_channel::bounded(1); + if let Err(error) = input_tx.send((records_chunk, chunk_tx)) { + let error: anyhow::Error = error.into(); + let _ = output_tx.send(Err(error)); + break; + } + if output_tx.send(Ok(chunk_rx)).is_err() { + break; + } + } + })); + + while let Ok(output) = output_rx.recv() { + let chunk_rx = output?; + let outputs = chunk_rx.recv()??; + + for output in outputs { let output = match output { tangram_core::predict::PredictOutput::Regression(output) => { vec![output.value.to_string()] @@ -129,5 +197,10 @@ pub fn predict(args: PredictArgs) -> Result<()> { writer.write_record(&output)?; } } + + for thread in threads { + thread.join().unwrap(); + } + Ok(()) }