Added proper file format for save/load of chain. Added internal ZSTD compression of chain stream.

Fortune for genmarkov's current commit: Blessing − 吉
cli
Avril 4 weeks ago
parent c4fc2fde1d
commit 066811444a
Signed by: flanchan
GPG Key ID: 284488987C31F630

83
Cargo.lock generated

@ -58,6 +58,25 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
[[package]]
name = "bytes"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9"
dependencies = [
"serde",
]
[[package]]
name = "cc"
version = "1.0.94"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7"
dependencies = [
"jobserver",
"libc",
]
[[package]] [[package]]
name = "cfg-if" name = "cfg-if"
version = "0.1.10" version = "0.1.10"
@ -166,6 +185,12 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "hermit-abi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "1.6.0" version = "1.6.0"
@ -191,6 +216,15 @@ dependencies = [
"either", "either",
] ]
[[package]]
name = "jobserver"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.79" version = "0.2.79"
@ -205,12 +239,15 @@ checksum = "8dd5a6d5999d9907cda8ed67bbd137d3af8085216c2ac62de5be860bd41f304a"
[[package]] [[package]]
name = "markov" name = "markov"
version = "0.2.0" version = "0.2.1"
dependencies = [ dependencies = [
"bytes",
"clap", "clap",
"markov 1.1.0", "markov 1.1.0",
"num_cpus",
"serde", "serde",
"serde_cbor", "serde_cbor",
"zstd",
] ]
[[package]] [[package]]
@ -228,6 +265,16 @@ dependencies = [
"serde_yaml", "serde_yaml",
] ]
[[package]]
name = "num_cpus"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
"hermit-abi",
"libc",
]
[[package]] [[package]]
name = "once_cell" name = "once_cell"
version = "1.20.3" version = "1.20.3"
@ -244,6 +291,12 @@ dependencies = [
"indexmap", "indexmap",
] ]
[[package]]
name = "pkg-config"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2"
[[package]] [[package]]
name = "ppv-lite86" name = "ppv-lite86"
version = "0.2.9" version = "0.2.9"
@ -473,3 +526,31 @@ checksum = "39f0c922f1a334134dc2f7a8b67dc5d25f0735263feec974345ff706bcf20b0d"
dependencies = [ dependencies = [
"linked-hash-map", "linked-hash-map",
] ]
[[package]]
name = "zstd"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9"
dependencies = [
"zstd-safe",
]
[[package]]
name = "zstd-safe"
version = "7.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059"
dependencies = [
"zstd-sys",
]
[[package]]
name = "zstd-sys"
version = "2.0.13+zstd.1.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa"
dependencies = [
"cc",
"pkg-config",
]

@ -1,6 +1,6 @@
[package] [package]
name = "markov" name = "markov"
version = "0.2.0" version = "0.2.1"
description = "Generate string of text from Markov chain fed by stdin or file(s)" description = "Generate string of text from Markov chain fed by stdin or file(s)"
authors = ["Avril <flanchan@cumallover.me>"] authors = ["Avril <flanchan@cumallover.me>"]
edition = "2018" edition = "2018"
@ -12,9 +12,23 @@ opt-level = 3
lto = true lto = true
#"fat" #"fat"
codegen-units = 1 codegen-units = 1
strip = true
[profile.symbols]
inherits = "release"
lto = "fat"
strip = false
[features]
default = ["threads"]
threads = ["zstd/zstdmt", "dep:num_cpus"]
[dependencies] [dependencies]
bytes = { version = "1.10.0", features = ["serde"] }
chain = {package = "markov", version = "1.1.0" } chain = {package = "markov", version = "1.1.0" }
clap = { version = "4.5.29", features = ["derive"] } clap = { version = "4.5.29", features = ["derive"] }
num_cpus = { version = "1.16.0", optional = true }
serde = { version = "1.0.217", features = ["derive"] } serde = { version = "1.0.217", features = ["derive"] }
serde_cbor = { version = "0.11.2", features = ["alloc"] } serde_cbor = { version = "0.11.2", features = ["alloc"] }
zstd = { version = "0.13.2", features = [] }

