refactor(tvix/nix-compat): reorganize wire and bytes
Move everything bytes-related into its own module, and re-export both bytes and primitive in a flat space from wire/mod.rs. Expose this if a `wire` feature flag is set. We only have `async` stuff in here. Change-Id: Ia4ce4791f13a5759901cc9d6ce6bd6bbcca587c7 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11389 Autosubmit: flokli <flokli@flokli.de> Reviewed-by: raitobezarius <tvl@lahfa.xyz> Tested-by: BuildkiteCI Reviewed-by: Brian Olsen <me@griff.name>
This commit is contained in:
parent
839c971a0f
commit
36b296609b
12 changed files with 69 additions and 96 deletions
254
tvix/nix-compat/src/wire/bytes/mod.rs
Normal file
254
tvix/nix-compat/src/wire/bytes/mod.rs
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
use std::{
|
||||
io::{Error, ErrorKind},
|
||||
ops::RangeBounds,
|
||||
};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
mod reader;
|
||||
pub use reader::BytesReader;
|
||||
mod writer;
|
||||
pub use writer::BytesWriter;
|
||||
|
||||
use super::primitive;
|
||||
|
||||
/// 8 null bytes, used to write out padding.
|
||||
const EMPTY_BYTES: &[u8; 8] = &[0u8; 8];
|
||||
|
||||
/// The length of the size field, in bytes is always 8.
|
||||
const LEN_SIZE: usize = 8;
|
||||
|
||||
#[allow(dead_code)]
|
||||
/// Read a "bytes wire packet" from the AsyncRead.
|
||||
/// Rejects reading more than `allowed_size` bytes of payload.
|
||||
///
|
||||
/// The packet is made up of three parts:
|
||||
/// - a length header, u64, LE-encoded
|
||||
/// - the payload itself
|
||||
/// - null bytes to the next 8 byte boundary
|
||||
///
|
||||
/// Ensures the payload size fits into the `allowed_size` passed,
|
||||
/// and that the padding is actual null bytes.
|
||||
///
|
||||
/// On success, the returned `Vec<u8>` only contains the payload itself.
|
||||
/// On failure (for example if a too large byte packet was sent), the reader
|
||||
/// becomes unusable.
|
||||
///
|
||||
/// This buffers the entire payload into memory, a streaming version will be
|
||||
/// added later.
|
||||
pub async fn read_bytes<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<Vec<u8>>
|
||||
where
|
||||
R: AsyncReadExt + Unpin,
|
||||
S: RangeBounds<u64>,
|
||||
{
|
||||
// read the length field
|
||||
let len = primitive::read_u64(r).await?;
|
||||
|
||||
if !allowed_size.contains(&len) {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"signalled package size not in allowed range",
|
||||
));
|
||||
}
|
||||
|
||||
// calculate the total length, including padding.
|
||||
// byte packets are padded to 8 byte blocks each.
|
||||
let padded_len = padding_len(len) as u64 + (len as u64);
|
||||
let mut limited_reader = r.take(padded_len);
|
||||
|
||||
let mut buf = Vec::new();
|
||||
|
||||
let s = limited_reader.read_to_end(&mut buf).await?;
|
||||
|
||||
// make sure we got exactly the number of bytes, and not less.
|
||||
if s as u64 != padded_len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"got less bytes than expected",
|
||||
));
|
||||
}
|
||||
|
||||
let (_content, padding) = buf.split_at(len as usize);
|
||||
|
||||
// ensure the padding is all zeroes.
|
||||
if !padding.iter().all(|e| *e == b'\0') {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"padding is not all zeroes",
|
||||
));
|
||||
}
|
||||
|
||||
// return the data without the padding
|
||||
buf.truncate(len as usize);
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
/// Read a "bytes wire packet" of from the AsyncRead and tries to parse as string.
|
||||
/// Internally uses [read_bytes].
|
||||
/// Rejects reading more than `allowed_size` bytes of payload.
|
||||
pub async fn read_string<R, S>(r: &mut R, allowed_size: S) -> std::io::Result<String>
|
||||
where
|
||||
R: AsyncReadExt + Unpin,
|
||||
S: RangeBounds<u64>,
|
||||
{
|
||||
let bytes = read_bytes(r, allowed_size).await?;
|
||||
String::from_utf8(bytes).map_err(|e| Error::new(ErrorKind::InvalidData, e))
|
||||
}
|
||||
|
||||
/// Writes a "bytes wire packet" to a (hopefully buffered) [AsyncWriteExt].
|
||||
///
|
||||
/// Accepts anything implementing AsRef<[u8]> as payload.
|
||||
///
|
||||
/// See [read_bytes] for a description of the format.
|
||||
///
|
||||
/// Note: if performance matters to you, make sure your
|
||||
/// [AsyncWriteExt] handle is buffered. This function is quite
|
||||
/// write-intesive.
|
||||
pub async fn write_bytes<W: AsyncWriteExt + Unpin, B: AsRef<[u8]>>(
|
||||
w: &mut W,
|
||||
b: B,
|
||||
) -> std::io::Result<()> {
|
||||
// write the size packet.
|
||||
primitive::write_u64(w, b.as_ref().len() as u64).await?;
|
||||
|
||||
// write the payload
|
||||
w.write_all(b.as_ref()).await?;
|
||||
|
||||
// write padding if needed
|
||||
let padding_len = padding_len(b.as_ref().len() as u64) as usize;
|
||||
if padding_len != 0 {
|
||||
w.write_all(&EMPTY_BYTES[..padding_len]).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Computes the number of bytes we should add to len (a length in
|
||||
/// bytes) to be alined on 64 bits (8 bytes).
|
||||
fn padding_len(len: u64) -> u8 {
|
||||
let modulo = len % 8;
|
||||
if modulo == 0 {
|
||||
0
|
||||
} else {
|
||||
8 - modulo as u8
|
||||
}
|
||||
}
|
||||
|
||||
/// Models the position inside a "bytes wire packet" that the reader or writer
|
||||
/// is in.
|
||||
/// It can be in three different stages, inside size, payload or padding fields.
|
||||
/// The number tracks the number of bytes written inside the specific field.
|
||||
/// There shall be no ambiguous states, at the end of a stage we immediately
|
||||
/// move to the beginning of the next one:
|
||||
/// - Size(LEN_SIZE) must be expressed as Payload(0)
|
||||
/// - Payload(self.payload_len) must be expressed as Padding(0)
|
||||
/// There's one exception - Size(LEN_SIZE) in the reader represents a failure
|
||||
/// state we enter in case the allowed size doesn't match the allowed range.
|
||||
///
|
||||
/// Padding(padding_len) means we're at the end of the bytes wire packet.
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
enum BytesPacketPosition {
|
||||
Size(usize),
|
||||
Payload(u64),
|
||||
Padding(usize),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokio_test::{assert_ok, io::Builder};
|
||||
|
||||
use super::*;
|
||||
use hex_literal::hex;
|
||||
|
||||
/// The maximum length of bytes packets we're willing to accept in the test
|
||||
/// cases.
|
||||
const MAX_LEN: u64 = 1024;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_8_bytes() {
|
||||
let mut mock = Builder::new()
|
||||
.read(&8u64.to_le_bytes())
|
||||
.read(&12345678u64.to_le_bytes())
|
||||
.build();
|
||||
|
||||
assert_eq!(
|
||||
&12345678u64.to_le_bytes(),
|
||||
read_bytes(&mut mock, 0u64..MAX_LEN)
|
||||
.await
|
||||
.unwrap()
|
||||
.as_slice()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_9_bytes() {
|
||||
let mut mock = Builder::new()
|
||||
.read(&9u64.to_le_bytes())
|
||||
.read(&hex!("01020304050607080900000000000000"))
|
||||
.build();
|
||||
|
||||
assert_eq!(
|
||||
hex!("010203040506070809"),
|
||||
read_bytes(&mut mock, 0u64..MAX_LEN)
|
||||
.await
|
||||
.unwrap()
|
||||
.as_slice()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_0_bytes() {
|
||||
// A empty byte packet is essentially just the 0 length field.
|
||||
// No data is read, and there's zero padding.
|
||||
let mut mock = Builder::new().read(&0u64.to_le_bytes()).build();
|
||||
|
||||
assert_eq!(
|
||||
hex!(""),
|
||||
read_bytes(&mut mock, 0u64..MAX_LEN)
|
||||
.await
|
||||
.unwrap()
|
||||
.as_slice()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
/// Ensure we don't read any further than the size field if the length
|
||||
/// doesn't match the range we want to accept.
|
||||
async fn test_read_reject_too_large() {
|
||||
let mut mock = Builder::new().read(&100u64.to_le_bytes()).build();
|
||||
|
||||
read_bytes(&mut mock, 10..10)
|
||||
.await
|
||||
.expect_err("expect this to fail");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_bytes_no_padding() {
|
||||
let input = hex!("6478696f34657661");
|
||||
let len = input.len() as u64;
|
||||
let mut mock = Builder::new()
|
||||
.write(&len.to_le_bytes())
|
||||
.write(&input)
|
||||
.build();
|
||||
assert_ok!(write_bytes(&mut mock, &input).await)
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn test_write_bytes_with_padding() {
|
||||
let input = hex!("322e332e3137");
|
||||
let len = input.len() as u64;
|
||||
let mut mock = Builder::new()
|
||||
.write(&len.to_le_bytes())
|
||||
.write(&hex!("322e332e31370000"))
|
||||
.build();
|
||||
assert_ok!(write_bytes(&mut mock, &input).await)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_write_string() {
|
||||
let input = "Hello, World!";
|
||||
let len = input.len() as u64;
|
||||
let mut mock = Builder::new()
|
||||
.write(&len.to_le_bytes())
|
||||
.write(&hex!("48656c6c6f2c20576f726c6421000000"))
|
||||
.build();
|
||||
assert_ok!(write_bytes(&mut mock, &input).await)
|
||||
}
|
||||
}
|
||||
463
tvix/nix-compat/src/wire/bytes/reader.rs
Normal file
463
tvix/nix-compat/src/wire/bytes/reader.rs
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
use pin_project_lite::pin_project;
|
||||
use std::{
|
||||
ops::RangeBounds,
|
||||
task::{ready, Poll},
|
||||
};
|
||||
use tokio::io::AsyncRead;
|
||||
|
||||
use super::{padding_len, BytesPacketPosition, LEN_SIZE};
|
||||
|
||||
pin_project! {
|
||||
/// Reads a "bytes wire packet" from the underlying reader.
|
||||
/// The format is the same as in [crate::wire::bytes::read_bytes],
|
||||
/// however this structure provides a [AsyncRead] interface,
|
||||
/// allowing to not having to pass around the entire payload in memory.
|
||||
///
|
||||
/// After being constructed with the underlying reader and an allowed size,
|
||||
/// subsequent requests to poll_read will return payload data until the end
|
||||
/// of the packet is reached.
|
||||
///
|
||||
/// Internally, it will first read over the size packet, filling payload_size,
|
||||
/// ensuring it fits allowed_size, then return payload data.
|
||||
/// It will only signal EOF (returning `Ok(())` without filling the buffer anymore)
|
||||
/// when all padding has been successfully consumed too.
|
||||
///
|
||||
/// This also means, it's important for a user to always read to the end,
|
||||
/// and not just call read_exact - otherwise it might not skip over the
|
||||
/// padding, and return garbage when reading the next packet.
|
||||
///
|
||||
/// In case of an error due to size constraints, or in case of not reading
|
||||
/// all the way to the end (and getting a EOF), the underlying reader is no
|
||||
/// longer usable and might return garbage.
|
||||
pub struct BytesReader<R, S>
|
||||
where
|
||||
R: AsyncRead,
|
||||
S: RangeBounds<u64>,
|
||||
|
||||
{
|
||||
#[pin]
|
||||
inner: R,
|
||||
|
||||
allowed_size: S,
|
||||
payload_size: [u8; 8],
|
||||
state: BytesPacketPosition,
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, S> BytesReader<R, S>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
S: RangeBounds<u64>,
|
||||
{
|
||||
/// Constructs a new BytesReader, using the underlying passed reader.
|
||||
pub fn new(r: R, allowed_size: S) -> Self {
|
||||
Self {
|
||||
inner: r,
|
||||
allowed_size,
|
||||
payload_size: [0; 8],
|
||||
state: BytesPacketPosition::Size(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Returns an error if the passed usize is 0.
|
||||
fn ensure_nonzero_bytes_read(bytes_read: usize) -> Result<usize, std::io::Error> {
|
||||
if bytes_read == 0 {
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"underlying reader returned EOF",
|
||||
))
|
||||
} else {
|
||||
Ok(bytes_read)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, S> AsyncRead for BytesReader<R, S>
|
||||
where
|
||||
R: AsyncRead,
|
||||
S: RangeBounds<u64>,
|
||||
{
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
let mut this = self.project();
|
||||
|
||||
// Use a loop, so we can deal with (multiple) state transitions.
|
||||
loop {
|
||||
match *this.state {
|
||||
BytesPacketPosition::Size(LEN_SIZE) => {
|
||||
// used in case an invalid size was signalled.
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"signalled package size not in allowed range",
|
||||
))?
|
||||
}
|
||||
BytesPacketPosition::Size(pos) => {
|
||||
// try to read more of the size field.
|
||||
// We wrap a BufRead around this.payload_size here, and set_filled.
|
||||
let mut read_buf = tokio::io::ReadBuf::new(this.payload_size);
|
||||
read_buf.advance(pos);
|
||||
ready!(this.inner.as_mut().poll_read(cx, &mut read_buf))?;
|
||||
|
||||
ensure_nonzero_bytes_read(read_buf.filled().len() - pos)?;
|
||||
|
||||
let total_size_read = read_buf.filled().len();
|
||||
if total_size_read == LEN_SIZE {
|
||||
// If the entire payload size was read, parse it
|
||||
let payload_size = u64::from_le_bytes(*this.payload_size);
|
||||
|
||||
if !this.allowed_size.contains(&payload_size) {
|
||||
// If it's not in the allowed
|
||||
// range, transition to failure mode
|
||||
// `BytesPacketPosition::Size(LEN_SIZE)`, where only
|
||||
// an error is returned.
|
||||
*this.state = BytesPacketPosition::Size(LEN_SIZE)
|
||||
} else if payload_size == 0 {
|
||||
// If the payload size is 0, move on to reading padding directly.
|
||||
*this.state = BytesPacketPosition::Padding(0)
|
||||
} else {
|
||||
// Else, transition to reading the payload.
|
||||
*this.state = BytesPacketPosition::Payload(0)
|
||||
}
|
||||
} else {
|
||||
// If we still need to read more of payload size, update
|
||||
// our position in the state.
|
||||
*this.state = BytesPacketPosition::Size(total_size_read)
|
||||
}
|
||||
}
|
||||
BytesPacketPosition::Payload(pos) => {
|
||||
let signalled_size = u64::from_le_bytes(*this.payload_size);
|
||||
// We don't enter this match arm at all if we're expecting empty payload
|
||||
debug_assert!(signalled_size > 0, "signalled size must be larger than 0");
|
||||
|
||||
// Read from the underlying reader into buf
|
||||
// We cap the ReadBuf to the size of the payload, as we
|
||||
// don't want to leak padding to the caller.
|
||||
let bytes_read = ensure_nonzero_bytes_read({
|
||||
// Reducing these two u64 to usize on 32bits is fine - we
|
||||
// only care about not reading too much, not too less.
|
||||
let mut limited_buf = buf.take((signalled_size - pos) as usize);
|
||||
ready!(this.inner.as_mut().poll_read(cx, &mut limited_buf))?;
|
||||
limited_buf.filled().len()
|
||||
})?;
|
||||
|
||||
// SAFETY: we just did populate this, but through limited_buf.
|
||||
unsafe { buf.assume_init(bytes_read) }
|
||||
buf.advance(bytes_read);
|
||||
|
||||
if pos + bytes_read as u64 == signalled_size {
|
||||
// If we now read all payload, transition to padding
|
||||
// state.
|
||||
*this.state = BytesPacketPosition::Padding(0);
|
||||
} else {
|
||||
// if we didn't read everything yet, update our position
|
||||
// in the state.
|
||||
*this.state = BytesPacketPosition::Payload(pos + bytes_read as u64);
|
||||
}
|
||||
|
||||
// We return from poll_read here.
|
||||
// This is important, as any error (or even Pending) from
|
||||
// the underlying reader on the next read (be it padding or
|
||||
// payload) would require us to roll back buf, as generally
|
||||
// a AsyncRead::poll_read may not advance the buffer in case
|
||||
// of a nonsuccessful read.
|
||||
// It can't be misinterpreted as EOF, as we definitely *did*
|
||||
// write something into buf if we come to here (we pass
|
||||
// `ensure_nonzero_bytes_read`).
|
||||
return Ok(()).into();
|
||||
}
|
||||
BytesPacketPosition::Padding(pos) => {
|
||||
// Consume whatever padding is left, ensuring it's all null
|
||||
// bytes. Only return `Ready(Ok(()))` once we're past the
|
||||
// padding (or in cases where polling the inner reader
|
||||
// returns `Poll::Pending`).
|
||||
let signalled_size = u64::from_le_bytes(*this.payload_size);
|
||||
let total_padding_len = padding_len(signalled_size) as usize;
|
||||
|
||||
let padding_len_remaining = total_padding_len - pos;
|
||||
if padding_len_remaining != 0 {
|
||||
// create a buffer only accepting the number of remaining padding bytes.
|
||||
let mut buf = [0; 8];
|
||||
let mut padding_buf = tokio::io::ReadBuf::new(&mut buf);
|
||||
let mut padding_buf = padding_buf.take(padding_len_remaining);
|
||||
|
||||
// read into padding_buf.
|
||||
ready!(this.inner.as_mut().poll_read(cx, &mut padding_buf))?;
|
||||
let bytes_read = ensure_nonzero_bytes_read(padding_buf.filled().len())?;
|
||||
|
||||
*this.state = BytesPacketPosition::Padding(pos + bytes_read);
|
||||
|
||||
// ensure the bytes are not null bytes
|
||||
if !padding_buf.filled().iter().all(|e| *e == b'\0') {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"padding is not all zeroes",
|
||||
))
|
||||
.into();
|
||||
}
|
||||
|
||||
// if we still have padding to read, run the loop again.
|
||||
continue;
|
||||
}
|
||||
// return EOF
|
||||
return Ok(()).into();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::wire::bytes::write_bytes;
|
||||
use hex_literal::hex;
|
||||
use lazy_static::lazy_static;
|
||||
use rstest::rstest;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio_test::{assert_err, io::Builder};
|
||||
|
||||
use super::*;
|
||||
|
||||
/// The maximum length of bytes packets we're willing to accept in the test
|
||||
/// cases.
|
||||
const MAX_LEN: u64 = 1024;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref LARGE_PAYLOAD: Vec<u8> = (0..255).collect::<Vec<u8>>().repeat(4 * 1024);
|
||||
}
|
||||
|
||||
/// Helper function, calling the (simpler) write_bytes with the payload.
|
||||
/// We use this to create data we want to read from the wire.
|
||||
async fn produce_packet_bytes(payload: &[u8]) -> Vec<u8> {
|
||||
let mut exp = vec![];
|
||||
write_bytes(&mut exp, payload).await.unwrap();
|
||||
exp
|
||||
}
|
||||
|
||||
/// Read bytes packets of various length, and ensure read_to_end returns the
|
||||
/// expected payload.
|
||||
#[rstest]
|
||||
#[case::empty(&[])] // empty bytes packet
|
||||
#[case::size_1b(&[0xff])] // 1 bytes payload
|
||||
#[case::size_8b(&hex!("0001020304050607"))] // 8 bytes payload (no padding)
|
||||
#[case::size_9b( &hex!("000102030405060708"))] // 9 bytes payload (7 bytes padding)
|
||||
#[case::size_1m(LARGE_PAYLOAD.as_slice())] // larger bytes packet
|
||||
#[tokio::test]
|
||||
async fn read_payload_correct(#[case] payload: &[u8]) {
|
||||
let mut mock = Builder::new()
|
||||
.read(&produce_packet_bytes(payload).await)
|
||||
.build();
|
||||
|
||||
let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64);
|
||||
let mut buf = Vec::new();
|
||||
r.read_to_end(&mut buf).await.expect("must succeed");
|
||||
|
||||
assert_eq!(payload, &buf[..]);
|
||||
}
|
||||
|
||||
/// Fail if the bytes packet is larger than allowed
|
||||
#[tokio::test]
|
||||
async fn read_bigger_than_allowed_fail() {
|
||||
let payload = LARGE_PAYLOAD.as_slice();
|
||||
let mut mock = Builder::new()
|
||||
.read(&produce_packet_bytes(payload).await[0..8]) // We stop reading after the size packet
|
||||
.build();
|
||||
|
||||
let mut r = BytesReader::new(&mut mock, ..2048);
|
||||
let mut buf = Vec::new();
|
||||
assert_err!(r.read_to_end(&mut buf).await);
|
||||
}
|
||||
|
||||
/// Fail if the bytes packet is smaller than allowed
|
||||
#[tokio::test]
|
||||
async fn read_smaller_than_allowed_fail() {
|
||||
let payload = &[0x00, 0x01, 0x02];
|
||||
let mut mock = Builder::new()
|
||||
.read(&produce_packet_bytes(payload).await[0..8]) // We stop reading after the size packet
|
||||
.build();
|
||||
|
||||
let mut r = BytesReader::new(&mut mock, 1024..2048);
|
||||
let mut buf = Vec::new();
|
||||
assert_err!(r.read_to_end(&mut buf).await);
|
||||
}
|
||||
|
||||
/// Fail if the padding is not all zeroes
|
||||
#[tokio::test]
|
||||
async fn read_fail_if_nonzero_padding() {
|
||||
let payload = &[0x00, 0x01, 0x02];
|
||||
let mut packet_bytes = produce_packet_bytes(payload).await;
|
||||
// Flip some bits in the padding
|
||||
packet_bytes[12] = 0xff;
|
||||
let mut mock = Builder::new().read(&packet_bytes).build(); // We stop reading after the faulty bit
|
||||
|
||||
let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
|
||||
let mut buf = Vec::new();
|
||||
|
||||
r.read_to_end(&mut buf).await.expect_err("must fail");
|
||||
}
|
||||
|
||||
/// Start a 9 bytes payload packet, but have the underlying reader return
|
||||
/// EOF in the middle of the size packet (after 4 bytes).
|
||||
/// We should get an unexpected EOF error, already when trying to read the
|
||||
/// first byte (of payload)
|
||||
#[tokio::test]
|
||||
async fn read_9b_eof_during_size() {
|
||||
let payload = &hex!("FF0102030405060708");
|
||||
let mut mock = Builder::new()
|
||||
.read(&produce_packet_bytes(payload).await[..4])
|
||||
.build();
|
||||
|
||||
let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
|
||||
let mut buf = [0u8; 1];
|
||||
|
||||
assert_eq!(
|
||||
r.read_exact(&mut buf).await.expect_err("must fail").kind(),
|
||||
std::io::ErrorKind::UnexpectedEof
|
||||
);
|
||||
|
||||
assert_eq!(&[0], &buf, "buffer should stay empty");
|
||||
}
|
||||
|
||||
/// Start a 9 bytes payload packet, but have the underlying reader return
|
||||
/// EOF in the middle of the payload (4 bytes into the payload).
|
||||
/// We should get an unexpected EOF error, after reading the first 4 bytes
|
||||
/// (successfully).
|
||||
#[tokio::test]
|
||||
async fn read_9b_eof_during_payload() {
|
||||
let payload = &hex!("FF0102030405060708");
|
||||
let mut mock = Builder::new()
|
||||
.read(&produce_packet_bytes(payload).await[..8 + 4])
|
||||
.build();
|
||||
|
||||
let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
|
||||
let mut buf = [0; 9];
|
||||
|
||||
r.read_exact(&mut buf[..4]).await.expect("must succeed");
|
||||
|
||||
assert_eq!(
|
||||
r.read_exact(&mut buf[4..=4])
|
||||
.await
|
||||
.expect_err("must fail")
|
||||
.kind(),
|
||||
std::io::ErrorKind::UnexpectedEof
|
||||
);
|
||||
}
|
||||
|
||||
/// Start a 9 bytes payload packet, but return an error at various stages *after* the actual payload.
|
||||
/// read_exact with a 9 bytes buffer is expected to succeed, but any further
|
||||
/// read, as well as read_to_end are expected to fail.
|
||||
#[rstest]
|
||||
#[case::before_padding(8 + 9)]
|
||||
#[case::during_padding(8 + 9 + 2)]
|
||||
#[case::after_padding(8 + 9 + padding_len(9) as usize)]
|
||||
#[tokio::test]
|
||||
async fn read_9b_eof_after_payload(#[case] offset: usize) {
|
||||
let payload = &hex!("FF0102030405060708");
|
||||
let mut mock = Builder::new()
|
||||
.read(&produce_packet_bytes(payload).await[..offset])
|
||||
.build();
|
||||
|
||||
let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
|
||||
let mut buf = [0; 9];
|
||||
|
||||
// read_exact of the payload will succeed, but a subsequent read will
|
||||
// return UnexpectedEof error.
|
||||
r.read_exact(&mut buf).await.expect("should succeed");
|
||||
assert_eq!(
|
||||
r.read_exact(&mut buf[4..=4])
|
||||
.await
|
||||
.expect_err("must fail")
|
||||
.kind(),
|
||||
std::io::ErrorKind::UnexpectedEof
|
||||
);
|
||||
|
||||
// read_to_end will fail.
|
||||
let mut mock = Builder::new()
|
||||
.read(&produce_packet_bytes(payload).await[..8 + payload.len()])
|
||||
.build();
|
||||
|
||||
let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
|
||||
let mut buf = Vec::new();
|
||||
assert_eq!(
|
||||
r.read_to_end(&mut buf).await.expect_err("must fail").kind(),
|
||||
std::io::ErrorKind::UnexpectedEof
|
||||
);
|
||||
}
|
||||
|
||||
/// Start a 9 bytes payload packet, but return an error after a certain position.
|
||||
/// Ensure that error is propagated.
|
||||
#[rstest]
|
||||
#[case::during_size(4)]
|
||||
#[case::before_payload(8)]
|
||||
#[case::during_payload(8 + 4)]
|
||||
#[case::before_padding(8 + 4)]
|
||||
#[case::during_padding(8 + 9 + 2)]
|
||||
#[tokio::test]
|
||||
async fn propagate_error_from_reader(#[case] offset: usize) {
|
||||
let payload = &hex!("FF0102030405060708");
|
||||
let mut mock = Builder::new()
|
||||
.read(&produce_packet_bytes(payload).await[..offset])
|
||||
.read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo"))
|
||||
.build();
|
||||
|
||||
let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
|
||||
let mut buf = Vec::new();
|
||||
|
||||
let err = r.read_to_end(&mut buf).await.expect_err("must fail");
|
||||
assert_eq!(
|
||||
err.kind(),
|
||||
std::io::ErrorKind::Other,
|
||||
"error kind must match"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
err.into_inner().unwrap().to_string(),
|
||||
"foo",
|
||||
"error payload must contain foo"
|
||||
);
|
||||
}
|
||||
|
||||
/// If there's an error right after the padding, we don't propagate it, as
|
||||
/// we're done reading. We just return EOF.
|
||||
#[tokio::test]
|
||||
async fn no_error_after_eof() {
|
||||
let payload = &hex!("FF0102030405060708");
|
||||
let mut mock = Builder::new()
|
||||
.read(&produce_packet_bytes(payload).await)
|
||||
.read_error(std::io::Error::new(std::io::ErrorKind::Other, "foo"))
|
||||
.build();
|
||||
|
||||
let mut r = BytesReader::new(&mut mock, ..MAX_LEN);
|
||||
let mut buf = Vec::new();
|
||||
|
||||
r.read_to_end(&mut buf).await.expect("must succeed");
|
||||
assert_eq!(buf.as_slice(), payload);
|
||||
}
|
||||
|
||||
/// Introduce various stalls in various places of the packet, to ensure we
|
||||
/// handle these cases properly, too.
|
||||
#[rstest]
|
||||
#[case::beginning(0)]
|
||||
#[case::before_payload(8)]
|
||||
#[case::during_payload(8 + 4)]
|
||||
#[case::before_padding(8 + 4)]
|
||||
#[case::during_padding(8 + 9 + 2)]
|
||||
#[tokio::test]
|
||||
async fn read_payload_correct_pending(#[case] offset: usize) {
|
||||
let payload = &hex!("FF0102030405060708");
|
||||
let mut mock = Builder::new()
|
||||
.read(&produce_packet_bytes(payload).await[..offset])
|
||||
.wait(Duration::from_nanos(0))
|
||||
.read(&produce_packet_bytes(payload).await[offset..])
|
||||
.build();
|
||||
|
||||
let mut r = BytesReader::new(&mut mock, ..=LARGE_PAYLOAD.len() as u64);
|
||||
let mut buf = Vec::new();
|
||||
r.read_to_end(&mut buf).await.expect("must succeed");
|
||||
|
||||
assert_eq!(payload, &buf[..]);
|
||||
}
|
||||
}
|
||||
521
tvix/nix-compat/src/wire/bytes/writer.rs
Normal file
521
tvix/nix-compat/src/wire/bytes/writer.rs
Normal file
|
|
@ -0,0 +1,521 @@
|
|||
use pin_project_lite::pin_project;
|
||||
use std::task::{ready, Poll};
|
||||
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
use super::{padding_len, BytesPacketPosition, EMPTY_BYTES, LEN_SIZE};
|
||||
|
||||
pin_project! {
|
||||
/// Writes a "bytes wire packet" to the underlying writer.
|
||||
/// The format is the same as in [crate::wire::bytes::write_bytes],
|
||||
/// however this structure provides a [AsyncWrite] interface,
|
||||
/// allowing to not having to pass around the entire payload in memory.
|
||||
///
|
||||
/// It internally takes care of writing (non-payload) framing (size and
|
||||
/// padding).
|
||||
///
|
||||
/// During construction, the expected payload size needs to be provided.
|
||||
///
|
||||
/// After writing the payload to it, the user MUST call flush (or shutdown),
|
||||
/// which will validate the written payload size to match, and write the
|
||||
/// necessary padding.
|
||||
///
|
||||
/// In case flush is not called at the end, invalid data might be sent
|
||||
/// silently.
|
||||
///
|
||||
/// The underlying writer returning `Ok(0)` is considered an EOF situation,
|
||||
/// which is stronger than the "typically means the underlying object is no
|
||||
/// longer able to accept bytes" interpretation from the docs. If such a
|
||||
/// situation occurs, an error is returned.
|
||||
///
|
||||
/// The struct holds three fields, the underlying writer, the (expected)
|
||||
/// payload length, and an enum, tracking the state.
|
||||
pub struct BytesWriter<W>
|
||||
where
|
||||
W: AsyncWrite,
|
||||
{
|
||||
#[pin]
|
||||
inner: W,
|
||||
payload_len: u64,
|
||||
state: BytesPacketPosition,
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> BytesWriter<W>
|
||||
where
|
||||
W: AsyncWrite,
|
||||
{
|
||||
/// Constructs a new BytesWriter, using the underlying passed writer.
|
||||
pub fn new(w: W, payload_len: u64) -> Self {
|
||||
Self {
|
||||
inner: w,
|
||||
payload_len,
|
||||
state: BytesPacketPosition::Size(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an error if the passed usize is 0.
|
||||
fn ensure_nonzero_bytes_written(bytes_written: usize) -> Result<usize, std::io::Error> {
|
||||
if bytes_written == 0 {
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WriteZero,
|
||||
"underlying writer accepted 0 bytes",
|
||||
))
|
||||
} else {
|
||||
Ok(bytes_written)
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> AsyncWrite for BytesWriter<W>
|
||||
where
|
||||
W: AsyncWrite,
|
||||
{
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
// Use a loop, so we can deal with (multiple) state transitions.
|
||||
let mut this = self.project();
|
||||
|
||||
loop {
|
||||
match *this.state {
|
||||
BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
|
||||
BytesPacketPosition::Size(pos) => {
|
||||
let size_field = &this.payload_len.to_le_bytes();
|
||||
|
||||
let bytes_written = ensure_nonzero_bytes_written(ready!(this
|
||||
.inner
|
||||
.as_mut()
|
||||
.poll_write(cx, &size_field[pos..]))?)?;
|
||||
|
||||
let new_pos = pos + bytes_written;
|
||||
if new_pos == LEN_SIZE {
|
||||
*this.state = BytesPacketPosition::Payload(0);
|
||||
} else {
|
||||
*this.state = BytesPacketPosition::Size(new_pos);
|
||||
}
|
||||
}
|
||||
BytesPacketPosition::Payload(pos) => {
|
||||
// Ensure we still have space for more payload
|
||||
if pos + (buf.len() as u64) > *this.payload_len {
|
||||
return Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"tried to write excess bytes",
|
||||
)));
|
||||
}
|
||||
let bytes_written = ready!(this.inner.as_mut().poll_write(cx, buf))?;
|
||||
ensure_nonzero_bytes_written(bytes_written)?;
|
||||
let new_pos = pos + (bytes_written as u64);
|
||||
if new_pos == *this.payload_len {
|
||||
*this.state = BytesPacketPosition::Padding(0)
|
||||
} else {
|
||||
*this.state = BytesPacketPosition::Payload(new_pos)
|
||||
}
|
||||
|
||||
return Poll::Ready(Ok(bytes_written));
|
||||
}
|
||||
// If we're already in padding state, there should be no more payload left to write!
|
||||
BytesPacketPosition::Padding(_pos) => {
|
||||
return Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"tried to write excess bytes",
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let mut this = self.project();
|
||||
|
||||
loop {
|
||||
match *this.state {
|
||||
BytesPacketPosition::Size(LEN_SIZE) => unreachable!(),
|
||||
BytesPacketPosition::Size(pos) => {
|
||||
// More bytes to write in the size field
|
||||
let size_field = &this.payload_len.to_le_bytes()[..];
|
||||
let bytes_written = ensure_nonzero_bytes_written(ready!(this
|
||||
.inner
|
||||
.as_mut()
|
||||
.poll_write(cx, &size_field[pos..]))?)?;
|
||||
let new_pos = pos + bytes_written;
|
||||
if new_pos == LEN_SIZE {
|
||||
// Size field written, now ready to receive payload
|
||||
*this.state = BytesPacketPosition::Payload(0);
|
||||
} else {
|
||||
*this.state = BytesPacketPosition::Size(new_pos);
|
||||
}
|
||||
}
|
||||
BytesPacketPosition::Payload(_pos) => {
|
||||
// If we're at position 0 and want to write 0 bytes of payload
|
||||
// in total, we can transition to padding.
|
||||
// Otherwise, break, as we're expecting more payload to
|
||||
// be written.
|
||||
if *this.payload_len == 0 {
|
||||
*this.state = BytesPacketPosition::Padding(0);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
BytesPacketPosition::Padding(pos) => {
|
||||
// Write remaining padding, if there is padding to write.
|
||||
let total_padding_len = padding_len(*this.payload_len) as usize;
|
||||
|
||||
if pos != total_padding_len {
|
||||
let bytes_written = ensure_nonzero_bytes_written(ready!(this
|
||||
.inner
|
||||
.as_mut()
|
||||
.poll_write(cx, &EMPTY_BYTES[pos..total_padding_len]))?)?;
|
||||
*this.state = BytesPacketPosition::Padding(pos + bytes_written);
|
||||
} else {
|
||||
// everything written, break
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Flush the underlying writer.
|
||||
this.inner.as_mut().poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
// Call flush.
|
||||
ready!(self.as_mut().poll_flush(cx))?;
|
||||
|
||||
let this = self.project();
|
||||
|
||||
// After a flush, being inside the padding state, and at the end of the padding
|
||||
// is the only way to prevent a dirty shutdown.
|
||||
if let BytesPacketPosition::Padding(pos) = *this.state {
|
||||
let padding_len = padding_len(*this.payload_len) as usize;
|
||||
if padding_len == pos {
|
||||
// Shutdown the underlying writer
|
||||
return this.inner.poll_shutdown(cx);
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the underlying writer, bubbling up any errors.
|
||||
ready!(this.inner.poll_shutdown(cx))?;
|
||||
|
||||
// return an error about unclean shutdown
|
||||
Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::BrokenPipe,
|
||||
"unclean shutdown",
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::wire::bytes::write_bytes;
|
||||
use hex_literal::hex;
|
||||
use lazy_static::lazy_static;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio_test::{assert_err, assert_ok, io::Builder};
|
||||
|
||||
use super::*;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref LARGE_PAYLOAD: Vec<u8> = (0..255).collect::<Vec<u8>>().repeat(4 * 1024);
|
||||
}
|
||||
|
||||
/// Helper function, calling the (simpler) write_bytes with the payload.
|
||||
/// We use this to create data we want to see on the wire.
|
||||
async fn produce_exp_bytes(payload: &[u8]) -> Vec<u8> {
|
||||
let mut exp = vec![];
|
||||
write_bytes(&mut exp, payload).await.unwrap();
|
||||
exp
|
||||
}
|
||||
|
||||
/// Write an empty bytes packet.
|
||||
#[tokio::test]
|
||||
async fn write_empty() {
|
||||
let payload = &[];
|
||||
let mut mock = Builder::new()
|
||||
.write(&produce_exp_bytes(payload).await)
|
||||
.build();
|
||||
|
||||
let mut w = BytesWriter::new(&mut mock, 0);
|
||||
assert_ok!(w.write_all(&[]).await, "write all data");
|
||||
assert_ok!(w.flush().await, "flush");
|
||||
}
|
||||
|
||||
/// Write an empty bytes packet, not calling write.
|
||||
#[tokio::test]
|
||||
async fn write_empty_only_flush() {
|
||||
let payload = &[];
|
||||
let mut mock = Builder::new()
|
||||
.write(&produce_exp_bytes(payload).await)
|
||||
.build();
|
||||
|
||||
let mut w = BytesWriter::new(&mut mock, 0);
|
||||
assert_ok!(w.flush().await, "flush");
|
||||
}
|
||||
|
||||
/// Write an empty bytes packet, not calling write or flush, only shutdown.
|
||||
#[tokio::test]
|
||||
async fn write_empty_only_shutdown() {
|
||||
let payload = &[];
|
||||
let mut mock = Builder::new()
|
||||
.write(&produce_exp_bytes(payload).await)
|
||||
.build();
|
||||
|
||||
let mut w = BytesWriter::new(&mut mock, 0);
|
||||
assert_ok!(w.shutdown().await, "shutdown");
|
||||
}
|
||||
|
||||
/// Write a 1 bytes packet
|
||||
#[tokio::test]
|
||||
async fn write_1b() {
|
||||
let payload = &[0xff];
|
||||
|
||||
let mut mock = Builder::new()
|
||||
.write(&produce_exp_bytes(payload).await)
|
||||
.build();
|
||||
|
||||
let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
|
||||
assert_ok!(w.write_all(payload).await);
|
||||
assert_ok!(w.flush().await, "flush");
|
||||
}
|
||||
|
||||
/// Write a 8 bytes payload (no padding)
|
||||
#[tokio::test]
|
||||
async fn write_8b() {
|
||||
let payload = &hex!("0001020304050607");
|
||||
|
||||
let mut mock = Builder::new()
|
||||
.write(&produce_exp_bytes(payload).await)
|
||||
.build();
|
||||
|
||||
let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
|
||||
assert_ok!(w.write_all(payload).await);
|
||||
assert_ok!(w.flush().await, "flush");
|
||||
}
|
||||
|
||||
/// Write a 9 bytes payload (7 bytes padding)
|
||||
#[tokio::test]
|
||||
async fn write_9b() {
|
||||
let payload = &hex!("000102030405060708");
|
||||
|
||||
let mut mock = Builder::new()
|
||||
.write(&produce_exp_bytes(payload).await)
|
||||
.build();
|
||||
|
||||
let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
|
||||
assert_ok!(w.write_all(payload).await);
|
||||
assert_ok!(w.flush().await, "flush");
|
||||
}
|
||||
|
||||
/// Write a 9 bytes packet very granularly, with a lot of flushing in between,
|
||||
/// and a shutdown at the end.
|
||||
#[tokio::test]
|
||||
async fn write_9b_flush() {
|
||||
let payload = &hex!("000102030405060708");
|
||||
let exp_bytes = produce_exp_bytes(payload).await;
|
||||
|
||||
let mut mock = Builder::new().write(&exp_bytes).build();
|
||||
|
||||
let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
|
||||
assert_ok!(w.flush().await);
|
||||
|
||||
assert_ok!(w.write_all(&payload[..4]).await);
|
||||
assert_ok!(w.flush().await);
|
||||
|
||||
// empty write, cause why not
|
||||
assert_ok!(w.write_all(&[]).await);
|
||||
assert_ok!(w.flush().await);
|
||||
|
||||
assert_ok!(w.write_all(&payload[4..]).await);
|
||||
assert_ok!(w.flush().await);
|
||||
assert_ok!(w.shutdown().await);
|
||||
}
|
||||
|
||||
/// Write a 9 bytes packet, but cause the sink to only accept half of the
|
||||
/// padding, ensuring we correctly write (only) the rest of the padding later.
|
||||
/// We write another 2 bytes of "bait", where a faulty implementation (pre
|
||||
/// cl/11384) would put too many null bytes.
|
||||
#[tokio::test]
|
||||
async fn write_9b_write_padding_2steps() {
|
||||
let payload = &hex!("000102030405060708");
|
||||
let exp_bytes = produce_exp_bytes(payload).await;
|
||||
|
||||
let mut mock = Builder::new()
|
||||
.write(&exp_bytes[0..8]) // size
|
||||
.write(&exp_bytes[8..17]) // payload
|
||||
.write(&exp_bytes[17..19]) // padding (2 of 7 bytes)
|
||||
// insert a wait to prevent Mock from merging the two writes into one
|
||||
.wait(Duration::from_nanos(1))
|
||||
.write(&hex!("0000000000ffff")) // padding (5 of 7 bytes, plus 2 bytes of "bait")
|
||||
.build();
|
||||
|
||||
let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
|
||||
assert_ok!(w.write_all(&payload[..]).await);
|
||||
assert_ok!(w.flush().await);
|
||||
// Write bait
|
||||
assert_ok!(mock.write_all(&hex!("ffff")).await);
|
||||
}
|
||||
|
||||
/// Write a larger bytes packet
|
||||
#[tokio::test]
|
||||
async fn write_1m() {
|
||||
let payload = LARGE_PAYLOAD.as_slice();
|
||||
let exp_bytes = produce_exp_bytes(payload).await;
|
||||
|
||||
let mut mock = Builder::new().write(&exp_bytes).build();
|
||||
let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
|
||||
|
||||
assert_ok!(w.write_all(payload).await);
|
||||
assert_ok!(w.flush().await, "flush");
|
||||
}
|
||||
|
||||
/// Not calling flush at the end, but shutdown is also ok if we wrote all
|
||||
/// bytes we promised to write (as shutdown implies flush)
|
||||
#[tokio::test]
|
||||
async fn write_shutdown_without_flush_end() {
|
||||
let payload = &[0xf0, 0xff];
|
||||
let exp_bytes = produce_exp_bytes(payload).await;
|
||||
|
||||
let mut mock = Builder::new().write(&exp_bytes).build();
|
||||
let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
|
||||
|
||||
// call flush to write the size field
|
||||
assert_ok!(w.flush().await);
|
||||
|
||||
// write payload
|
||||
assert_ok!(w.write_all(payload).await);
|
||||
|
||||
// call shutdown
|
||||
assert_ok!(w.shutdown().await);
|
||||
}
|
||||
|
||||
/// Writing more bytes than previously signalled should fail.
|
||||
#[tokio::test]
|
||||
async fn write_more_than_signalled_fail() {
|
||||
let mut buf = Vec::new();
|
||||
let mut w = BytesWriter::new(&mut buf, 2);
|
||||
|
||||
assert_err!(w.write_all(&hex!("000102")).await);
|
||||
}
|
||||
/// Writing more bytes than previously signalled, but in two parts
|
||||
#[tokio::test]
|
||||
async fn write_more_than_signalled_split_fail() {
|
||||
let mut buf = Vec::new();
|
||||
let mut w = BytesWriter::new(&mut buf, 2);
|
||||
|
||||
// write two bytes
|
||||
assert_ok!(w.write_all(&hex!("0001")).await);
|
||||
|
||||
// write the excess byte.
|
||||
assert_err!(w.write_all(&hex!("02")).await);
|
||||
}
|
||||
|
||||
/// Writing more bytes than previously signalled, but flushing after the
|
||||
/// signalled amount should fail.
|
||||
#[tokio::test]
|
||||
async fn write_more_than_signalled_flush_fail() {
|
||||
let mut buf = Vec::new();
|
||||
let mut w = BytesWriter::new(&mut buf, 2);
|
||||
|
||||
// write two bytes, then flush
|
||||
assert_ok!(w.write_all(&hex!("0001")).await);
|
||||
assert_ok!(w.flush().await);
|
||||
|
||||
// write the excess byte.
|
||||
assert_err!(w.write_all(&hex!("02")).await);
|
||||
}
|
||||
|
||||
/// Calling shutdown while not having written all bytes that were promised
|
||||
/// returns an error.
|
||||
/// Note there's still cases of silent corruption if the user doesn't call
|
||||
/// shutdown explicitly (only drops).
|
||||
#[tokio::test]
|
||||
async fn premature_shutdown() {
|
||||
let payload = &[0xf0, 0xff];
|
||||
let mut buf = Vec::new();
|
||||
let mut w = BytesWriter::new(&mut buf, payload.len() as u64);
|
||||
|
||||
// call flush to write the size field
|
||||
assert_ok!(w.flush().await);
|
||||
|
||||
// write half of the payload (!)
|
||||
assert_ok!(w.write_all(&payload[0..1]).await);
|
||||
|
||||
// call shutdown, ensure it fails
|
||||
assert_err!(w.shutdown().await);
|
||||
}
|
||||
|
||||
/// Write to a Writer that fails to write during the size packet (after 4 bytes).
|
||||
/// Ensure this error gets propagated on the first call to write.
|
||||
#[tokio::test]
|
||||
async fn inner_writer_fail_during_size_firstwrite() {
|
||||
let payload = &[0xf0];
|
||||
|
||||
let mut mock = Builder::new()
|
||||
.write(&1u32.to_le_bytes())
|
||||
.write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
|
||||
.build();
|
||||
let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
|
||||
|
||||
assert_err!(w.write_all(payload).await);
|
||||
}
|
||||
|
||||
/// Write to a Writer that fails to write during the size packet (after 4 bytes).
|
||||
/// Ensure this error gets propagated during an initial flush
|
||||
#[tokio::test]
|
||||
async fn inner_writer_fail_during_size_initial_flush() {
|
||||
let payload = &[0xf0];
|
||||
|
||||
let mut mock = Builder::new()
|
||||
.write(&1u32.to_le_bytes())
|
||||
.write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
|
||||
.build();
|
||||
let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
|
||||
|
||||
assert_err!(w.flush().await);
|
||||
}
|
||||
|
||||
/// Write to a writer that fails to write during the payload (after 9 bytes).
|
||||
/// Ensure this error gets propagated when we're writing this byte.
|
||||
#[tokio::test]
|
||||
async fn inner_writer_fail_during_write() {
|
||||
let payload = &hex!("f0ff");
|
||||
|
||||
let mut mock = Builder::new()
|
||||
.write(&2u64.to_le_bytes())
|
||||
.write(&hex!("f0"))
|
||||
.write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
|
||||
.build();
|
||||
let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
|
||||
|
||||
assert_ok!(w.write(&hex!("f0")).await);
|
||||
assert_err!(w.write(&hex!("ff")).await);
|
||||
}
|
||||
|
||||
/// Write to a writer that fails to write during the padding (after 10 bytes).
|
||||
/// Ensure this error gets propagated during a flush.
|
||||
#[tokio::test]
|
||||
async fn inner_writer_fail_during_padding_flush() {
|
||||
let payload = &hex!("f0");
|
||||
|
||||
let mut mock = Builder::new()
|
||||
.write(&1u64.to_le_bytes())
|
||||
.write(&hex!("f0"))
|
||||
.write(&hex!("00"))
|
||||
.write_error(std::io::Error::new(std::io::ErrorKind::Other, "🍿"))
|
||||
.build();
|
||||
let mut w = BytesWriter::new(&mut mock, payload.len() as u64);
|
||||
|
||||
assert_ok!(w.write(&hex!("f0")).await);
|
||||
assert_err!(w.flush().await);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue