diff --git a/tensorboard-rs/Cargo.toml b/tensorboard-rs/Cargo.toml index 2efe349..f43636f 100644 --- a/tensorboard-rs/Cargo.toml +++ b/tensorboard-rs/Cargo.toml @@ -26,7 +26,11 @@ tensorboard-proto = { version = "0.5.7" } gethostname = "0.2.1" -image = "0.23.4" +image = "0.25.6" +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0.140" +hex = "0.4.3" +sha2 = "0.10" [dev-dependencies] #protobuf-codegen = "2.14" diff --git a/tensorboard-rs/examples/draw_graph.rs b/tensorboard-rs/examples/draw_graph.rs index 7dff307..419e5c0 100644 --- a/tensorboard-rs/examples/draw_graph.rs +++ b/tensorboard-rs/examples/draw_graph.rs @@ -2,9 +2,9 @@ use tensorboard_rs::summary_writer::SummaryWriter; //use tensorboard_proto::event::{Event, TaggedRunMetadata}; //use tensorboard_proto::summary::{Summary}; //use tensorboard_proto::graph::{GraphDef, }; -use tensorboard_proto::node_def::{NodeDef, }; +use tensorboard_proto::node_def::NodeDef; //use tensorboard_proto::versions::{VersionDef, }; -use tensorboard_proto::attr_value::{AttrValue, }; +use tensorboard_proto::attr_value::AttrValue; //use tensorboard_proto::tensor_shape::{TensorShapeProto, }; //use tensorboard_proto::step_stats::{RunMetadata, }; use protobuf::RepeatedField; @@ -16,10 +16,10 @@ pub fn main() { let mut node1 = NodeDef::new(); node1.set_name("node1".to_string()); node1.set_op("op1".to_string()); - + let inputs = RepeatedField::from(vec![]); node1.set_input(inputs); - + let mut attrs = HashMap::new(); let mut v1 = AttrValue::new(); v1.set_i(16); @@ -27,6 +27,6 @@ pub fn main() { node1.set_attr(attrs); writer.add_graph(&[node1]); - + writer.flush(); } diff --git a/tensorboard-rs/examples/draw_histo.rs b/tensorboard-rs/examples/draw_histo.rs index e4fab19..e5e1f85 100644 --- a/tensorboard-rs/examples/draw_histo.rs +++ b/tensorboard-rs/examples/draw_histo.rs @@ -2,7 +2,6 @@ use tensorboard_rs::summary_writer::SummaryWriter; //use image::{open, }; pub fn main() { - let mut writer = SummaryWriter::new(&("./logdir".to_string())); let min = 1.001; @@ -10,31 +9,54 @@ pub fn main() { let num = 435.; let sum = 8555.435; let sum_squares = 189242.110435; - let bucket_limits = [3.8009999999999997, 6.600999999999999, 9.400999999999998, 12.200999999999999, 15.001, 17.801, 20.601, 23.401, 26.201, 29.001]; - let bucket_counts = [ 6., 15., 24., 33., 27., 48., 57., 66., 75., 84.]; - - writer.add_histogram_raw("run1/histo1", - min, max, - num, - sum, sum_squares, - &bucket_limits, &bucket_counts, - 1 + let bucket_limits = [ + 3.8009999999999997, + 6.600999999999999, + 9.400999999999998, + 12.200999999999999, + 15.001, + 17.801, + 20.601, + 23.401, + 26.201, + 29.001, + ]; + let bucket_counts = [6., 15., 24., 33., 27., 48., 57., 66., 75., 84.]; + + writer.add_histogram_raw( + "run1/histo1", + min, + max, + num, + sum, + sum_squares, + &bucket_limits, + &bucket_counts, + 1, ); - writer.add_histogram_raw("run1/histo1", - min, max, - num, - sum, sum_squares, - &bucket_limits, &bucket_counts, - 2 + writer.add_histogram_raw( + "run1/histo1", + min, + max, + num, + sum, + sum_squares, + &bucket_limits, + &bucket_counts, + 2, ); - writer.add_histogram_raw("run1/histo1", - min, max, - num, - sum, sum_squares, - &bucket_limits, &bucket_counts, - 3 + writer.add_histogram_raw( + "run1/histo1", + min, + max, + num, + sum, + sum_squares, + &bucket_limits, + &bucket_counts, + 3, ); writer.flush(); } diff --git a/tensorboard-rs/examples/draw_hparams.rs b/tensorboard-rs/examples/draw_hparams.rs new file mode 100644 index 0000000..6f4b1f1 --- /dev/null +++ b/tensorboard-rs/examples/draw_hparams.rs @@ -0,0 +1,43 @@ +use std::collections::HashMap; +use tensorboard_rs::hparams::{GenericValue, HyperParameter, Metric}; +use tensorboard_rs::status::JobStatus; +use tensorboard_rs::summary_writer::SummaryWriter; + +pub fn main() { + for i in 0..10 { + let path = format!("./logdir/test-{:}", i); + let mut writer = SummaryWriter::new(path); + let options = ["A", "B", "C", "D"]; + let hparams = vec![ + HyperParameter::new("M1"), + HyperParameter::with_string("M2", "A Metric"), + HyperParameter::with_bool("M3", true), + HyperParameter::with_f64s("M4", &[1f64, 2f64]), + HyperParameter::with_strings("M5", &options), + ]; + let metrics = vec![Metric::new("Test Metric")]; + writer.add_hparams_config(&hparams, &metrics); + + let mut data = HashMap::new(); + data.insert("M1".to_string(), GenericValue::Number(i as f64)); + data.insert( + "M2".to_string(), + GenericValue::String("A Metric".to_string()), + ); + data.insert("M3".to_string(), GenericValue::Bool(true)); + data.insert("M4".to_string(), GenericValue::Number((i % 2) as f64)); + data.insert( + "M5".to_string(), + GenericValue::String(options[i % 4].to_string()), + ); + writer.add_hparams(data, Some(format!("test-{:}", i).to_string()), Some(0)); + + writer.add_scalar("Test Metric", 0f32, 0); + writer.add_scalar("Test Metric", 1f32, 1); + writer.add_scalar("Test Metric", 2f32, 2); + + writer.add_job_status(&JobStatus::Success, None); + + writer.flush(); + } +} diff --git a/tensorboard-rs/examples/draw_image.rs b/tensorboard-rs/examples/draw_image.rs index 014a9a0..1bbf8e1 100644 --- a/tensorboard-rs/examples/draw_image.rs +++ b/tensorboard-rs/examples/draw_image.rs @@ -1,8 +1,7 @@ +use image::open; use tensorboard_rs::summary_writer::SummaryWriter; -use image::{open, }; pub fn main() { - let mut writer = SummaryWriter::new(&("./logdir".to_string())); let stop_image = "./examples/stop.jpg"; @@ -10,7 +9,11 @@ pub fn main() { let img = img.into_rgb8(); let (width, height) = img.dimensions(); - - writer.add_image(&"test_image".to_string(), &img.into_raw()[..], &vec![3, width as usize, height as usize][..], 12); + writer.add_image( + &"test_image".to_string(), + &img.into_raw()[..], + &vec![3, width as usize, height as usize][..], + 12, + ); writer.flush(); } diff --git a/tensorboard-rs/examples/draw_scalar.rs b/tensorboard-rs/examples/draw_scalar.rs index 407960d..33817e1 100644 --- a/tensorboard-rs/examples/draw_scalar.rs +++ b/tensorboard-rs/examples/draw_scalar.rs @@ -1,15 +1,15 @@ -use tensorboard_rs::summary_writer::SummaryWriter; use std::collections::HashMap; +use tensorboard_rs::summary_writer::SummaryWriter; pub fn main() { let mut writer = SummaryWriter::new(&("./logdir".to_string())); let name = "run1"; let mut scalar = 2.3; - let mut step = 12; + let mut step = 12; for i in 0..2 { println!("{}", i); - scalar += (i as f32)*0.1; + scalar += (i as f32) * 0.1; step += i; writer.add_scalar(name, scalar, step); diff --git a/tensorboard-rs/src/event_file_writer.rs b/tensorboard-rs/src/event_file_writer.rs index 38f6cda..3239a13 100644 --- a/tensorboard-rs/src/event_file_writer.rs +++ b/tensorboard-rs/src/event_file_writer.rs @@ -1,15 +1,15 @@ -use std::path::{PathBuf, Path}; -use std::fs; -use std::time::SystemTime; use gethostname::gethostname; -use std::process::id; -use std::fs::File; use protobuf::Message; -use std::thread::{spawn, JoinHandle}; +use std::fs; +use std::fs::File; +use std::path::{Path, PathBuf}; +use std::process::id; use std::sync::mpsc::{channel, Sender}; +use std::thread::{spawn, JoinHandle}; +use std::time::SystemTime; -use tensorboard_proto::event::Event; use crate::record_writer::RecordWriter; +use tensorboard_proto::event::Event; enum EventSignal { Data(Vec), @@ -37,8 +37,11 @@ impl EventFileWriter { } let hostname = gethostname().into_string().expect(""); let pid = id(); - - let file_name = format!("events.out.tfevents.{:010}.{}.{}.{}", time, hostname, pid, 0); + + let file_name = format!( + "events.out.tfevents.{:010}.{}.{}.{}", + time, hostname, pid, 0 + ); //let file_writer = File::create(logdir.join(file_name)).expect(""); //let writer = RecordWriter::new(file_writer); @@ -47,20 +50,24 @@ impl EventFileWriter { let child = spawn(move || { let file_writer = File::create(logdir_move.join(file_name)).expect(""); let mut writer = RecordWriter::new(file_writer); - + loop { let result: EventSignal = rx.recv().unwrap(); match result { EventSignal::Data(d) => { writer.write(&d).expect("write error"); - }, - EventSignal::Flush => {writer.flush().expect("flush error");}, - EventSignal::Stop => {break;}, + } + EventSignal::Flush => { + writer.flush().expect("flush error"); + } + EventSignal::Stop => { + break; + } } - }; + } writer.flush().expect("flush error"); }); - + let mut ret = EventFileWriter { logdir, writer: tx, @@ -81,13 +88,13 @@ impl EventFileWriter { pub fn get_logdir(&self) -> PathBuf { self.logdir.to_path_buf() } - + pub fn add_event(&mut self, event: &Event) { let mut data: Vec = Vec::new(); event.write_to_vec(&mut data).expect(""); self.writer.send(EventSignal::Data(data)).expect(""); } - + pub fn flush(&mut self) { self.writer.send(EventSignal::Flush).expect(""); } diff --git a/tensorboard-rs/src/hparams.rs b/tensorboard-rs/src/hparams.rs new file mode 100644 index 0000000..8fed0a6 --- /dev/null +++ b/tensorboard-rs/src/hparams.rs @@ -0,0 +1,297 @@ +use crate::status::JobStatus; +use protobuf::well_known_types::{ListValue, Value}; +use protobuf::{Message, RepeatedField, SingularPtrField}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use tensorboard_proto::api::{DataType, Experiment, HParamInfo, MetricInfo, MetricName}; +use tensorboard_proto::plugin_hparams::{HParamsPluginData, SessionEndInfo, SessionStartInfo}; +use tensorboard_proto::summary::{ + Summary, SummaryMetadata, SummaryMetadata_PluginData, Summary_Value, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum GenericValue { + Number(f64), + String(String), + Bool(bool), + List(Vec), + None, +} + +impl From for Value { + fn from(value: GenericValue) -> Self { + let mut ret = Value::new(); + match value { + GenericValue::Number(n) => { + ret.set_number_value(n); + } + GenericValue::String(s) => { + ret.set_string_value(s); + } + GenericValue::Bool(b) => { + ret.set_bool_value(b); + } + GenericValue::None => {} + GenericValue::List(v) => { + let v = v.iter().map(|v| v.clone().into()).collect::>(); + let mut list_value = ListValue::new(); + list_value.set_values(RepeatedField::from_vec(v)); + ret.set_list_value(list_value); + } + } + ret + } +} +impl From<&GenericValue> for Value { + fn from(value: &GenericValue) -> Self { + value.clone().into() + } +} +#[derive(Debug, Clone)] +pub struct HyperParameter { + name: String, + value: GenericValue, +} + +impl HyperParameter { + + pub fn as_kv(&self) -> (String,GenericValue) { + (self.name.clone(),self.value.clone()) + } + + pub fn with_values(name: &str, values: Vec) -> Self { + HyperParameter { + name: name.to_string(), + value: GenericValue::List(values), + } + } + pub fn new(name: &str) -> Self { + HyperParameter { + name: name.to_string(), + value: GenericValue::None, + } + } + pub fn with_bool(name: &str, value: bool) -> Self { + HyperParameter { + name: name.to_string(), + value: GenericValue::Bool(value), + } + } + pub fn with_bools(name: &str) -> Self { + let value = vec![GenericValue::Bool(true), GenericValue::Bool(false)]; + Self::with_values(name, value) + } + pub fn with_string(name: &str, value: &str) -> Self { + HyperParameter { + name: name.to_string(), + value: GenericValue::String(value.to_string()), + } + } + pub fn with_strings(name: &str, values: &[&str]) -> Self { + let value = values + .iter() + .map(|v| GenericValue::String(v.to_string())) + .collect::>(); + Self::with_values(name, value) + } + pub fn with_f64(name: &str, value: f64) -> Self { + HyperParameter { + name: name.to_string(), + value: GenericValue::Number(value), + } + } + pub fn with_f64s(name: &str, values: &[f64]) -> Self { + let value = values + .iter() + .map(|v| GenericValue::Number(*v)) + .collect::>(); + Self::with_values(name, value) + } +} +fn get_type_and_value(value: &GenericValue) -> (DataType, RepeatedField) { + let (type_, value) = match value { + GenericValue::Number(n) => { + let mut value = Value::new(); + value.set_number_value(*n); + ( + DataType::DATA_TYPE_FLOAT64, + RepeatedField::from_vec(vec![value]), + ) + } + GenericValue::String(s) => { + let mut value = Value::new(); + value.set_string_value(s.clone()); + ( + DataType::DATA_TYPE_STRING, + RepeatedField::from_vec(vec![value]), + ) + } + GenericValue::Bool(b) => { + let mut value = Value::new(); + value.set_bool_value(*b); + ( + DataType::DATA_TYPE_BOOL, + RepeatedField::from_vec(vec![value]), + ) + } + GenericValue::None => (DataType::DATA_TYPE_UNSET, RepeatedField::new()), + GenericValue::List(l) => { + if l.is_empty() { + (DataType::DATA_TYPE_UNSET, RepeatedField::new()) + } else { + let v = l.first().unwrap(); + let type_ = match v { + GenericValue::Number(_) => DataType::DATA_TYPE_FLOAT64, + GenericValue::String(_) => DataType::DATA_TYPE_STRING, + GenericValue::Bool(_) => DataType::DATA_TYPE_BOOL, + _ => { + panic!("Not supported") + } + }; + let values = l.iter().flat_map(|v| get_type_and_value(v).1).collect(); + (type_, RepeatedField::from_vec(values)) + } + } + }; + (type_, value) +} +impl From<&HyperParameter> for HParamInfo { + fn from(value: &HyperParameter) -> Self { + let value = value.clone(); + let mut ret = HParamInfo::new(); + ret.set_name(value.name); + + let (type_, value) = get_type_and_value(&value.value); + + ret.set_field_type(type_); + let mut list_value = ListValue::new(); + list_value.set_values(value); + ret.set_domain_discrete(list_value); + ret + } +} + +#[derive(Debug, Clone)] +pub struct Metric { + name: String, +} + +impl Metric { + pub fn new(name: &str) -> Self { + Metric { + name: name.to_string(), + } + } +} + +impl From<&Metric> for MetricInfo { + fn from(value: &Metric) -> Self { + let value = value.clone(); + let mut metric_name = MetricName::new(); + metric_name.set_tag(value.name); + let mut ret = MetricInfo::new(); + ret.set_name(metric_name); + ret + } +} + +const PLUGIN_NAME: &str = "hparams"; +const PLUGIN_DATA_VERSION: i32 = 0; +fn create_summary_metadata(hparams_plugin_data_pb: &HParamsPluginData) -> SummaryMetadata { + let mut content = HParamsPluginData::new(); + content.clone_from(hparams_plugin_data_pb); + content.version = PLUGIN_DATA_VERSION; + + let mut summary_plugin_data = SummaryMetadata_PluginData::new(); + summary_plugin_data.set_content(content.write_to_bytes().unwrap()); + summary_plugin_data.set_plugin_name(PLUGIN_NAME.to_string()); + + let mut ret = SummaryMetadata::new(); + ret.plugin_data = SingularPtrField::from_option(Some(summary_plugin_data)); + ret +} + +fn sumary_pb(tag: &str, hparams_plugin_data: &HParamsPluginData) -> Summary { + let mut summary = Summary::new(); + let summary_metadata = create_summary_metadata(hparams_plugin_data); + let mut summary_value = Summary_Value::new(); + summary_value.set_tag(tag.to_string()); + summary_value.set_metadata(summary_metadata); + summary.set_value(RepeatedField::from_vec(vec![summary_value])); + summary +} + +pub const EXPERIMENT_TAG: &str = "_hparams_/experiment"; +pub const SESSION_START_INFO_TAG: &str = "_hparams_/session_start_info"; +pub const SESSION_END_INFO_TAG: &str = "_hparams_/session_end_info"; +pub fn hparams_config_pb(hparams: Vec, metrics: Vec) -> Summary { + let mut experiment = Experiment::new(); + experiment.set_hparam_infos(RepeatedField::from(hparams)); + experiment.set_metric_infos(RepeatedField::from(metrics)); + let mut hparam_plugin_data = HParamsPluginData::new(); + hparam_plugin_data.set_experiment(experiment); + sumary_pb(EXPERIMENT_TAG, &hparam_plugin_data) +} + +pub fn hparams_config(hyper_parameters: &[HyperParameter], metrics: &[Metric]) -> Summary { + let hyper_parameters = hyper_parameters + .iter() + .map(|h| h.into()) + .collect::>(); + let metrics = metrics + .iter() + .map(|h| h.into()) + .collect::>(); + hparams_config_pb(hyper_parameters, metrics) +} + +fn derive_session_group_name( + trial_id: Option, + hparams: &HashMap, +) -> String { + if let Some(trial_id) = trial_id { + trial_id + } else { + let json_str = serde_json::to_string(&hparams).expect("Failed to serialize to JSON"); + let mut hasher = Sha256::new(); + hasher.update(json_str.as_bytes()); + let result = hasher.finalize(); + hex::encode(result) + } +} + +pub fn hparams( + hparams: HashMap, + trial_id: Option, + start_time_secs: Option, +) -> Summary { + let group_name = derive_session_group_name(trial_id, &hparams); + let mut session_start_info = SessionStartInfo::new(); + session_start_info.set_group_name(group_name); + if let Some(start_time_secs) = start_time_secs { + session_start_info.set_start_time_secs(start_time_secs as f64); + } + let hparams = hparams + .iter() + .map(|(k, v)| { + let value: Value = v.into(); + (k.clone(), value) + }) + .collect::>(); + session_start_info.set_hparams(hparams); + let mut hparams_plugin_data = HParamsPluginData::new(); + hparams_plugin_data.set_session_start_info(session_start_info); + sumary_pb(SESSION_START_INFO_TAG, &hparams_plugin_data) +} + +pub fn status_config(job_status: &JobStatus, end_time_secs: Option) -> Summary { + let mut session_end_info = SessionEndInfo::new(); + if let Some(end_time_secs) = end_time_secs { + session_end_info.set_end_time_secs(end_time_secs as f64); + } + session_end_info.set_status(job_status.clone().into()); + let mut hparams_plugin_data = HParamsPluginData::new(); + hparams_plugin_data.set_session_end_info(session_end_info); + sumary_pb(SESSION_END_INFO_TAG, &hparams_plugin_data) +} diff --git a/tensorboard-rs/src/lib.rs b/tensorboard-rs/src/lib.rs index 728bbf1..dc3c2dd 100644 --- a/tensorboard-rs/src/lib.rs +++ b/tensorboard-rs/src/lib.rs @@ -1,4 +1,3 @@ - //! Write data for Tensorboard from Rust. //! ============================================================= //! @@ -15,12 +14,10 @@ //! Licese //! ------------ - - - - +pub mod event_file_writer; +pub mod hparams; pub mod masked_crc32c; pub mod record_writer; -pub mod event_file_writer; -pub mod summary_writer; +pub mod status; pub mod summary; +pub mod summary_writer; diff --git a/tensorboard-rs/src/masked_crc32c.rs b/tensorboard-rs/src/masked_crc32c.rs index f33e838..66259e0 100644 --- a/tensorboard-rs/src/masked_crc32c.rs +++ b/tensorboard-rs/src/masked_crc32c.rs @@ -1,3 +1,4 @@ +#[allow(clippy::manual_rotate)] pub fn masked_crc32c(data: &[u8]) -> u32 { let x = crc32c(data); ((x >> 15) | (x << 17)).overflowing_add(0xa282ead8).0 @@ -8,70 +9,38 @@ pub fn masked_crc32c(data: &[u8]) -> u32 { //} const CRC_TABLE: [u32; 256] = [ - 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, - 0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb, - 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, - 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, - 0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b, - 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, - 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, - 0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b, - 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, - 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, - 0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5, - 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, - 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, - 0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a, - 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, - 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, - 0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48, - 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, - 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, - 0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198, - 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, - 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, - 0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8, - 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, - 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, - 0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789, - 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, - 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, - 0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9, - 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, - 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, - 0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829, - 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, - 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, - 0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043, - 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, - 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, - 0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc, - 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, - 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, - 0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652, - 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, - 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, - 0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982, - 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, - 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, - 0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2, - 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, - 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, - 0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f, - 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, - 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, - 0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f, - 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, - 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, - 0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f, - 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, - 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, - 0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321, - 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, - 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, - 0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e, - 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, - 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351, + 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb, + 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, + 0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b, 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, + 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b, + 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, + 0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5, 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, + 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a, + 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, + 0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48, 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, + 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198, + 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, + 0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8, 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, + 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789, + 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, + 0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9, 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, + 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829, + 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, + 0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043, 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, + 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc, + 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, + 0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652, 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, + 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982, + 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, + 0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2, 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, + 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f, + 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, + 0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f, 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, + 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f, + 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, + 0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321, 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, + 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e, + 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351, ]; const CRC_INIT: u32 = 0; @@ -81,13 +50,13 @@ const _MASK: u32 = 0xFFFFFFFF; pub fn crc_update(crc: u32, data: &[u8]) -> u32 { let mut crc = crc ^ _MASK; for b in data { - let table_index = ((crc & 0xff ) as u8 )^ b; + let table_index = ((crc & 0xff) as u8) ^ b; crc = (CRC_TABLE[table_index as usize] ^ (crc >> 8)) & _MASK; } crc ^ _MASK } -pub fn crc_finalize(crc: u32) -> u32{ +pub fn crc_finalize(crc: u32) -> u32 { crc & _MASK } diff --git a/tensorboard-rs/src/record_writer.rs b/tensorboard-rs/src/record_writer.rs index ceb4c4f..d836549 100644 --- a/tensorboard-rs/src/record_writer.rs +++ b/tensorboard-rs/src/record_writer.rs @@ -12,8 +12,8 @@ impl RecordWriter { } pub fn write(&mut self, data: &[u8]) -> std::io::Result<()>{ let header = data.len() as u64; - let header_crc = (masked_crc32c(&(header.to_le_bytes())) as u32).to_le_bytes(); - let footer_crc = (masked_crc32c(data) as u32).to_le_bytes(); + let header_crc = masked_crc32c(&(header.to_le_bytes())).to_le_bytes(); + let footer_crc = masked_crc32c(data).to_le_bytes(); let header = header.to_le_bytes(); self._writer.write_all(&header)?; diff --git a/tensorboard-rs/src/status.rs b/tensorboard-rs/src/status.rs new file mode 100644 index 0000000..9e9ee3b --- /dev/null +++ b/tensorboard-rs/src/status.rs @@ -0,0 +1,20 @@ +use tensorboard_proto::api::Status; + +#[derive(Debug, Clone)] +pub enum JobStatus { + Unknown, + Success, + Failure, + Running, +} + +impl From for Status { + fn from(value: JobStatus) -> Self { + match value { + JobStatus::Unknown => Status::STATUS_UNKNOWN, + JobStatus::Success => Status::STATUS_SUCCESS, + JobStatus::Failure => Status::STATUS_FAILURE, + JobStatus::Running => Status::STATUS_RUNNING, + } + } +} diff --git a/tensorboard-rs/src/summary.rs b/tensorboard-rs/src/summary.rs index a85a08b..f8446e9 100644 --- a/tensorboard-rs/src/summary.rs +++ b/tensorboard-rs/src/summary.rs @@ -1,12 +1,16 @@ #![allow(clippy::too_many_arguments)] -use tensorboard_proto::summary::{Summary, Summary_Value, Summary_Image, SummaryMetadata, SummaryMetadata_PluginData, HistogramProto}; -use tensorboard_proto::layout::{Layout, Category}; + +use std::io::Cursor; use protobuf::RepeatedField; +use tensorboard_proto::layout::{Category, Layout}; +use tensorboard_proto::summary::{ + HistogramProto, Summary, SummaryMetadata, SummaryMetadata_PluginData, Summary_Image, + Summary_Value, +}; -use image::{RgbImage, DynamicImage, ImageOutputFormat}; +use image::{DynamicImage, ImageFormat, RgbImage}; pub fn scalar(name: &str, scalar_value: f32) -> Summary { - let mut value = Summary_Value::new(); value.set_tag(name.to_string()); value.set_simple_value(scalar_value); @@ -18,12 +22,15 @@ pub fn scalar(name: &str, scalar_value: f32) -> Summary { summary } -pub fn histogram_raw(name: &str, - min: f64, max: f64, - num: f64, - sum: f64, sum_squares: f64, - bucket_limits: &[f64], - bucket_counts: &[f64], +pub fn histogram_raw( + name: &str, + min: f64, + max: f64, + num: f64, + sum: f64, + sum_squares: f64, + bucket_limits: &[f64], + bucket_counts: &[f64], ) -> Summary { let mut hist = HistogramProto::new(); hist.set_min(min); @@ -33,7 +40,7 @@ pub fn histogram_raw(name: &str, hist.set_sum_squares(sum_squares); hist.set_bucket_limit(bucket_limits.to_vec()); hist.set_bucket(bucket_counts.to_vec()); - + let mut value = Summary_Value::new(); value.set_tag(name.to_string()); value.set_histo(hist); @@ -53,21 +60,23 @@ pub fn image(tag: &str, data: &[u8], dim: &[usize]) -> Summary { if dim[0] != 3 { panic!("needs rgb"); } - if data.len() != dim[0]*dim[1]*dim[2] { + if data.len() != dim[0] * dim[1] * dim[2] { panic!("length of data should matches with dim."); } - + let mut img = RgbImage::new(dim[1] as u32, dim[2] as u32); img.clone_from_slice(data); let dimg = DynamicImage::ImageRgb8(img); - let mut output_buf = Vec::::new(); - dimg.write_to(&mut output_buf, ImageOutputFormat::Png).expect(""); + let output_buf = Vec::::new(); + let mut c = Cursor::new(output_buf); + dimg.write_to(&mut c, ImageFormat::Png) + .expect(""); let mut output_image = Summary_Image::new(); output_image.set_height(dim[1] as i32); output_image.set_width(dim[2] as i32); output_image.set_colorspace(3); - output_image.set_encoded_image_string(output_buf); + output_image.set_encoded_image_string(c.into_inner()); let mut value = Summary_Value::new(); value.set_tag(tag.to_string()); value.set_image(output_image); @@ -88,6 +97,4 @@ pub fn custom_scalars(_layout: f32) { plugin_data.set_plugin_name("custom_scalars".to_string()); let mut smd = SummaryMetadata::new(); smd.set_plugin_data(plugin_data); - - } diff --git a/tensorboard-rs/src/summary_writer.rs b/tensorboard-rs/src/summary_writer.rs index 64e7931..392177a 100644 --- a/tensorboard-rs/src/summary_writer.rs +++ b/tensorboard-rs/src/summary_writer.rs @@ -1,24 +1,36 @@ #![allow(clippy::too_many_arguments)] -use std::path::{PathBuf, Path}; -use std::time::SystemTime; -use std::collections::HashMap; +use crate::event_file_writer::EventFileWriter; +use crate::hparams::{ + hparams, hparams_config, status_config, GenericValue, HyperParameter, Metric, +}; +use crate::summary::{histogram_raw, image, scalar}; use protobuf::Message; use protobuf::RepeatedField; -use crate::event_file_writer::EventFileWriter; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::time::SystemTime; use tensorboard_proto::event::{Event, TaggedRunMetadata}; -use tensorboard_proto::summary::{Summary}; -use tensorboard_proto::graph::{GraphDef, }; -use tensorboard_proto::node_def::{NodeDef, }; -use tensorboard_proto::versions::{VersionDef, }; +use tensorboard_proto::graph::GraphDef; +use tensorboard_proto::node_def::NodeDef; //use tensorboard_proto::attr_value::{AttrValue, }; //use tensorboard_proto::tensor_shape::{TensorShapeProto, }; -use tensorboard_proto::step_stats::{RunMetadata, }; -use crate::summary::{scalar, image, histogram_raw}; - +use crate::status::JobStatus; +use tensorboard_proto::step_stats::RunMetadata; +use tensorboard_proto::summary::Summary; +use tensorboard_proto::versions::VersionDef; pub struct FileWriter { writer: EventFileWriter, } + +impl FileWriter { + pub(crate) fn add_global_summary(&mut self, summary: Summary) { + let mut evn = Event::new(); + evn.set_summary(summary); + self.writer.add_event(&evn); + } +} + impl FileWriter { pub fn new>(logdir: P) -> FileWriter { FileWriter { @@ -30,15 +42,15 @@ impl FileWriter { } pub fn add_event(&mut self, event: &Event, step: usize) { let mut event = event.clone(); - + let mut time_full = 0.0; if let Ok(n) = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH) { time_full = n.as_secs_f64(); } event.set_wall_time(time_full); - + event.set_step(step as i64); - + self.writer.add_event(&event) } pub fn add_summary(&mut self, summary: Summary, step: usize) { @@ -71,14 +83,31 @@ pub struct SummaryWriter { writer: FileWriter, all_writers: HashMap, } + impl SummaryWriter { + pub fn add_job_status(&mut self, job_status: &JobStatus, end_time_secs: Option) { + self.writer + .add_global_summary(status_config(job_status, end_time_secs)) + } pub fn new>(logdir: P) -> SummaryWriter { SummaryWriter { writer: FileWriter::new(logdir), all_writers: HashMap::new(), } } - pub fn add_hparams(&mut self) {unimplemented!();} + pub fn add_hparams_config(&mut self, hyper_parameters: &[HyperParameter], metrics: &[Metric]) { + self.writer + .add_global_summary(hparams_config(hyper_parameters, metrics)); + } + pub fn add_hparams( + &mut self, + hyper_parameters: HashMap, + trial_id: Option, + start_time_secs: Option, + ) { + self.writer + .add_global_summary(hparams(hyper_parameters, trial_id, start_time_secs)); + } pub fn add_scalar(&mut self, tag: &str, scalar_value: f32, step: usize) { self.writer.add_summary(scalar(tag, scalar_value), step); } @@ -86,7 +115,7 @@ impl SummaryWriter { let base_logdir = self.writer.get_logdir(); for (tag, scalar_value) in tag_scalar.iter() { let fw_tag = base_logdir.join(main_tag).join(tag); - if ! self.all_writers.contains_key(&fw_tag) { + if !self.all_writers.contains_key(&fw_tag) { let new_writer = FileWriter::new(fw_tag.clone()); self.all_writers.insert(fw_tag.clone(), new_writer); } @@ -95,54 +124,104 @@ impl SummaryWriter { } } - pub fn export_scalars_to_json(&self) {unimplemented!();} - pub fn add_histogram(&mut self) {unimplemented!();} - pub fn add_histogram_raw(&mut self, - tag: &str, - min: f64, max: f64, - num: f64, - sum: f64, sum_squares: f64, - bucket_limits: &[f64], bucket_counts: &[f64], - step: usize + pub fn export_scalars_to_json(&self) { + unimplemented!(); + } + pub fn add_histogram(&mut self) { + unimplemented!(); + } + pub fn add_histogram_raw( + &mut self, + tag: &str, + min: f64, + max: f64, + num: f64, + sum: f64, + sum_squares: f64, + bucket_limits: &[f64], + bucket_counts: &[f64], + step: usize, ) { if bucket_limits.len() != bucket_counts.len() { panic!("bucket_limits.len() != bucket_counts.len()"); } - self.writer.add_summary(histogram_raw(tag, min, max, num, sum, sum_squares, bucket_limits, bucket_counts), step); + self.writer.add_summary( + histogram_raw( + tag, + min, + max, + num, + sum, + sum_squares, + bucket_limits, + bucket_counts, + ), + step, + ); } pub fn add_image(&mut self, tag: &str, data: &[u8], dim: &[usize], step: usize) { self.writer.add_summary(image(tag, data, dim), step); } - pub fn add_images(&mut self) {unimplemented!();} - pub fn add_image_with_boxes(&mut self) {unimplemented!();} - pub fn add_figure(&mut self) {unimplemented!();} - pub fn add_video(&mut self) {unimplemented!();} - pub fn add_audio(&mut self) {unimplemented!();} - pub fn add_text(&mut self) {unimplemented!();} - pub fn add_onnx_graph(&mut self) {unimplemented!();} - pub fn add_openvino_graph(&mut self) {unimplemented!();} + pub fn add_images(&mut self) { + unimplemented!(); + } + pub fn add_image_with_boxes(&mut self) { + unimplemented!(); + } + pub fn add_figure(&mut self) { + unimplemented!(); + } + pub fn add_video(&mut self) { + unimplemented!(); + } + pub fn add_audio(&mut self) { + unimplemented!(); + } + pub fn add_text(&mut self) { + unimplemented!(); + } + pub fn add_onnx_graph(&mut self) { + unimplemented!(); + } + pub fn add_openvino_graph(&mut self) { + unimplemented!(); + } pub fn add_graph(&mut self, node_list: &[NodeDef]) { let mut graph = GraphDef::new(); - + let nodes = RepeatedField::from(node_list.to_vec()); graph.set_node(nodes); - + let mut version = VersionDef::new(); version.set_producer(22); graph.set_versions(version); let stats = RunMetadata::new(); - + self.writer.add_graph(graph, stats); } - pub fn add_embedding(&mut self) {unimplemented!();} - pub fn add_pr_curve(&mut self) {unimplemented!();} - pub fn add_pr_curve_raw(&mut self) {unimplemented!();} - pub fn add_custom_scalars_multilinechart(&mut self) {unimplemented!();} - pub fn add_custom_scalars_marginchart(&mut self) {unimplemented!();} - pub fn add_custom_scalars(&mut self) {unimplemented!();} - pub fn add_mesh(&mut self) {unimplemented!();} + pub fn add_embedding(&mut self) { + unimplemented!(); + } + pub fn add_pr_curve(&mut self) { + unimplemented!(); + } + pub fn add_pr_curve_raw(&mut self) { + unimplemented!(); + } + pub fn add_custom_scalars_multilinechart(&mut self) { + unimplemented!(); + } + pub fn add_custom_scalars_marginchart(&mut self) { + unimplemented!(); + } + pub fn add_custom_scalars(&mut self) { + unimplemented!(); + } + pub fn add_mesh(&mut self) { + unimplemented!(); + } pub fn flush(&mut self) { self.writer.flush();