fix(nix-daemon): ensure Framed NARs are read exactly
This prevents framing confusion, which would otherwise lead to a trivial confused deputy attack. See issue #120. The NixFramedReader state machine has been refactored to simplify its internal logic and accurately account for EOF conditions. End-of-stream is fused, and unexpected EOF on the underlying reader is returned as UnexpectedEof, though we don't fuse those ourselves. We also ensure that the underlying reader does not swap the ReadBuf; this would otherwise supply a primitive for converting uninitialised mutable memory into `&mut [u8]` without initialisation, thus allowing undefined behaviour to be triggered from safe code. Change-Id: I05ddb7e3ca57b3363f56c0d9b43d5a641748ca36 Reviewed-on: https://cl.snix.dev/c/snix/+/30380 Reviewed-by: Brian Olsen <brian@maven-group.org> Tested-by: besadii Reviewed-by: Florian Klink <flokli@flokli.de>
This commit is contained in:
		
							parent
							
								
									4ef7c50a2d
								
							
						
					
					
						commit
						9a8a9c6b67
					
				
					 2 changed files with 107 additions and 104 deletions
				
			
		|  | @ -1,45 +1,35 @@ | ||||||
| use std::{ | use std::{ | ||||||
|     io::Result, |     num::NonZeroU64, | ||||||
|     pin::Pin, |     pin::Pin, | ||||||
|     task::{ready, Poll}, |     task::{self, ready, Poll}, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| use pin_project_lite::pin_project; | use pin_project_lite::pin_project; | ||||||
| use tokio::io::{AsyncRead, ReadBuf}; | use tokio::io::{self, AsyncRead, ReadBuf}; | ||||||
| 
 | 
 | ||||||
| /// State machine for [`NixFramedReader`].
 | /// State machine for [`NixFramedReader`].
 | ||||||
| ///
 | ///
 | ||||||
| /// As the reader progresses it linearly cycles through the states.
 | /// We read length-prefixed chunks until we receive a zero-sized payload indicating EOF.
 | ||||||
| #[derive(Debug, PartialEq)] | /// Other than the zero-sized terminating chunk, chunk boundaries are not considered meaningful.
 | ||||||
| enum NixFramedReaderState { | /// Lengths are 64-bit little endian values on the wire.
 | ||||||
|     /// The reader always starts in this state.
 | #[derive(Debug, Eq, PartialEq)] | ||||||
|     ///
 | enum State { | ||||||
|     /// Before the payload, the client first sends its size.
 |     Length { buf: [u8; 8], filled: u8 }, | ||||||
|     /// The size is a u64 which is 8 bytes long, while it's likely that we will receive
 |     Chunk { remaining: NonZeroU64 }, | ||||||
|     /// the whole u64 in one read, it's possible that it will arrive in smaller chunks.
 |     Eof, | ||||||
|     /// So in this state we read up to 8 bytes and transition to
 |  | ||||||
|     /// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero,
 |  | ||||||
|     /// otherwise we reset filled to 0, and read the next size value.
 |  | ||||||
|     ReadingSize { buf: [u8; 8], filled: usize }, |  | ||||||
|     /// This is where we read the actual payload that is sent to us.
 |  | ||||||
|     ///
 |  | ||||||
|     /// Once we've read the expected number of bytes, we go back to the
 |  | ||||||
|     /// [`NixFramedReaderState::ReadingSize`] state.
 |  | ||||||
|     ReadingPayload { |  | ||||||
|         /// Represents the remaining number of bytes we expect to read based on the value
 |  | ||||||
|         /// read in the previous state.
 |  | ||||||
|         remaining: u64, |  | ||||||
|     }, |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pin_project! { | pin_project! { | ||||||
|     /// Implements Nix's Framed reader protocol for protocol versions >= 1.23.
 |     /// Implements Nix's [Framed] reader protocol for protocol versions >= 1.23.
 | ||||||
|     ///
 |     ///
 | ||||||
|     /// See serialization.md#framed and [`NixFramedReaderState`] for details.
 |     /// Unexpected EOF on the underlying reader is returned as [UnexpectedEof][`std::io::ErrorKind::UnexpectedEof`].
 | ||||||
|  |     /// True EOF (end-of-stream) is fused.
 | ||||||
|  |     ///
 | ||||||
|  |     /// [Framed]: https://snix.dev/docs/reference/nix-daemon-protocol/types/#framed
 | ||||||
|     pub struct NixFramedReader<R> { |     pub struct NixFramedReader<R> { | ||||||
|         #[pin] |         #[pin] | ||||||
|         reader: R, |         reader: R, | ||||||
|         state: NixFramedReaderState, |         state: State, | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -47,100 +37,107 @@ impl<R> NixFramedReader<R> { | ||||||
|     pub fn new(reader: R) -> Self { |     pub fn new(reader: R) -> Self { | ||||||
|         Self { |         Self { | ||||||
|             reader, |             reader, | ||||||
|             state: NixFramedReaderState::ReadingSize { |             state: State::Length { | ||||||
|                 buf: [0; 8], |                 buf: [0; 8], | ||||||
|                 filled: 0, |                 filled: 0, | ||||||
|             }, |             }, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[cfg(test)] |     /// Returns `true` if the Nix Framed reader has reached EOF.
 | ||||||
|     fn is_eof(&self) -> bool { |     #[must_use] | ||||||
|         self.state |     pub fn is_eof(&self) -> bool { | ||||||
|             == NixFramedReaderState::ReadingSize { |         matches!(self.state, State::Eof) | ||||||
|                 buf: [0; 8], |  | ||||||
|                 filled: 8, |  | ||||||
|             } |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl<R: AsyncRead> AsyncRead for NixFramedReader<R> { | impl<R: AsyncRead> AsyncRead for NixFramedReader<R> { | ||||||
|     fn poll_read( |     fn poll_read( | ||||||
|         mut self: Pin<&mut Self>, |         mut self: Pin<&mut Self>, | ||||||
|         cx: &mut std::task::Context<'_>, |         cx: &mut task::Context<'_>, | ||||||
|         read_buf: &mut ReadBuf<'_>, |         buf: &mut ReadBuf<'_>, | ||||||
|     ) -> Poll<Result<()>> { |     ) -> Poll<io::Result<()>> { | ||||||
|         let mut this = self.as_mut().project(); |         let mut this = self.as_mut().project(); | ||||||
|  | 
 | ||||||
|  |         // reading nothing always succeeds
 | ||||||
|  |         if buf.remaining() == 0 { | ||||||
|  |             return Ok(()).into(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         loop { | ||||||
|  |             let reader = this.reader.as_mut(); | ||||||
|             match this.state { |             match this.state { | ||||||
|             NixFramedReaderState::ReadingSize { buf, filled } => { |                 State::Eof => { | ||||||
|                 if *filled < buf.len() { |                     return Ok(()).into(); | ||||||
|                     let mut size_buf = ReadBuf::new(buf); |                 } | ||||||
|                     size_buf.advance(*filled); |                 State::Length { buf, filled } => { | ||||||
|  |                     let bytes_read = { | ||||||
|  |                         let mut b = ReadBuf::new(&mut buf[*filled as usize..]); | ||||||
|  |                         ready!(reader.poll_read(cx, &mut b))?; | ||||||
|  |                         b.filled().len() as u8 | ||||||
|  |                     }; | ||||||
| 
 | 
 | ||||||
|                     ready!(this.reader.poll_read(cx, &mut size_buf))?; |  | ||||||
|                     let bytes_read = size_buf.filled().len() - *filled; |  | ||||||
|                     if bytes_read == 0 { |                     if bytes_read == 0 { | ||||||
|                         // oef
 |                         return Err(io::ErrorKind::UnexpectedEof.into()).into(); | ||||||
|                         return Poll::Ready(Ok(())); |  | ||||||
|                     } |                     } | ||||||
|  | 
 | ||||||
|                     *filled += bytes_read; |                     *filled += bytes_read; | ||||||
|                     // Schedule ourselves to run again.
 | 
 | ||||||
|                     return self.poll_read(cx, read_buf); |                     if *filled == 8 { | ||||||
|  |                         *this.state = match NonZeroU64::new(u64::from_le_bytes(*buf)) { | ||||||
|  |                             None => State::Eof, | ||||||
|  |                             Some(remaining) => State::Chunk { remaining }, | ||||||
|  |                         }; | ||||||
|                     } |                     } | ||||||
|                 let size = u64::from_le_bytes(*buf); |  | ||||||
|                 if size == 0 { |  | ||||||
|                     // eof
 |  | ||||||
|                     *filled = 0; |  | ||||||
|                     return Poll::Ready(Ok(())); |  | ||||||
|                 } |                 } | ||||||
|                 *this.state = NixFramedReaderState::ReadingPayload { remaining: size }; |                 State::Chunk { remaining } => { | ||||||
|                 self.poll_read(cx, read_buf) |                     let bytes_read = ready!(with_limited(buf, remaining.get(), |buf| { | ||||||
|             } |                         reader.poll_read(cx, buf).map_ok(|()| buf.filled().len()) | ||||||
|             NixFramedReaderState::ReadingPayload { remaining } => { |                     }))?; | ||||||
|                 // Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms.
 | 
 | ||||||
|                 let safe_remaining = if *remaining <= usize::MAX as u64 { |                     *this.state = match NonZeroU64::new(remaining.get() - bytes_read as u64) { | ||||||
|                     *remaining as usize |                         None => State::Length { | ||||||
|  |                             buf: [0; 8], | ||||||
|  |                             filled: 0, | ||||||
|  |                         }, | ||||||
|  |                         Some(remaining) => State::Chunk { remaining }, | ||||||
|  |                     }; | ||||||
|  | 
 | ||||||
|  |                     return if bytes_read == 0 { | ||||||
|  |                         Err(io::ErrorKind::UnexpectedEof.into()) | ||||||
|                     } else { |                     } else { | ||||||
|                     usize::MAX |                         Ok(()) | ||||||
|                 }; |                     } | ||||||
|                 if safe_remaining > 0 { |                     .into(); | ||||||
|                     // The buffer is no larger than the amount of data that we expect.
 |                 } | ||||||
|                     // Otherwise we will trim the buffer below and come back here.
 |             } | ||||||
|                     if read_buf.remaining() <= safe_remaining { |         } | ||||||
|                         let filled_before = read_buf.filled().len(); |     } | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
|                         ready!(this.reader.as_mut().poll_read(cx, read_buf))?; | /// Make a limited version of `buf`, consisting only of up to `n` bytes of the unfilled section, and call `f` with it.
 | ||||||
|                         let bytes_read = read_buf.filled().len() - filled_before; | /// After `f` returns, we propagate the filled cursor advancement back to `buf`.
 | ||||||
|  | // TODO(edef): duplicate of src/wire/bytes/reader/mod.rs:with_limited
 | ||||||
|  | fn with_limited<R>(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) -> R { | ||||||
|  |     let mut nbuf = buf.take(n.try_into().unwrap_or(usize::MAX)); | ||||||
|  |     let ptr = nbuf.initialized().as_ptr(); | ||||||
|  |     let ret = f(&mut nbuf); | ||||||
| 
 | 
 | ||||||
|                         *remaining -= bytes_read as u64; |     // SAFETY: `ReadBuf::take` only returns the *unfilled* section of `buf`,
 | ||||||
|                         if *remaining == 0 { |     // so anything filled is new, initialized data.
 | ||||||
|                             *this.state = NixFramedReaderState::ReadingSize { |     //
 | ||||||
|                                 buf: [0; 8], |     // We verify that `nbuf` still points to the same buffer,
 | ||||||
|                                 filled: 0, |     // so we're sure it hasn't been swapped out.
 | ||||||
|                             }; |     unsafe { | ||||||
|                         } |         // ensure our buffer hasn't been swapped out
 | ||||||
|                         return Poll::Ready(Ok(())); |         assert_eq!(nbuf.initialized().as_ptr(), ptr); | ||||||
|                     } |  | ||||||
|                     // Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes
 |  | ||||||
|                     // internal bookkeeping simpler.
 |  | ||||||
|                     let mut smaller_buf = read_buf.take(safe_remaining); |  | ||||||
|                     ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?; |  | ||||||
| 
 | 
 | ||||||
|                     let bytes_read = smaller_buf.filled().len(); |         let n = nbuf.filled().len(); | ||||||
|  |         buf.assume_init(n); | ||||||
|  |         buf.advance(n); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|                     // SAFETY: we just read this number of bytes into read_buf's backing slice above.
 |     ret | ||||||
|                     unsafe { read_buf.assume_init(bytes_read) }; |  | ||||||
|                     read_buf.advance(bytes_read); |  | ||||||
|                     return Poll::Ready(Ok(())); |  | ||||||
|                 } |  | ||||||
|                 *this.state = NixFramedReaderState::ReadingSize { |  | ||||||
|                     buf: [0; 8], |  | ||||||
|                     filled: 0, |  | ||||||
|                 }; |  | ||||||
|                 self.poll_read(cx, read_buf) |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
|  | @ -158,7 +155,6 @@ mod nix_framed_tests { | ||||||
|     use crate::nix_daemon::framing::NixFramedReader; |     use crate::nix_daemon::framing::NixFramedReader; | ||||||
| 
 | 
 | ||||||
|     #[tokio::test] |     #[tokio::test] | ||||||
|     #[should_panic] // broken
 |  | ||||||
|     async fn read_unexpected_eof_after_frame() { |     async fn read_unexpected_eof_after_frame() { | ||||||
|         let mut mock = Builder::new() |         let mut mock = Builder::new() | ||||||
|             // The client sends len
 |             // The client sends len
 | ||||||
|  | @ -179,7 +175,6 @@ mod nix_framed_tests { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[tokio::test] |     #[tokio::test] | ||||||
|     #[should_panic] // broken
 |  | ||||||
|     async fn read_unexpected_eof_in_frame() { |     async fn read_unexpected_eof_in_frame() { | ||||||
|         let mut mock = Builder::new() |         let mut mock = Builder::new() | ||||||
|             // The client sends len
 |             // The client sends len
 | ||||||
|  | @ -200,7 +195,6 @@ mod nix_framed_tests { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[tokio::test] |     #[tokio::test] | ||||||
|     #[should_panic] // broken
 |  | ||||||
|     async fn read_unexpected_eof_in_length() { |     async fn read_unexpected_eof_in_length() { | ||||||
|         let mut mock = Builder::new() |         let mut mock = Builder::new() | ||||||
|             // The client sends len
 |             // The client sends len
 | ||||||
|  | @ -219,7 +213,6 @@ mod nix_framed_tests { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     #[tokio::test] |     #[tokio::test] | ||||||
|     #[should_panic] // broken
 |  | ||||||
|     async fn read_hello_world_in_two_frames() { |     async fn read_hello_world_in_two_frames() { | ||||||
|         let mut mock = Builder::new() |         let mut mock = Builder::new() | ||||||
|             // The client sends len
 |             // The client sends len
 | ||||||
|  | @ -268,7 +261,6 @@ mod nix_framed_tests { | ||||||
|     /// Somewhat of a fuzz test, ensuring that we end up in identical states for the same input,
 |     /// Somewhat of a fuzz test, ensuring that we end up in identical states for the same input,
 | ||||||
|     /// independent of how it is spread across read calls and poll cycles.
 |     /// independent of how it is spread across read calls and poll cycles.
 | ||||||
|     #[test] |     #[test] | ||||||
|     #[should_panic] // broken
 |  | ||||||
|     fn split_verif() { |     fn split_verif() { | ||||||
|         let mut cx = task::Context::from_waker(task::Waker::noop()); |         let mut cx = task::Context::from_waker(task::Waker::noop()); | ||||||
|         let mut input = make_framed(&[b"hello", b"world", b"!", b""]); |         let mut input = make_framed(&[b"hello", b"world", b"!", b""]); | ||||||
|  |  | ||||||
|  | @ -189,16 +189,27 @@ where | ||||||
|                                         writer.deref_mut(), |                                         writer.deref_mut(), | ||||||
|                                     ); |                                     ); | ||||||
|                                     self.io.add_to_store_nar(request, &mut reader).await |                                     self.io.add_to_store_nar(request, &mut reader).await | ||||||
|  |                                     // TODO(edef): enforce framing synchronisation
 | ||||||
|                                 }) |                                 }) | ||||||
|                                 .await? |                                 .await? | ||||||
|                             } |                             } | ||||||
|                             23.. => { |                             23.. => { | ||||||
|                                 // Starting at protocol version 1.23, the framed protocol is used, see serialization.md#framed
 |                                 // Starting at protocol version 1.23, the framed protocol is used, see serialization.md#framed
 | ||||||
|                                 let mut framed = NixFramedReader::new(&mut self.reader); |                                 let mut framed = NixFramedReader::new(&mut self.reader); | ||||||
|  | 
 | ||||||
|                                 Self::handle(&self.writer, async { |                                 Self::handle(&self.writer, async { | ||||||
|                                     self.io.add_to_store_nar(request, &mut framed).await |                                     self.io.add_to_store_nar(request, &mut framed).await | ||||||
|                                 }) |                                 }) | ||||||
|                                 .await? |                                 .await?; | ||||||
|  | 
 | ||||||
|  |                                 // framing desynchronisation
 | ||||||
|  |                                 // this MUST kill the connection
 | ||||||
|  |                                 if !framed.is_eof() { | ||||||
|  |                                     return Err(std::io::Error::new( | ||||||
|  |                                         std::io::ErrorKind::InvalidData, | ||||||
|  |                                         "payload was not fully consumed", | ||||||
|  |                                     )); | ||||||
|  |                                 } | ||||||
|                             } |                             } | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue