diff --git a/Cargo.lock b/Cargo.lock index 68cc4d1..853bee2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -58,6 +58,25 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" +[[package]] +name = "bytes" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" +dependencies = [ + "serde", +] + +[[package]] +name = "cc" +version = "1.0.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7" +dependencies = [ + "jobserver", + "libc", +] + [[package]] name = "cfg-if" version = "0.1.10" @@ -166,6 +185,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "indexmap" version = "1.6.0" @@ -191,6 +216,15 @@ dependencies = [ "either", ] +[[package]] +name = "jobserver" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6" +dependencies = [ + "libc", +] + [[package]] name = "libc" version = "0.2.79" @@ -205,12 +239,15 @@ checksum = "8dd5a6d5999d9907cda8ed67bbd137d3af8085216c2ac62de5be860bd41f304a" [[package]] name = "markov" -version = "0.2.0" +version = "0.2.1" dependencies = [ + "bytes", "clap", "markov 1.1.0", + "num_cpus", "serde", "serde_cbor", + "zstd", ] [[package]] @@ -228,6 +265,16 @@ dependencies = [ "serde_yaml", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "once_cell" version = "1.20.3" @@ -244,6 +291,12 @@ dependencies = [ "indexmap", ] +[[package]] +name = "pkg-config" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + [[package]] name = "ppv-lite86" version = "0.2.9" @@ -473,3 +526,31 @@ checksum = "39f0c922f1a334134dc2f7a8b67dc5d25f0735263feec974345ff706bcf20b0d" dependencies = [ "linked-hash-map", ] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 9406f0e..ab11ee7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "markov" -version = "0.2.0" +version = "0.2.1" description = "Generate string of text from Markov chain fed by stdin or file(s)" authors = ["Avril "] edition = "2018" @@ -12,9 +12,23 @@ opt-level = 3 lto = true #"fat" codegen-units = 1 +strip = true + +[profile.symbols] +inherits = "release" +lto = "fat" +strip = false + +[features] +default = ["threads"] + +threads = ["zstd/zstdmt", "dep:num_cpus"] [dependencies] +bytes = { version = "1.10.0", features = ["serde"] } chain = {package = "markov", version = "1.1.0" } clap = { version = "4.5.29", features = ["derive"] } +num_cpus = { version = "1.16.0", optional = true } serde = { version = "1.0.217", features = ["derive"] } serde_cbor = { version = "0.11.2", features = ["alloc"] } +zstd = { version = "0.13.2", features = [] } diff --git a/src/format.rs b/src/format.rs new file mode 100644 index 0000000..df6610d --- /dev/null +++ b/src/format.rs @@ -0,0 +1,379 @@ +//! Handles the chain load/save format +use super::*; +use std::{ + io::{ + self, + Read, Write, BufRead, + }, + fmt, +}; +use bytes::{ + Buf, BufMut, Bytes, +}; +use zstd::{ + Encoder, Decoder, +}; + +/// The chain that can be saved / loaded. +pub type Chain = crate::Chain; + +/// The version of the encoded format stream +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Copy)] +#[repr(packed)] +pub struct Version(pub u8,pub u8,pub u8,pub u8); + +impl fmt::Display for Version +{ + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + write!(f, "{}.{}.{}", self.0,self.1, self.2)?; + if self.3 != 0 { + write!(f, "r{}", self.3) + } else { + Ok(()) + } + } +} + + +impl Version { + /// Current save version + pub const CURRENT: Self = Version(0,0,0,0); + + /// Current value as a native integer + const CURRENT_VALUE: u32 = Self::CURRENT.as_native(); + + pub const fn as_native(&self) -> u32 { + u32::from_be_bytes([self.0, self.1, self.2, self.3]) + } + + #[inline] + pub const fn from_native(value: u32) -> Self { + let [a,b,c,d] = u32::to_be_bytes(value); + Self(a,b,c,d) + } +} + +impl Default for Version +{ + #[inline] + fn default() -> Self + { + Self::CURRENT + } +} + +pub unsafe trait AutoBinaryFormat: Sized { + #[inline] + fn as_raw_for_encode(&self) -> *const [u8] { + let ptr = self as *const Self; + std::ptr::slice_from_raw_parts(ptr as *const u8, std::mem::size_of::()) + } + #[inline] + fn as_raw_for_decode(&mut self) -> *mut [u8] { + let ptr = self as *mut Self; + std::ptr::slice_from_raw_parts_mut(ptr as *mut u8, std::mem::size_of::()) + } + + fn raw_format_read_size(&mut self) -> usize { + std::mem::size_of::() + } + fn raw_format_write_size(&self) -> usize { + std::mem::size_of::() + } +} + +unsafe impl AutoBinaryFormat for [T; N] {} + +pub trait BinaryFormat { + fn read_from(&mut self, stream: &mut S) -> io::Result + where S: io::Read; + + fn write_to(&self, stream: &mut S) -> io::Result + where S: io::Write; + + fn binary_format_read_size(&mut self) -> Option; + fn binary_format_write_size(&self) -> usize; +} + +impl BinaryFormat for T +where T: AutoBinaryFormat +{ + #[inline(always)] + fn read_from(&mut self, stream: &mut S) -> io::Result + where S: io::Read { + let ptr = self.as_raw_for_decode(); + // SAFETY: The read data is guaranteed to be valid here. + Ok(unsafe { + stream.read_exact(&mut *ptr)?; + (*ptr).len() + }) + } + #[inline(always)] + fn write_to(&self, stream: &mut S) -> io::Result + where S: io::Write { + let ptr = self.as_raw_for_encode(); + // SAFETY: The written data is guaranteed to be valid here. + Ok(unsafe { + stream.write_all(&*ptr)?; + (*ptr).len() + }) + } + #[inline] + fn binary_format_read_size(&mut self) -> Option { + Some(self.raw_format_read_size()) + } + #[inline] + fn binary_format_write_size(&self) -> usize { + self.raw_format_write_size() + } +} + +impl BinaryFormat for [T] +where T: BinaryFormat +{ + #[inline] + fn read_from(&mut self, stream: &mut S) -> io::Result + where S: io::Read { + let mut sz = 0; + for i in self.iter_mut() { + sz += i.read_from(stream)?; + } + Ok(sz) + } + #[inline] + fn write_to(&self, stream: &mut S) -> io::Result + where S: io::Write { + let mut sz =0; + for i in self.iter() { + sz += i.write_to(stream)?; + } + Ok(sz) + } + #[inline] + fn binary_format_read_size(&mut self) -> Option { + self.iter_mut().map(|x| x.binary_format_read_size()).try_fold(0, |x, y| { + match (x, y) { + (x, Some(y)) => Some(x + y), + _ => None, + } + }) + } + #[inline] + fn binary_format_write_size(&self) -> usize { + self.iter().map(|x| x.binary_format_write_size()).sum() + } +} + +impl BinaryFormat for Version { + #[inline] + fn read_from(&mut self, stream: &mut S) -> io::Result + where S: io::Read { + let mut vi = [0u8; 4]; + stream.read_exact(&mut vi[..])?; + Ok(4) + } + #[inline] + fn write_to(&self, stream: &mut S) -> io::Result + where S: io::Write { + let vi = [self.0,self.1,self.2,self.3]; + stream.write_all(&vi[..])?; + Ok(4) + + } + #[inline] + fn binary_format_read_size(&mut self) -> Option { + Some(std::mem::size_of::() * 4) + } + #[inline] + fn binary_format_write_size(&self) -> usize { + std::mem::size_of::() * 4 + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, Default)] +#[repr(u32)] +pub enum Compressed { + #[default] + No = 0, + Zstd = 1, +} + +impl Compressed { + + #[inline] + const fn to_int(&self) -> u32 { + *self as u32 + } + #[inline] + fn try_from_int(val: u32) -> Option { + match val { + // SAFETY: These variants are known + 0..=1 => Some(unsafe { + std::mem::transmute(val) + }), + _ => None, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] +pub struct FormatMetadata { + pub version: Version, + + pub compressed: Compressed, + pub chain_size: usize, // NOTE: Unused + pub checksum: u64, // NOTE: Unused +} + +impl FormatMetadata { + const MAGIC_NUM: &[u8; 8] = b"MARKOV\x00\xcf"; +} + +impl BinaryFormat for FormatMetadata +{ + fn write_to(&self, mut stream: &mut S) -> io::Result + where S: io::Write { + let sz = self.version.write_to(&mut stream)?; + + let mut obuf = [0u8; std::mem::size_of::() +std::mem::size_of::() + std::mem::size_of::()]; + { + let mut obuf = &mut obuf[..]; + use std::convert::TryInto; + + obuf.put_u32(self.compressed.to_int()); + obuf.put_u64(self.chain_size.try_into().map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "Chain size attribute out-of-bounds for format size"))?); + obuf.put_u64(self.checksum); + } + stream.write_all(&obuf[..])?; + + Ok(sz + obuf.len()) + } + + fn read_from(&mut self, mut stream: &mut S) -> io::Result + where S: io::Read { + let sz = self.version.read_from(&mut stream)?; + + if self.version > Version::CURRENT { + return Err(io::Error::new(io::ErrorKind::Unsupported, format!("Unknown format version {}", self.version))); + } + + let mut ibuf = [0u8; std::mem::size_of::() +std::mem::size_of::() + std::mem::size_of::()]; + stream.read_exact(&mut ibuf[..])?; + { + let mut ibuf = &ibuf[..]; + use std::convert::TryInto; + + self.compressed = Compressed::try_from_int(ibuf.get_u32()).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid compression attribute"))?; + self.chain_size = ibuf.get_u64().try_into().map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Chain size attribute out-of-bounds for native size"))?; + self.checksum = ibuf.get_u64(); + } + + Ok(sz + ibuf.len()) + } + + #[inline] + fn binary_format_read_size(&mut self) -> Option { + let szm = self.version.binary_format_read_size()?; + Some(szm + std::mem::size_of::() + + std::mem::size_of::() + + std::mem::size_of::()) + } + + #[inline(always)] + fn binary_format_write_size(&self) -> usize { + self.version.binary_format_write_size() + + std::mem::size_of::() + + std::mem::size_of::() + + std::mem::size_of::() + } +} + + +/// Load a chain from a stream +#[inline] +pub fn load_chain_from_sync(stream: &mut S) -> io::Result> +where S: io::Read + ?Sized +{ + let mut stream = io::BufReader::new(stream); + { + let mut magic = FormatMetadata::MAGIC_NUM.clone(); + stream.read_exact(&mut magic[..])?; + + if &magic != FormatMetadata::MAGIC_NUM { + return Err(io::Error::new(io::ErrorKind::InvalidData, "Invalid file header tag magic number")); + } + } + + let metadata = { + let mut metadata = FormatMetadata::default(); + metadata.read_from(&mut stream)?; + metadata + }; + + match metadata.version { + Version::CURRENT => { + let read = |read: &mut (dyn io::Read)| serde_cbor::from_reader(read).expect("Failed to read chain from input stream"); // TODO: Error type + + match metadata.compressed { + Compressed::No => + Ok(read(&mut stream)), + Compressed::Zstd => { + let mut stream = zstd::Decoder::with_buffer(stream)?; + + //#[cfg(feature="threads")] + //stream.multithread(num_cpus::get() as i32); + + //NOTE: Not required here: //stream.finish()?; + Ok(read(&mut stream)) + }, + } + }, + unsup => { + return Err(io::Error::new(io::ErrorKind::Unsupported, format!("Unsupported payload version {}", unsup))); + }, + } +} + +/// Save a chain to a stream with optional compression. +#[inline] +pub fn save_chain_to_sync(stream: &mut S, chain: &Chain, compress: bool) -> io::Result<()> +where S: io::Write + ?Sized +{ + let mut stream = io::BufWriter::new(stream); + + let metadata = FormatMetadata { + compressed: compress + .then_some(Compressed::Zstd) + .unwrap_or(Compressed::No), + ..Default::default() + }; + + stream.write_all(FormatMetadata::MAGIC_NUM)?; + metadata.write_to(&mut stream)?; + + let write = |stream: &mut (dyn io::Write)| serde_cbor::to_writer(stream, chain).expect("Failed to write chain to output stream"); // TODO: Error type + + let mut stream = match metadata.compressed { + Compressed::No => { + write(&mut stream); + stream + }, + Compressed::Zstd => { + let mut stream = zstd::Encoder::new(stream, 22)?; + + #[cfg(feature="threads")] + stream.multithread(num_cpus::get() as u32)?; + + write(&mut stream); + // XXX: Should we flush after write here..? + + // NOTE: Required here. + stream.finish()? + }, + }; + stream.flush() +} + +//TODO: Add `tokio_uring` version of `save_chain_to_file()`/`load_chain_from_file()` that spawns the `tokio_uring` runtime internally to queue reads/writes from/to a file. diff --git a/src/main.rs b/src/main.rs index 208a57a..abb757a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,7 +31,7 @@ pub struct Cli { force: bool, /// The number of lines to output from the chain. - lines: Option, + lines: Option, // TODO: Allow 0 for only save + load operations. //TODO: Num of lines, etc. } @@ -57,21 +57,18 @@ fn parse_cli() -> Cli { Cli::parse() } +mod format; + fn load_chain(stream: &mut S) -> io::Result> where S: io::Read + ?Sized { - let mut stream = io::BufReader::new(stream); - Ok(serde_cbor::from_reader(&mut stream).expect("Failed to read chain from input stream")) // TODO: Error type + format::load_chain_from_sync(stream) } fn save_chain(stream: &mut S, chain: &Chain) -> io::Result<()> where S: io::Write + ?Sized { - use io::Write; - - let mut stream = io::BufWriter::new(stream); - serde_cbor::to_writer(&mut stream, chain).expect("Failed to write chain to output stream"); // TODO: Error type - stream.flush() + format::save_chain_to_sync(stream, chain, true) //TODO: Change compression to be off for small chains...? We will need to store the chain size info somewhere else. } fn create_chain(cli: &Cli) -> Chain @@ -109,6 +106,7 @@ fn main() { let mut stdin = stdin.lock(); let mut chain = create_chain(&cli); + //TODO: When chain is not empty (i.e. loaded,) it is okay for the input to be empty. (XXX: Should we *ignore* stdin for this? Empty stdin (CTRL+D on TTY) works. Do we actually need to change this behaviour? I don't think so.) (TODO: Add option to skip reading when loading file, maybe?) buffered_read_all_lines(&mut stdin, |string| { chain.feed(&string.split_whitespace() .filter(|word| !word.is_empty())