Compare commits
16 Commits
master
...
sock-buffe
Author | SHA1 | Date |
---|---|---|
Avril | c3f678a81e | 3 years ago |
Avril | d82c46b12d | 3 years ago |
Avril | ed69a8f187 | 3 years ago |
Avril | 019bdee5c1 | 3 years ago |
Avril | fc2b10a306 | 3 years ago |
Avril | 5d5748b5ea | 3 years ago |
Avril | 7cad244c16 | 3 years ago |
Avril | 56306fae83 | 3 years ago |
Avril | c96d098441 | 3 years ago |
Avril | 9d927c548a | 3 years ago |
Avril | a6a25259b8 | 3 years ago |
Avril | 818659b83c | 3 years ago |
Avril | 19d1db35d6 | 3 years ago |
Avril | 90c9fce20c | 3 years ago |
Avril | 6f8d367080 | 3 years ago |
Avril | 69d546d2d1 | 3 years ago |
@ -0,0 +1,371 @@
|
|||||||
|
//! Stream buffering for sync of encrypted socked.
|
||||||
|
use super::*;
|
||||||
|
use smallvec::SmallVec;
|
||||||
|
use std::io;
|
||||||
|
use std::{
|
||||||
|
task::{
|
||||||
|
Poll, Context,
|
||||||
|
},
|
||||||
|
pin::Pin,
|
||||||
|
};
|
||||||
|
use bytes::{
|
||||||
|
Buf, BufMut,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A wrapping buffer over a writer and/or reader.
|
||||||
|
#[pin_project]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Buffered<T: ?Sized, const SIZE: usize>
|
||||||
|
{
|
||||||
|
/// Current internal buffer
|
||||||
|
/// When it's full to `SIZE`, it should be written to `stream` at once then cleared when it's been written.
|
||||||
|
buffer: SmallVec<[u8; SIZE]>, //TODO: Can we have a non-spilling stack vec?
|
||||||
|
pending: usize, w: usize,
|
||||||
|
#[pin] stream: T
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, const SIZE: usize> Buffered<T, SIZE>
|
||||||
|
where [(); SIZE]: Sized, // This isn't checked?
|
||||||
|
{
|
||||||
|
/// Create a new staticly sized buffer wrapper around this stream
|
||||||
|
pub fn new(stream: T) -> Self
|
||||||
|
{
|
||||||
|
assert!(SIZE > 0, "Size of buffer cannot be 0");
|
||||||
|
Self {
|
||||||
|
buffer: SmallVec::new(),
|
||||||
|
pending: 0, w: 0,
|
||||||
|
stream,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Consume into the wrapped stream
|
||||||
|
pub fn into_inner(self) -> T
|
||||||
|
{
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
/// The inner stream
|
||||||
|
pub fn inner(&self) -> &T
|
||||||
|
{
|
||||||
|
&self.stream
|
||||||
|
}
|
||||||
|
/// A mutable reference to the backing stream
|
||||||
|
pub fn inner_mut(&mut self) -> &mut T
|
||||||
|
{
|
||||||
|
&mut self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The current buffer bytes.
|
||||||
|
pub fn current_buffer(&self) -> &[u8]
|
||||||
|
{
|
||||||
|
&self.buffer[..]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Is the current internal buffer empty?
|
||||||
|
///
|
||||||
|
/// You can flush a partially-filled buffer to the backing stream of a writer with `.flush().await`.
|
||||||
|
pub fn is_empty(&self) -> bool
|
||||||
|
{
|
||||||
|
self.buffer.is_empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline] fn div_mod<V>(a: V, b: V) -> (V, <V as std::ops::Div>::Output, <V as std::ops::Rem>::Output)
|
||||||
|
where V: std::ops::Div + std::ops::Rem + Clone
|
||||||
|
{
|
||||||
|
(a.clone(), a.clone() / b.clone(), a % b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// XXX: I don't think writing futures like this is safe. Expand the inline `async{}`s into actual polling.
|
||||||
|
impl<T: ?Sized + Unpin, const SIZE: usize> AsyncWrite for Buffered<T, SIZE>
|
||||||
|
where T: AsyncWrite
|
||||||
|
{
|
||||||
|
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
|
||||||
|
// TODO: Don't write poll methods like this ffs... Write it properly.
|
||||||
|
let this = self.get_mut();
|
||||||
|
let fut = async {
|
||||||
|
use tokio::prelude::*;
|
||||||
|
|
||||||
|
let mut written=0;
|
||||||
|
let mut err = None;
|
||||||
|
this.buffer.extend_from_slice(buf);
|
||||||
|
|
||||||
|
for chunk in this.buffer.chunks_exact(SIZE)
|
||||||
|
{
|
||||||
|
if cfg!(test) {
|
||||||
|
eprintln!("Pushing chunk: {:?}", chunk);
|
||||||
|
}
|
||||||
|
match this.stream.write_all(&chunk).await {
|
||||||
|
Ok(()) => {
|
||||||
|
written += chunk.len();
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
err = Some(e);
|
||||||
|
break;
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.buffer.drain(0..written);
|
||||||
|
if let Some(err) = err {
|
||||||
|
Err(err)
|
||||||
|
} else {
|
||||||
|
Ok(buf.len())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tokio::pin!(fut);
|
||||||
|
fut.poll(cx)
|
||||||
|
|
||||||
|
}
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
let fut = async {
|
||||||
|
use tokio::prelude::*;
|
||||||
|
|
||||||
|
let wres = if this.buffer.len() > 0 {
|
||||||
|
if cfg!(test) {
|
||||||
|
eprintln!("Pushing rest: {:?}", &this.buffer[..]);
|
||||||
|
}
|
||||||
|
let res = this.stream.write_all(&this.buffer[..]).await;
|
||||||
|
this.buffer.clear();
|
||||||
|
res
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
};
|
||||||
|
this.stream.flush().await?;
|
||||||
|
wres
|
||||||
|
};
|
||||||
|
tokio::pin!(fut);
|
||||||
|
fut.poll(cx)
|
||||||
|
}
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
let fut = async {
|
||||||
|
use tokio::prelude::*;
|
||||||
|
|
||||||
|
this.flush().await?;
|
||||||
|
this.stream.shutdown().await
|
||||||
|
};
|
||||||
|
tokio::pin!(fut);
|
||||||
|
fut.poll(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_release_buf<const SIZE: usize>(buffer: &mut SmallVec<[u8; SIZE]>, into: &mut [u8]) -> (bool, usize)
|
||||||
|
{
|
||||||
|
let sz = std::cmp::min(buffer.len(), into.len());
|
||||||
|
(&mut into[..sz]).copy_from_slice(&buffer[..sz]);
|
||||||
|
drop(buffer.drain(..sz));
|
||||||
|
(!buffer.is_empty(), sz)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsyncRead + Unpin + ?Sized, const SIZE: usize> Buffered<T, SIZE>
|
||||||
|
{
|
||||||
|
async fn fill_buffer(&mut self) -> io::Result<bool>
|
||||||
|
{
|
||||||
|
let sz = self.buffer.len();
|
||||||
|
Ok(if sz != SIZE { // < SIZE
|
||||||
|
use tokio::prelude::*;
|
||||||
|
|
||||||
|
// XXXX::: I think the issue is, this function comes before the await point, meaning it is ran twice after the first poll? I have no fucking idea. I hate this... I just want a god damn buffered stream. WHY IS THIS SO CANCEROUS.
|
||||||
|
self.buffer.resize(SIZE, 0);
|
||||||
|
|
||||||
|
let done = {
|
||||||
|
let mut r=0;
|
||||||
|
let mut done =sz;
|
||||||
|
while done < SIZE && {r = self.stream.read(&mut self.buffer[done..]).await?; r > 0} {
|
||||||
|
done += r;
|
||||||
|
}
|
||||||
|
done
|
||||||
|
};
|
||||||
|
println!("Filling buffer to {}", done);
|
||||||
|
if done == SIZE {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
self.buffer.resize(done, 0);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else { // == SIZE
|
||||||
|
debug_assert!(sz == SIZE);
|
||||||
|
true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_take_buffer<B: ?Sized + BufMut>(&mut self, to: &mut B) -> usize
|
||||||
|
{
|
||||||
|
if self.buffer.is_empty() {
|
||||||
|
println!("Buffer empty, skipping take");
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
let copy = std::cmp::min(self.buffer.len(), to.remaining_mut());
|
||||||
|
println!("Draining {} bytes into output", copy);
|
||||||
|
|
||||||
|
to.put_slice(&self.buffer[..copy]);
|
||||||
|
self.buffer.drain(..copy);
|
||||||
|
copy
|
||||||
|
}
|
||||||
|
|
||||||
|
// async-based impl of `read`. there as a reference for when we find out how to write `poll_read`. Sigh...
|
||||||
|
async fn read_test(&mut self, buf: &mut [u8]) -> io::Result<usize>
|
||||||
|
{
|
||||||
|
use tokio::prelude::*;
|
||||||
|
|
||||||
|
let mut w = 0;
|
||||||
|
|
||||||
|
while w < buf.len() {
|
||||||
|
match self.try_take_buffer(&mut &mut buf[w..]) {
|
||||||
|
0 => {
|
||||||
|
if !self.fill_buffer().await?
|
||||||
|
&& self.buffer.is_empty()
|
||||||
|
{
|
||||||
|
println!("Buffer empty");
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
println!("Buffer filled");
|
||||||
|
}
|
||||||
|
},
|
||||||
|
x => w+=x,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("Done: {}", w);
|
||||||
|
Result::<usize, io::Error>::Ok(w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// XXX: I don't think writing futures like this is safe. Expand the inline `async{}`s into actual polling.
|
||||||
|
impl<T: ?Sized + Unpin, const SIZE: usize> AsyncRead for Buffered<T, SIZE>
|
||||||
|
where T: AsyncRead
|
||||||
|
{
|
||||||
|
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
|
||||||
|
|
||||||
|
|
||||||
|
let this = self.get_mut();
|
||||||
|
let res = loop {
|
||||||
|
|
||||||
|
let read = if this.buffer.len() < SIZE || this.pending > 0
|
||||||
|
{
|
||||||
|
let st = if this.pending > 0 {this.pending-1} else { this.buffer.len() };
|
||||||
|
this.buffer.resize(SIZE, 0);
|
||||||
|
|
||||||
|
let mut done=st;
|
||||||
|
let mut r=0;
|
||||||
|
//XXX: Same issue even trying to save buffer length state over Pendings... Wtf is going on here?
|
||||||
|
macro_rules! ready {
|
||||||
|
(try $poll:expr) => {
|
||||||
|
match $poll {
|
||||||
|
Poll::Pending => {
|
||||||
|
this.pending = st+1;
|
||||||
|
//this.buffer.resize(done, 0);
|
||||||
|
return Poll::Pending;
|
||||||
|
},
|
||||||
|
Poll::Ready(Ok(x)) => x,
|
||||||
|
err => {
|
||||||
|
//this.buffer.resize(done, 0);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// XXX: V Same issue, runs the above code twice when re-polling after Pending. We need to make sure we jump back to this point in the code following a Pending poll to `stream.poll_read`, but I have no fucking clue how to do this? Eh...... We'll probably need to design the code differently. There is a lot of state that gets lost here and idk how to preserve it.... I hate this.
|
||||||
|
while done < SIZE && {r = ready!(try Pin::new(&mut this.stream).poll_read(cx, &mut this.buffer[done..])); r > 0}
|
||||||
|
{
|
||||||
|
done +=r;
|
||||||
|
}
|
||||||
|
this.pending = 0;
|
||||||
|
// This causes early eof (0)
|
||||||
|
//println!("Done: {}", done);
|
||||||
|
//this.buffer.resize(done, 0);
|
||||||
|
done
|
||||||
|
} else {
|
||||||
|
this.buffer.len()
|
||||||
|
};
|
||||||
|
match this.try_take_buffer(&mut &mut buf[this.w..]) {
|
||||||
|
0 => break Ok(this.w),
|
||||||
|
x => this.w+=x,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
this.w = 0;
|
||||||
|
Poll::Ready(res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests
|
||||||
|
{
|
||||||
|
use super::*;
|
||||||
|
#[tokio::test]
|
||||||
|
async fn writer() -> eyre::Result<()>
|
||||||
|
{
|
||||||
|
use tokio::prelude::*;
|
||||||
|
|
||||||
|
let (tx, mut rx) = tokio::io::duplex(11);
|
||||||
|
let mut ttx = Buffered::<_, 4>::new(tx);
|
||||||
|
|
||||||
|
let back = tokio::spawn(async move {
|
||||||
|
|
||||||
|
println!("Writing bytes");
|
||||||
|
ttx.write_all(b"Hello world").await?;
|
||||||
|
|
||||||
|
println!("Waiting 1 second...");
|
||||||
|
tokio::time::delay_for(tokio::time::Duration::from_secs(1)).await;
|
||||||
|
println!("Flushing stream");
|
||||||
|
ttx.flush().await?;
|
||||||
|
ttx.shutdown().await?;
|
||||||
|
Result::<_, std::io::Error>::Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut output = Vec::new();
|
||||||
|
println!("Reading full stream...");
|
||||||
|
tokio::io::copy(&mut rx, &mut output).await?;
|
||||||
|
println!("Waiting for background...");
|
||||||
|
back.await.expect("Back panick")?;
|
||||||
|
|
||||||
|
println!("Expected {:?}, got {:?}", b"Hello world", &output);
|
||||||
|
assert_eq!(&output[..], b"Hello world");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reader() -> eyre::Result<()>
|
||||||
|
{
|
||||||
|
use tokio::prelude::*;
|
||||||
|
|
||||||
|
const DATA: &'static [u8] = b"Hello world";
|
||||||
|
|
||||||
|
let (mut tx, rx) = tokio::io::duplex(11);
|
||||||
|
let mut rx = Buffered::<_, 4>::new(rx);
|
||||||
|
|
||||||
|
let back = tokio::spawn(async move {
|
||||||
|
tx.write_all(DATA).await?;
|
||||||
|
tx.write_all(DATA).await?;
|
||||||
|
tx.flush().await?;
|
||||||
|
|
||||||
|
tx.shutdown().await?;
|
||||||
|
Result::<_, std::io::Error>::Ok(())
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut output = vec![0u8; DATA.len()*2];
|
||||||
|
// Bug found! Pinning and polling that stack future in `poll_read` does NOT work!
|
||||||
|
// (we unrolled the async function to a poll based one and we're STILL losing state.... FFS!)
|
||||||
|
// The exact same works as a real async function.
|
||||||
|
|
||||||
|
/*
|
||||||
|
rx.read(&mut output[..DATA.len()]).await?;
|
||||||
|
rx.read(&mut output[DATA.len()..]).await?;
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* THIS SHIT HANGS???????????????
|
||||||
|
tokio::io::copy(&mut rx, &mut output).await?;
|
||||||
|
*/
|
||||||
|
|
||||||
|
assert_eq!(rx.read(&mut output[..DATA.len()]).await?, DATA.len());
|
||||||
|
assert_eq!(rx.read(&mut output[DATA.len()..]).await?, DATA.len());
|
||||||
|
|
||||||
|
back.await.expect("Back panick")?;
|
||||||
|
|
||||||
|
eprintln!("String: {}", String::from_utf8_lossy(&output[..]));
|
||||||
|
assert_eq!(&output[..DATA.len()], &DATA[..]);
|
||||||
|
assert_eq!(&output[DATA.len()..], &DATA[..]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in new issue