diff --git a/snix/nix-compat/src/nix_daemon/framing/framed_read.rs b/snix/nix-compat/src/nix_daemon/framing/framed_read.rs index 309df82cf..af98f1881 100644 --- a/snix/nix-compat/src/nix_daemon/framing/framed_read.rs +++ b/snix/nix-compat/src/nix_daemon/framing/framed_read.rs @@ -1,45 +1,35 @@ use std::{ - io::Result, + num::NonZeroU64, pin::Pin, - task::{ready, Poll}, + task::{self, ready, Poll}, }; use pin_project_lite::pin_project; -use tokio::io::{AsyncRead, ReadBuf}; +use tokio::io::{self, AsyncRead, ReadBuf}; /// State machine for [`NixFramedReader`]. /// -/// As the reader progresses it linearly cycles through the states. -#[derive(Debug, PartialEq)] -enum NixFramedReaderState { - /// The reader always starts in this state. - /// - /// Before the payload, the client first sends its size. - /// The size is a u64 which is 8 bytes long, while it's likely that we will receive - /// the whole u64 in one read, it's possible that it will arrive in smaller chunks. - /// So in this state we read up to 8 bytes and transition to - /// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero, - /// otherwise we reset filled to 0, and read the next size value. - ReadingSize { buf: [u8; 8], filled: usize }, - /// This is where we read the actual payload that is sent to us. - /// - /// Once we've read the expected number of bytes, we go back to the - /// [`NixFramedReaderState::ReadingSize`] state. - ReadingPayload { - /// Represents the remaining number of bytes we expect to read based on the value - /// read in the previous state. - remaining: u64, - }, +/// We read length-prefixed chunks until we receive a zero-sized payload indicating EOF. +/// Other than the zero-sized terminating chunk, chunk boundaries are not considered meaningful. +/// Lengths are 64-bit little endian values on the wire. +#[derive(Debug, Eq, PartialEq)] +enum State { + Length { buf: [u8; 8], filled: u8 }, + Chunk { remaining: NonZeroU64 }, + Eof, } pin_project! { - /// Implements Nix's Framed reader protocol for protocol versions >= 1.23. + /// Implements Nix's [Framed] reader protocol for protocol versions >= 1.23. /// - /// See serialization.md#framed and [`NixFramedReaderState`] for details. + /// Unexpected EOF on the underlying reader is returned as [UnexpectedEof][`std::io::ErrorKind::UnexpectedEof`]. + /// True EOF (end-of-stream) is fused. + /// + /// [Framed]: https://snix.dev/docs/reference/nix-daemon-protocol/types/#framed pub struct NixFramedReader { #[pin] reader: R, - state: NixFramedReaderState, + state: State, } } @@ -47,102 +37,109 @@ impl NixFramedReader { pub fn new(reader: R) -> Self { Self { reader, - state: NixFramedReaderState::ReadingSize { + state: State::Length { buf: [0; 8], filled: 0, }, } } - #[cfg(test)] - fn is_eof(&self) -> bool { - self.state - == NixFramedReaderState::ReadingSize { - buf: [0; 8], - filled: 8, - } + /// Returns `true` if the Nix Framed reader has reached EOF. + #[must_use] + pub fn is_eof(&self) -> bool { + matches!(self.state, State::Eof) } } impl AsyncRead for NixFramedReader { fn poll_read( mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - read_buf: &mut ReadBuf<'_>, - ) -> Poll> { + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { let mut this = self.as_mut().project(); - match this.state { - NixFramedReaderState::ReadingSize { buf, filled } => { - if *filled < buf.len() { - let mut size_buf = ReadBuf::new(buf); - size_buf.advance(*filled); - ready!(this.reader.poll_read(cx, &mut size_buf))?; - let bytes_read = size_buf.filled().len() - *filled; + // reading nothing always succeeds + if buf.remaining() == 0 { + return Ok(()).into(); + } + + loop { + let reader = this.reader.as_mut(); + match this.state { + State::Eof => { + return Ok(()).into(); + } + State::Length { buf, filled } => { + let bytes_read = { + let mut b = ReadBuf::new(&mut buf[*filled as usize..]); + ready!(reader.poll_read(cx, &mut b))?; + b.filled().len() as u8 + }; + if bytes_read == 0 { - // oef - return Poll::Ready(Ok(())); + return Err(io::ErrorKind::UnexpectedEof.into()).into(); } + *filled += bytes_read; - // Schedule ourselves to run again. - return self.poll_read(cx, read_buf); - } - let size = u64::from_le_bytes(*buf); - if size == 0 { - // eof - *filled = 0; - return Poll::Ready(Ok(())); - } - *this.state = NixFramedReaderState::ReadingPayload { remaining: size }; - self.poll_read(cx, read_buf) - } - NixFramedReaderState::ReadingPayload { remaining } => { - // Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms. - let safe_remaining = if *remaining <= usize::MAX as u64 { - *remaining as usize - } else { - usize::MAX - }; - if safe_remaining > 0 { - // The buffer is no larger than the amount of data that we expect. - // Otherwise we will trim the buffer below and come back here. - if read_buf.remaining() <= safe_remaining { - let filled_before = read_buf.filled().len(); - ready!(this.reader.as_mut().poll_read(cx, read_buf))?; - let bytes_read = read_buf.filled().len() - filled_before; - - *remaining -= bytes_read as u64; - if *remaining == 0 { - *this.state = NixFramedReaderState::ReadingSize { - buf: [0; 8], - filled: 0, - }; - } - return Poll::Ready(Ok(())); + if *filled == 8 { + *this.state = match NonZeroU64::new(u64::from_le_bytes(*buf)) { + None => State::Eof, + Some(remaining) => State::Chunk { remaining }, + }; } - // Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes - // internal bookkeeping simpler. - let mut smaller_buf = read_buf.take(safe_remaining); - ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?; - - let bytes_read = smaller_buf.filled().len(); - - // SAFETY: we just read this number of bytes into read_buf's backing slice above. - unsafe { read_buf.assume_init(bytes_read) }; - read_buf.advance(bytes_read); - return Poll::Ready(Ok(())); } - *this.state = NixFramedReaderState::ReadingSize { - buf: [0; 8], - filled: 0, - }; - self.poll_read(cx, read_buf) + State::Chunk { remaining } => { + let bytes_read = ready!(with_limited(buf, remaining.get(), |buf| { + reader.poll_read(cx, buf).map_ok(|()| buf.filled().len()) + }))?; + + *this.state = match NonZeroU64::new(remaining.get() - bytes_read as u64) { + None => State::Length { + buf: [0; 8], + filled: 0, + }, + Some(remaining) => State::Chunk { remaining }, + }; + + return if bytes_read == 0 { + Err(io::ErrorKind::UnexpectedEof.into()) + } else { + Ok(()) + } + .into(); + } } } } } +/// Make a limited version of `buf`, consisting only of up to `n` bytes of the unfilled section, and call `f` with it. +/// After `f` returns, we propagate the filled cursor advancement back to `buf`. +// TODO(edef): duplicate of src/wire/bytes/reader/mod.rs:with_limited +fn with_limited(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) -> R { + let mut nbuf = buf.take(n.try_into().unwrap_or(usize::MAX)); + let ptr = nbuf.initialized().as_ptr(); + let ret = f(&mut nbuf); + + // SAFETY: `ReadBuf::take` only returns the *unfilled* section of `buf`, + // so anything filled is new, initialized data. + // + // We verify that `nbuf` still points to the same buffer, + // so we're sure it hasn't been swapped out. + unsafe { + // ensure our buffer hasn't been swapped out + assert_eq!(nbuf.initialized().as_ptr(), ptr); + + let n = nbuf.filled().len(); + buf.assume_init(n); + buf.advance(n); + } + + ret +} + #[cfg(test)] mod nix_framed_tests { use std::{ @@ -158,7 +155,6 @@ mod nix_framed_tests { use crate::nix_daemon::framing::NixFramedReader; #[tokio::test] - #[should_panic] // broken async fn read_unexpected_eof_after_frame() { let mut mock = Builder::new() // The client sends len @@ -179,7 +175,6 @@ mod nix_framed_tests { } #[tokio::test] - #[should_panic] // broken async fn read_unexpected_eof_in_frame() { let mut mock = Builder::new() // The client sends len @@ -200,7 +195,6 @@ mod nix_framed_tests { } #[tokio::test] - #[should_panic] // broken async fn read_unexpected_eof_in_length() { let mut mock = Builder::new() // The client sends len @@ -219,7 +213,6 @@ mod nix_framed_tests { } #[tokio::test] - #[should_panic] // broken async fn read_hello_world_in_two_frames() { let mut mock = Builder::new() // The client sends len @@ -268,7 +261,6 @@ mod nix_framed_tests { /// Somewhat of a fuzz test, ensuring that we end up in identical states for the same input, /// independent of how it is spread across read calls and poll cycles. #[test] - #[should_panic] // broken fn split_verif() { let mut cx = task::Context::from_waker(task::Waker::noop()); let mut input = make_framed(&[b"hello", b"world", b"!", b""]); diff --git a/snix/nix-compat/src/nix_daemon/handler.rs b/snix/nix-compat/src/nix_daemon/handler.rs index 2430b9634..34fa570b9 100644 --- a/snix/nix-compat/src/nix_daemon/handler.rs +++ b/snix/nix-compat/src/nix_daemon/handler.rs @@ -189,16 +189,27 @@ where writer.deref_mut(), ); self.io.add_to_store_nar(request, &mut reader).await + // TODO(edef): enforce framing synchronisation }) .await? } 23.. => { // Starting at protocol version 1.23, the framed protocol is used, see serialization.md#framed let mut framed = NixFramedReader::new(&mut self.reader); + Self::handle(&self.writer, async { self.io.add_to_store_nar(request, &mut framed).await }) - .await? + .await?; + + // framing desynchronisation + // this MUST kill the connection + if !framed.is_eof() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "payload was not fully consumed", + )); + } } } }