@ -0,0 +1,379 @@
//! Handles the chain load/save format
use super::*;
use std::{
io::{
self,
Read, Write, BufRead,
},
fmt,
};
use bytes::{
Buf, BufMut, Bytes,
};
use zstd::{
Encoder, Decoder,
};
/// The chain that can be saved / loaded.
pub type Chain<T = String> = crate::Chain<T>;
/// The version of the encoded format stream
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Copy)]
#[repr(packed)]
pub struct Version(pub u8,pub u8,pub u8,pub u8);
impl fmt::Display for Version
{
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "{}.{}.{}", self.0,self.1, self.2)?;
if self.3 != 0 {
write!(f, "r{}", self.3)
} else {
Ok(())
}
}
}
impl Version {
/// Current save version
pub const CURRENT: Self = Version(0,0,0,0);
/// Current value as a native integer
const CURRENT_VALUE: u32 = Self::CURRENT.as_native();
pub const fn as_native(&self) -> u32 {
u32::from_be_bytes([self.0, self.1, self.2, self.3])
}
#[inline]
pub const fn from_native(value: u32) -> Self {
let [a,b,c,d] = u32::to_be_bytes(value);
Self(a,b,c,d)
}
}
impl Default for Version
{
#[inline]
fn default() -> Self
{
Self::CURRENT
}
}
pub unsafe trait AutoBinaryFormat: Sized {
#[inline]
fn as_raw_for_encode(&self) -> *const [u8] {
let ptr = self as *const Self;
std::ptr::slice_from_raw_parts(ptr as *const u8, std::mem::size_of::<Self>())
}
#[inline]
fn as_raw_for_decode(&mut self) -> *mut [u8] {
let ptr = self as *mut Self;
std::ptr::slice_from_raw_parts_mut(ptr as *mut u8, std::mem::size_of::<Self>())
}
fn raw_format_read_size(&mut self) -> usize {
std::mem::size_of::<Self>()
}
fn raw_format_write_size(&self) -> usize {
std::mem::size_of::<Self>()
}
}
unsafe impl<T, const N: usize> AutoBinaryFormat for [T; N] {}
pub trait BinaryFormat {
fn read_from<S: ?Sized>(&mut self, stream: &mut S) -> io::Result<usize>
where S: io::Read;
fn write_to<S: ?Sized>(&self, stream: &mut S) -> io::Result<usize>
where S: io::Write;
fn binary_format_read_size(&mut self) -> Option<usize>;
fn binary_format_write_size(&self) -> usize;
}
impl<T> BinaryFormat for T
where T: AutoBinaryFormat
{
#[inline(always)]
fn read_from<S: ?Sized>(&mut self, stream: &mut S) -> io::Result<usize>
where S: io::Read {
let ptr = self.as_raw_for_decode();
// SAFETY: The read data is guaranteed to be valid here.
Ok(unsafe {
stream.read_exact(&mut *ptr)?;
(*ptr).len()
})
}
#[inline(always)]
fn write_to<S: ?Sized>(&self, stream: &mut S) -> io::Result<usize>
where S: io::Write {
let ptr = self.as_raw_for_encode();
// SAFETY: The written data is guaranteed to be valid here.
Ok(unsafe {
stream.write_all(&*ptr)?;
(*ptr).len()
})
}
#[inline]
fn binary_format_read_size(&mut self) -> Option<usize> {
Some(self.raw_format_read_size())
}
#[inline]
fn binary_format_write_size(&self) -> usize {
self.raw_format_write_size()
}
}
impl<T> BinaryFormat for [T]
where T: BinaryFormat
{
#[inline]
fn read_from<S: ?Sized>(&mut self, stream: &mut S) -> io::Result<usize>
where S: io::Read {
let mut sz = 0;
for i in self.iter_mut() {
sz += i.read_from(stream)?;
}
Ok(sz)
}
#[inline]
fn write_to<S: ?Sized>(&self, stream: &mut S) -> io::Result<usize>
where S: io::Write {
let mut sz =0;
for i in self.iter() {
sz += i.write_to(stream)?;
}
Ok(sz)
}
#[inline]
fn binary_format_read_size(&mut self) -> Option<usize> {
self.iter_mut().map(|x| x.binary_format_read_size()).try_fold(0, |x, y| {
match (x, y) {
(x, Some(y)) => Some(x + y),
_ => None,
}
})
}
#[inline]
fn binary_format_write_size(&self) -> usize {
self.iter().map(|x| x.binary_format_write_size()).sum()
}
}
impl BinaryFormat for Version {
#[inline]
fn read_from<S: ?Sized>(&mut self, stream: &mut S) -> io::Result<usize>
where S: io::Read {
let mut vi = [0u8; 4];
stream.read_exact(&mut vi[..])?;
Ok(4)
}
#[inline]
fn write_to<S: ?Sized>(&self, stream: &mut S) -> io::Result<usize>
where S: io::Write {
let vi = [self.0,self.1,self.2,self.3];
stream.write_all(&vi[..])?;
Ok(4)
}
#[inline]
fn binary_format_read_size(&mut self) -> Option<usize> {
Some(std::mem::size_of::<u8>() * 4)
}
#[inline]
fn binary_format_write_size(&self) -> usize {
std::mem::size_of::<u8>() * 4
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy, Default)]
#[repr(u32)]
pub enum Compressed {
#[default]
No = 0,
Zstd = 1,
}
impl Compressed {
#[inline]
const fn to_int(&self) -> u32 {
*self as u32
}
#[inline]
fn try_from_int(val: u32) -> Option<Self> {
match val {
// SAFETY: These variants are known
0..=1 => Some(unsafe {
std::mem::transmute(val)
}),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
pub struct FormatMetadata {
pub version: Version,
pub compressed: Compressed,
pub chain_size: usize, // NOTE: Unused
pub checksum: u64, // NOTE: Unused
}
impl FormatMetadata {
const MAGIC_NUM: &[u8; 8] = b"MARKOV\x00\xcf";
}
impl BinaryFormat for FormatMetadata
{
fn write_to<S: ?Sized>(&self, mut stream: &mut S) -> io::Result<usize>
where S: io::Write {
let sz = self.version.write_to(&mut stream)?;
let mut obuf = [0u8; std::mem::size_of::<u32>() +std::mem::size_of::<u64>() + std::mem::size_of::<u64>()];
{
let mut obuf = &mut obuf[..];
use std::convert::TryInto;
obuf.put_u32(self.compressed.to_int());
obuf.put_u64(self.chain_size.try_into().map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "Chain size attribute out-of-bounds for format size"))?);
obuf.put_u64(self.checksum);
}
stream.write_all(&obuf[..])?;
Ok(sz + obuf.len())
}
fn read_from<S: ?Sized>(&mut self, mut stream: &mut S) -> io::Result<usize>
where S: io::Read {
let sz = self.version.read_from(&mut stream)?;
if self.version > Version::CURRENT {
return Err(io::Error::new(io::ErrorKind::Unsupported, format!("Unknown format version {}", self.version)));
}
let mut ibuf = [0u8; std::mem::size_of::<u32>() +std::mem::size_of::<u64>() + std::mem::size_of::<u64>()];
stream.read_exact(&mut ibuf[..])?;
{
let mut ibuf = &ibuf[..];
use std::convert::TryInto;
self.compressed = Compressed::try_from_int(ibuf.get_u32()).ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid compression attribute"))?;
self.chain_size = ibuf.get_u64().try_into().map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Chain size attribute out-of-bounds for native size"))?;
self.checksum = ibuf.get_u64();
}
Ok(sz + ibuf.len())
}
#[inline]
fn binary_format_read_size(&mut self) -> Option<usize> {
let szm = self.version.binary_format_read_size()?;
Some(szm + std::mem::size_of::<u32>()
+ std::mem::size_of::<u64>()
+ std::mem::size_of::<u64>())
}
#[inline(always)]
fn binary_format_write_size(&self) -> usize {
self.version.binary_format_write_size()
+ std::mem::size_of::<u32>()
+ std::mem::size_of::<u64>()
+ std::mem::size_of::<u64>()
}
}
/// Load a chain from a stream
#[inline]
pub fn load_chain_from_sync<S>(stream: &mut S) -> io::Result<Chain<String>>
where S: io::Read + ?Sized
{
let mut stream = io::BufReader::new(stream);
{
let mut magic = FormatMetadata::MAGIC_NUM.clone();
stream.read_exact(&mut magic[..])?;
if &magic != FormatMetadata::MAGIC_NUM {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Invalid file header tag magic number"));
}
}
let metadata = {
let mut metadata = FormatMetadata::default();
metadata.read_from(&mut stream)?;
metadata
};
match metadata.version {
Version::CURRENT => {
let read = |read: &mut (dyn io::Read)| serde_cbor::from_reader(read).expect("Failed to read chain from input stream"); // TODO: Error type
match metadata.compressed {
Compressed::No =>
Ok(read(&mut stream)),
Compressed::Zstd => {
let mut stream = zstd::Decoder::with_buffer(stream)?;
//#[cfg(feature="threads")]
//stream.multithread(num_cpus::get() as i32);
//NOTE: Not required here: //stream.finish()?;
Ok(read(&mut stream))
},
}
},
unsup => {
return Err(io::Error::new(io::ErrorKind::Unsupported, format!("Unsupported payload version {}", unsup)));
},
}
}
/// Save a chain to a stream with optional compression.
#[inline]
pub fn save_chain_to_sync<S>(stream: &mut S, chain: &Chain<String>, compress: bool) -> io::Result<()>
where S: io::Write + ?Sized
{
let mut stream = io::BufWriter::new(stream);
let metadata = FormatMetadata {
compressed: compress
.then_some(Compressed::Zstd)
.unwrap_or(Compressed::No),
..Default::default()
};
stream.write_all(FormatMetadata::MAGIC_NUM)?;
metadata.write_to(&mut stream)?;
let write = |stream: &mut (dyn io::Write)| serde_cbor::to_writer(stream, chain).expect("Failed to write chain to output stream"); // TODO: Error type
let mut stream = match metadata.compressed {
Compressed::No => {
write(&mut stream);
stream
},
Compressed::Zstd => {
let mut stream = zstd::Encoder::new(stream, 22)?;
#[cfg(feature="threads")]
stream.multithread(num_cpus::get() as u32)?;
write(&mut stream);
// XXX: Should we flush after write here..?
// NOTE: Required here.
stream.finish()?
},
};
stream.flush()
}
//TODO: Add `tokio_uring` version of `save_chain_to_file()`/`load_chain_from_file()` that spawns the `tokio_uring` runtime internally to queue reads/writes from/to a file.

@ -31,7 +31,7 @@ pub struct Cli {
force: bool, force: bool,
/// The number of lines to output from the chain. /// The number of lines to output from the chain.
lines: Option<NonZeroUsize>, lines: Option<NonZeroUsize>, // TODO: Allow 0 for only save + load operations.
//TODO: Num of lines, etc. //TODO: Num of lines, etc.
} }
@ -57,21 +57,18 @@ fn parse_cli() -> Cli {
Cli::parse() Cli::parse()
} }
mod format;
fn load_chain<S>(stream: &mut S) -> io::Result<Chain<String>> fn load_chain<S>(stream: &mut S) -> io::Result<Chain<String>>
where S: io::Read + ?Sized where S: io::Read + ?Sized
{ {
let mut stream = io::BufReader::new(stream); format::load_chain_from_sync(stream)
Ok(serde_cbor::from_reader(&mut stream).expect("Failed to read chain from input stream")) // TODO: Error type
} }
fn save_chain<S>(stream: &mut S, chain: &Chain<String>) -> io::Result<()> fn save_chain<S>(stream: &mut S, chain: &Chain<String>) -> io::Result<()>
where S: io::Write + ?Sized where S: io::Write + ?Sized
{ {
use io::Write; format::save_chain_to_sync(stream, chain, true) //TODO: Change compression to be off for small chains...? We will need to store the chain size info somewhere else.
let mut stream = io::BufWriter::new(stream);
serde_cbor::to_writer(&mut stream, chain).expect("Failed to write chain to output stream"); // TODO: Error type
stream.flush()
} }
fn create_chain(cli: &Cli) -> Chain<String> fn create_chain(cli: &Cli) -> Chain<String>
@ -109,6 +106,7 @@ fn main() {
let mut stdin = stdin.lock(); let mut stdin = stdin.lock();
let mut chain = create_chain(&cli); let mut chain = create_chain(&cli);
//TODO: When chain is not empty (i.e. loaded,) it is okay for the input to be empty. (XXX: Should we *ignore* stdin for this? Empty stdin (CTRL+D on TTY) works. Do we actually need to change this behaviour? I don't think so.) (TODO: Add option to skip reading when loading file, maybe?)
buffered_read_all_lines(&mut stdin, |string| { buffered_read_all_lines(&mut stdin, |string| {
chain.feed(&string.split_whitespace() chain.feed(&string.split_whitespace()
.filter(|word| !word.is_empty()) .filter(|word| !word.is_empty())

Loading…
Cancel
Save