refactor(tvix/store/blobsvc): make BlobStore async

We previously kept the trait of a BlobService sync.

This however had some annoying consequences:

 - It became more and more complicated to track when we're in a context
   with an async runtime in the context or not, producing bugs like
   https://b.tvl.fyi/issues/304
 - The sync trait shielded away async clients from async worloads,
   requiring manual block_on code inside the gRPC client code, and
   spawn_blocking calls in consumers of the trait, even if they were
   async (like the gRPC server)
 - We had to write our own custom glue code (SyncReadIntoAsyncRead)
   to convert a sync io::Read into a tokio::io::AsyncRead, which already
   existed in tokio internally, but upstream ia hesitant to expose.

This now makes the BlobService trait async (via the async_trait macro,
like we already do in various gRPC parts), and replaces the sync readers
and writers with their async counterparts.

Tests interacting with a BlobService now need to have an async runtime
available, the easiest way for this is to mark the test functions
with the tokio::test macro, allowing us to directly .await in the test
function.

In places where we don't have an async runtime available from context
(like tvix-cli), we can pass one down explicitly.

Now that we don't provide a sync interface anymore, the (sync) FUSE
library now holds a pointer to a tokio runtime handle, and needs to at
least have 2 threads available when talking to a blob service (which is
why some of the tests now use the multi_thread flavor).

The FUSE tests got a bit more verbose, as we couldn't use the
setup_and_mount function accepting a callback anymore. We can hopefully
move some of the test fixture setup to rstest in the future to make this
less repetitive.

Co-Authored-By: Connor Brewster <cbrewster@hey.com>
Change-Id: Ia0501b606e32c852d0108de9c9016b21c94a3c05
Reviewed-on: https://cl.tvl.fyi/c/depot/+/9329
Reviewed-by: Connor Brewster <cbrewster@hey.com>
Tested-by: BuildkiteCI
Reviewed-by: raitobezarius <tvl@lahfa.xyz>
This commit is contained in:
Florian Klink 2023-09-13 14:20:21 +02:00 committed by flokli
parent 3de9601764
commit da6cbb4a45
25 changed files with 1700 additions and 1002 deletions

View file

@ -1,110 +0,0 @@
use std::io;
use tracing::{debug, instrument};
use super::BlobReader;
/// This implements [io::Seek] for and [io::Read] by simply skipping over some
/// bytes, keeping track of the position.
/// It fails whenever you try to seek backwards.
pub struct DumbSeeker<R: io::Read> {
r: R,
pos: u64,
}
impl<R: io::Read> DumbSeeker<R> {
pub fn new(r: R) -> Self {
DumbSeeker { r, pos: 0 }
}
}
impl<R: io::Read> io::Read for DumbSeeker<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let bytes_read = self.r.read(buf)?;
self.pos += bytes_read as u64;
Ok(bytes_read)
}
}
impl<R: io::Read> io::Seek for DumbSeeker<R> {
#[instrument(skip(self))]
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
let absolute_offset: u64 = match pos {
io::SeekFrom::Start(start_offset) => {
if start_offset < self.pos {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
format!("can't seek backwards ({} -> {})", self.pos, start_offset),
));
} else {
start_offset
}
}
// we don't know the total size, can't support this.
io::SeekFrom::End(_end_offset) => {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"can't seek from end",
));
}
io::SeekFrom::Current(relative_offset) => {
if relative_offset < 0 {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"can't seek backwards relative to current position",
));
} else {
self.pos + relative_offset as u64
}
}
};
debug!(absolute_offset=?absolute_offset, "seek");
// we already know absolute_offset is larger than self.pos
debug_assert!(
absolute_offset >= self.pos,
"absolute_offset {} is larger than self.pos {}",
absolute_offset,
self.pos
);
// calculate bytes to skip
let bytes_to_skip: u64 = absolute_offset - self.pos;
// discard these bytes. We can't use take() as it requires ownership of
// self.r, but we only have &mut self.
let mut buf = [0; 1024];
let mut bytes_skipped: u64 = 0;
while bytes_skipped < bytes_to_skip {
let len = std::cmp::min(bytes_to_skip - bytes_skipped, buf.len() as u64);
match self.r.read(&mut buf[..len as usize]) {
Ok(0) => break,
Ok(n) => bytes_skipped += n as u64,
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
// This will fail when seeking past the end of self.r
if bytes_to_skip != bytes_skipped {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!(
"tried to skip {} bytes, but only was able to skip {} until reaching EOF",
bytes_to_skip, bytes_skipped
),
));
}
self.pos = absolute_offset;
// return the new position from the start of the stream
Ok(absolute_offset)
}
}
/// A Cursor<Vec<u8>> can be used as a BlobReader.
impl<R: io::Read + Send + 'static> BlobReader for DumbSeeker<R> {}

View file

@ -1,22 +1,26 @@
use super::{dumb_seeker::DumbSeeker, BlobReader, BlobService, BlobWriter};
use super::{naive_seeker::NaiveSeeker, BlobReader, BlobService, BlobWriter};
use crate::{proto, B3Digest};
use futures::sink::{SinkExt, SinkMapErr};
use std::{collections::VecDeque, io};
use futures::sink::SinkExt;
use futures::TryFutureExt;
use std::{
collections::VecDeque,
io::{self},
pin::pin,
task::Poll,
};
use tokio::io::AsyncWriteExt;
use tokio::{net::UnixStream, task::JoinHandle};
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
use tokio_util::{
io::{CopyToBytes, SinkWriter, SyncIoBridge},
io::{CopyToBytes, SinkWriter},
sync::{PollSendError, PollSender},
};
use tonic::{transport::Channel, Code, Status, Streaming};
use tonic::{async_trait, transport::Channel, Code, Status};
use tracing::instrument;
/// Connects to a (remote) tvix-store BlobService over gRPC.
#[derive(Clone)]
pub struct GRPCBlobService {
/// A handle into the active tokio runtime. Necessary to spawn tasks.
tokio_handle: tokio::runtime::Handle,
/// The internal reference to a gRPC client.
/// Cloning it is cheap, and it internally handles concurrent requests.
grpc_client: proto::blob_service_client::BlobServiceClient<Channel>,
@ -28,13 +32,11 @@ impl GRPCBlobService {
pub fn from_client(
grpc_client: proto::blob_service_client::BlobServiceClient<Channel>,
) -> Self {
Self {
tokio_handle: tokio::runtime::Handle::current(),
grpc_client,
}
Self { grpc_client }
}
}
#[async_trait]
impl BlobService for GRPCBlobService {
/// Constructs a [GRPCBlobService] from the passed [url::Url]:
/// - scheme has to match `grpc+*://`.
@ -89,22 +91,16 @@ impl BlobService for GRPCBlobService {
}
#[instrument(skip(self, digest), fields(blob.digest=%digest))]
fn has(&self, digest: &B3Digest) -> Result<bool, crate::Error> {
// Get a new handle to the gRPC client, and copy the digest.
async fn has(&self, digest: &B3Digest) -> Result<bool, crate::Error> {
let mut grpc_client = self.grpc_client.clone();
let digest = digest.clone();
let resp = grpc_client
.stat(proto::StatBlobRequest {
digest: digest.clone().into(),
..Default::default()
})
.await;
let task: JoinHandle<Result<_, Status>> = self.tokio_handle.spawn(async move {
Ok(grpc_client
.stat(proto::StatBlobRequest {
digest: digest.into(),
..Default::default()
})
.await?
.into_inner())
});
match self.tokio_handle.block_on(task)? {
match resp {
Ok(_blob_meta) => Ok(true),
Err(e) if e.code() == Code::NotFound => Ok(false),
Err(e) => Err(crate::Error::StorageError(e.to_string())),
@ -113,35 +109,30 @@ impl BlobService for GRPCBlobService {
// On success, this returns a Ok(Some(io::Read)), which can be used to read
// the contents of the Blob, identified by the digest.
fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, crate::Error> {
async fn open_read(
&self,
digest: &B3Digest,
) -> Result<Option<Box<dyn BlobReader>>, crate::Error> {
// Get a new handle to the gRPC client, and copy the digest.
let mut grpc_client = self.grpc_client.clone();
let digest = digest.clone();
// Construct the task that'll send out the request and return the stream
// the gRPC client should use to send [proto::BlobChunk], or an error if
// the blob doesn't exist.
let task: JoinHandle<Result<Streaming<proto::BlobChunk>, Status>> =
self.tokio_handle.spawn(async move {
let stream = grpc_client
.read(proto::ReadBlobRequest {
digest: digest.into(),
})
.await?
.into_inner();
Ok(stream)
});
// Get a stream of [proto::BlobChunk], or return an error if the blob
// doesn't exist.
let resp = grpc_client
.read(proto::ReadBlobRequest {
digest: digest.clone().into(),
})
.await;
// This runs the task to completion, which on success will return a stream.
// On reading from it, we receive individual [proto::BlobChunk], so we
// massage this to a stream of bytes,
// then create an [AsyncRead], which we'll turn into a [io::Read],
// that's returned from the function.
match self.tokio_handle.block_on(task)? {
match resp {
Ok(stream) => {
// map the stream of proto::BlobChunk to bytes.
let data_stream = stream.map(|x| {
let data_stream = stream.into_inner().map(|x| {
x.map(|x| VecDeque::from(x.data.to_vec()))
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
});
@ -149,9 +140,7 @@ impl BlobService for GRPCBlobService {
// Use StreamReader::new to convert to an AsyncRead.
let data_reader = tokio_util::io::StreamReader::new(data_stream);
// Use SyncIoBridge to turn it into a sync Read.
let sync_reader = tokio_util::io::SyncIoBridge::new(data_reader);
Ok(Some(Box::new(DumbSeeker::new(sync_reader))))
Ok(Some(Box::new(NaiveSeeker::new(data_reader))))
}
Err(e) if e.code() == Code::NotFound => Ok(None),
Err(e) => Err(crate::Error::StorageError(e.to_string())),
@ -160,7 +149,7 @@ impl BlobService for GRPCBlobService {
/// Returns a BlobWriter, that'll internally wrap each write in a
// [proto::BlobChunk], which is send to the gRPC server.
fn open_write(&self) -> Box<dyn BlobWriter> {
async fn open_write(&self) -> Box<dyn BlobWriter> {
let mut grpc_client = self.grpc_client.clone();
// set up an mpsc channel passing around Bytes.
@ -171,9 +160,8 @@ impl BlobService for GRPCBlobService {
let blobchunk_stream = ReceiverStream::new(rx).map(|x| proto::BlobChunk { data: x });
// That receiver stream is used as a stream in the gRPC BlobService.put rpc call.
let task: JoinHandle<Result<_, Status>> = self
.tokio_handle
.spawn(async move { Ok(grpc_client.put(blobchunk_stream).await?.into_inner()) });
let task: JoinHandle<Result<_, Status>> =
tokio::spawn(async move { Ok(grpc_client.put(blobchunk_stream).await?.into_inner()) });
// The tx part of the channel is converted to a sink of byte chunks.
@ -187,43 +175,26 @@ impl BlobService for GRPCBlobService {
// We need to explicitly cast here, otherwise rustc does error with "expected fn pointer, found fn item"
// … which is turned into an [tokio::io::AsyncWrite].
let async_writer = SinkWriter::new(CopyToBytes::new(sink));
// … which is then turned into a [io::Write].
let writer = SyncIoBridge::new(async_writer);
let writer = SinkWriter::new(CopyToBytes::new(sink));
Box::new(GRPCBlobWriter {
tokio_handle: self.tokio_handle.clone(),
task_and_writer: Some((task, writer)),
digest: None,
})
}
}
type BridgedWriter = SyncIoBridge<
SinkWriter<
CopyToBytes<
SinkMapErr<PollSender<bytes::Bytes>, fn(PollSendError<bytes::Bytes>) -> io::Error>,
>,
>,
>;
pub struct GRPCBlobWriter {
/// A handle into the active tokio runtime. Necessary to block on the task
/// containing the put request.
tokio_handle: tokio::runtime::Handle,
pub struct GRPCBlobWriter<W: tokio::io::AsyncWrite> {
/// The task containing the put request, and the inner writer, if we're still writing.
task_and_writer: Option<(
JoinHandle<Result<proto::PutBlobResponse, Status>>,
BridgedWriter,
)>,
task_and_writer: Option<(JoinHandle<Result<proto::PutBlobResponse, Status>>, W)>,
/// The digest that has been returned, if we successfully closed.
digest: Option<B3Digest>,
}
impl BlobWriter for GRPCBlobWriter {
fn close(&mut self) -> Result<B3Digest, crate::Error> {
#[async_trait]
impl<W: tokio::io::AsyncWrite + Send + Sync + Unpin + 'static> BlobWriter for GRPCBlobWriter<W> {
async fn close(&mut self) -> Result<B3Digest, crate::Error> {
if self.task_and_writer.is_none() {
// if we're already closed, return the b3 digest, which must exist.
// If it doesn't, we already closed and failed once, and didn't handle the error.
@ -240,12 +211,14 @@ impl BlobWriter for GRPCBlobWriter {
// the channel.
writer
.shutdown()
.map_err(|e| crate::Error::StorageError(e.to_string()))?;
.map_err(|e| crate::Error::StorageError(e.to_string()))
.await?;
// block on the RPC call to return.
// This ensures all chunks are sent out, and have been received by the
// backend.
match self.tokio_handle.block_on(task)? {
match task.await? {
Ok(resp) => {
// return the digest from the response, and store it in self.digest for subsequent closes.
let digest: B3Digest = resp.digest.try_into().map_err(|_| {
@ -262,26 +235,48 @@ impl BlobWriter for GRPCBlobWriter {
}
}
impl io::Write for GRPCBlobWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
impl<W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for GRPCBlobWriter<W> {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
match &mut self.task_and_writer {
None => Err(io::Error::new(
None => Poll::Ready(Err(io::Error::new(
io::ErrorKind::NotConnected,
"already closed",
)),
Some((_, ref mut writer)) => writer.write(buf),
))),
Some((_, ref mut writer)) => {
let pinned_writer = pin!(writer);
pinned_writer.poll_write(cx, buf)
}
}
}
fn flush(&mut self) -> io::Result<()> {
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
match &mut self.task_and_writer {
None => Err(io::Error::new(
None => Poll::Ready(Err(io::Error::new(
io::ErrorKind::NotConnected,
"already closed",
)),
Some((_, ref mut writer)) => writer.flush(),
))),
Some((_, ref mut writer)) => {
let pinned_writer = pin!(writer);
pinned_writer.poll_flush(cx)
}
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
// TODO(raitobezarius): this might not be a graceful shutdown of the
// channel inside the gRPC connection.
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
@ -291,7 +286,6 @@ mod tests {
use tempfile::TempDir;
use tokio::net::UnixListener;
use tokio::task;
use tokio::time;
use tokio_stream::wrappers::UnixListenerStream;
@ -358,32 +352,23 @@ mod tests {
}
/// This uses the correct scheme for a unix socket, and provides a server on the other side.
#[tokio::test]
async fn test_valid_unix_path_ping_pong() {
/// This is not a tokio::test, because spawn two separate tokio runtimes and
// want to have explicit control.
#[test]
fn test_valid_unix_path_ping_pong() {
let tmpdir = TempDir::new().unwrap();
let path = tmpdir.path().join("daemon");
// let mut join_set = JoinSet::new();
// prepare a client
let client = {
let mut url = url::Url::parse("grpc+unix:///path/to/somewhere").expect("must parse");
url.set_path(path.to_str().unwrap());
GRPCBlobService::from_url(&url).expect("must succeed")
};
let path_copy = path.clone();
let path_clone = path.clone();
// Spin up a server, in a thread far away, which spawns its own tokio runtime,
// and blocks on the task.
thread::spawn(move || {
// Create the runtime
let rt = tokio::runtime::Runtime::new().unwrap();
// Get a handle from this runtime
let handle = rt.handle();
let task = handle.spawn(async {
let uds = UnixListener::bind(path_copy).unwrap();
let task = rt.spawn(async {
let uds = UnixListener::bind(path_clone).unwrap();
let uds_stream = UnixListenerStream::new(uds);
// spin up a new server
@ -397,33 +382,46 @@ mod tests {
router.serve_with_incoming(uds_stream).await
});
handle.block_on(task)
rt.block_on(task).unwrap().unwrap();
});
// wait for the socket to be created
{
let mut socket_created = false;
for _try in 1..20 {
if path.exists() {
socket_created = true;
break;
// Now create another tokio runtime which we'll use in the main test code.
let rt = tokio::runtime::Runtime::new().unwrap();
let task = rt.spawn(async move {
// wait for the socket to be created
{
let mut socket_created = false;
// TODO: exponential backoff urgently
for _try in 1..20 {
if path.exists() {
socket_created = true;
break;
}
tokio::time::sleep(time::Duration::from_millis(20)).await;
}
tokio::time::sleep(time::Duration::from_millis(20)).await;
assert!(
socket_created,
"expected socket path to eventually get created, but never happened"
);
}
assert!(
socket_created,
"expected socket path to eventually get created, but never happened"
);
}
// prepare a client
let client = {
let mut url =
url::Url::parse("grpc+unix:///path/to/somewhere").expect("must parse");
url.set_path(path.to_str().unwrap());
GRPCBlobService::from_url(&url).expect("must succeed")
};
let has = task::spawn_blocking(move || {
client
let has = client
.has(&fixtures::BLOB_A_DIGEST)
.expect("must not be err")
})
.await
.expect("must not be err");
assert!(!has);
.await
.expect("must not be err");
assert!(!has);
});
rt.block_on(task).unwrap()
}
}

View file

@ -1,9 +1,11 @@
use std::io::{self, Cursor};
use std::io::{self, Cursor, Write};
use std::task::Poll;
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use tracing::{instrument, warn};
use tonic::async_trait;
use tracing::instrument;
use super::{BlobReader, BlobService, BlobWriter};
use crate::{B3Digest, Error};
@ -13,6 +15,7 @@ pub struct MemoryBlobService {
db: Arc<RwLock<HashMap<B3Digest, Vec<u8>>>>,
}
#[async_trait]
impl BlobService for MemoryBlobService {
/// Constructs a [MemoryBlobService] from the passed [url::Url]:
/// - scheme has to be `memory://`
@ -31,12 +34,12 @@ impl BlobService for MemoryBlobService {
}
#[instrument(skip(self, digest), fields(blob.digest=%digest))]
fn has(&self, digest: &B3Digest) -> Result<bool, Error> {
async fn has(&self, digest: &B3Digest) -> Result<bool, Error> {
let db = self.db.read().unwrap();
Ok(db.contains_key(digest))
}
fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, Error> {
async fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, Error> {
let db = self.db.read().unwrap();
match db.get(digest).map(|x| Cursor::new(x.clone())) {
@ -46,7 +49,7 @@ impl BlobService for MemoryBlobService {
}
#[instrument(skip(self))]
fn open_write(&self) -> Box<dyn BlobWriter> {
async fn open_write(&self) -> Box<dyn BlobWriter> {
Box::new(MemoryBlobWriter::new(self.db.clone()))
}
}
@ -70,9 +73,13 @@ impl MemoryBlobWriter {
}
}
}
impl std::io::Write for MemoryBlobWriter {
fn write(&mut self, b: &[u8]) -> std::io::Result<usize> {
match &mut self.writers {
impl tokio::io::AsyncWrite for MemoryBlobWriter {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
b: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
Poll::Ready(match &mut self.writers {
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"already closed",
@ -81,22 +88,34 @@ impl std::io::Write for MemoryBlobWriter {
let bytes_written = buf.write(b)?;
hasher.write(&b[..bytes_written])
}
}
})
}
fn flush(&mut self) -> std::io::Result<()> {
match &mut self.writers {
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
Poll::Ready(match self.writers {
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"already closed",
)),
Some(_) => Ok(()),
}
})
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
// shutdown is "instantaneous", we only write to memory.
Poll::Ready(Ok(()))
}
}
#[async_trait]
impl BlobWriter for MemoryBlobWriter {
fn close(&mut self) -> Result<B3Digest, Error> {
async fn close(&mut self) -> Result<B3Digest, Error> {
if self.writers.is_none() {
match &self.digest {
Some(digest) => Ok(digest.clone()),

View file

@ -1,11 +1,12 @@
use std::io;
use tonic::async_trait;
use crate::{B3Digest, Error};
mod dumb_seeker;
mod from_addr;
mod grpc;
mod memory;
mod naive_seeker;
mod sled;
#[cfg(test)]
@ -21,35 +22,41 @@ pub use self::sled::SledBlobService;
/// a way to get a [io::Read] to a blob, and a method to initiate writing a new
/// Blob, which will return something implmenting io::Write, and providing a
/// close funtion, to finalize a blob and get its digest.
#[async_trait]
pub trait BlobService: Send + Sync {
/// Create a new instance by passing in a connection URL.
/// TODO: check if we want to make this async, instead of lazily connecting
fn from_url(url: &url::Url) -> Result<Self, Error>
where
Self: Sized;
/// Check if the service has the blob, by its content hash.
fn has(&self, digest: &B3Digest) -> Result<bool, Error>;
async fn has(&self, digest: &B3Digest) -> Result<bool, Error>;
/// Request a blob from the store, by its content hash.
fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, Error>;
async fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, Error>;
/// Insert a new blob into the store. Returns a [BlobWriter], which
/// implements [io::Write] and a [BlobWriter::close].
fn open_write(&self) -> Box<dyn BlobWriter>;
async fn open_write(&self) -> Box<dyn BlobWriter>;
}
/// A [io::Write] that you need to close() afterwards, and get back the digest
/// of the written blob.
pub trait BlobWriter: io::Write + Send + Sync + 'static {
/// A [tokio::io::AsyncWrite] that you need to close() afterwards, and get back
/// the digest of the written blob.
#[async_trait]
pub trait BlobWriter: tokio::io::AsyncWrite + Send + Sync + Unpin + 'static {
/// Signal there's no more data to be written, and return the digest of the
/// contents written.
///
/// Closing a already-closed BlobWriter is a no-op.
fn close(&mut self) -> Result<B3Digest, Error>;
async fn close(&mut self) -> Result<B3Digest, Error>;
}
/// A [io::Read] that also allows seeking.
pub trait BlobReader: io::Read + io::Seek + Send + 'static {}
/// A [tokio::io::AsyncRead] that also allows seeking.
pub trait BlobReader:
tokio::io::AsyncRead + tokio::io::AsyncSeek + tokio::io::AsyncBufRead + Send + Unpin + 'static
{
}
/// A [`io::Cursor<Vec<u8>>`] can be used as a BlobReader.
impl BlobReader for io::Cursor<Vec<u8>> {}

View file

@ -0,0 +1,269 @@
use super::BlobReader;
use pin_project_lite::pin_project;
use std::io;
use std::task::Poll;
use tokio::io::AsyncRead;
use tracing::{debug, instrument};
pin_project! {
/// This implements [tokio::io::AsyncSeek] for and [tokio::io::AsyncRead] by
/// simply skipping over some bytes, keeping track of the position.
/// It fails whenever you try to seek backwards.
///
/// ## Pinning concerns:
///
/// [NaiveSeeker] is itself pinned by callers, and we do not need to concern
/// ourselves regarding that.
///
/// Though, its fields as per
/// <https://doc.rust-lang.org/std/pin/#pinning-is-not-structural-for-field>
/// can be pinned or unpinned.
///
/// So we need to go over each field and choose our policy carefully.
///
/// The obvious cases are the bookkeeping integers we keep in the structure,
/// those are private and not shared to anyone, we never build a
/// `Pin<&mut X>` out of them at any point, therefore, we can safely never
/// mark them as pinned. Of course, it is expected that no developer here
/// attempt to `pin!(self.pos)` to pin them because it makes no sense. If
/// they have to become pinned, they should be marked `#[pin]` and we need
/// to discuss it.
///
/// So the bookkeeping integers are in the right state with respect to their
/// pinning status. The projection should offer direct access.
///
/// On the `r` field, i.e. a `BufReader<R>`, given that
/// <https://docs.rs/tokio/latest/tokio/io/struct.BufReader.html#impl-Unpin-for-BufReader%3CR%3E>
/// is available, even a `Pin<&mut BufReader<R>>` can be safely moved.
///
/// The only care we should have regards the internal reader itself, i.e.
/// the `R` instance, see that Tokio decided to `#[pin]` it too:
/// <https://docs.rs/tokio/latest/src/tokio/io/util/buf_reader.rs.html#29>
///
/// In general, there's no `Unpin` instance for `R: tokio::io::AsyncRead`
/// (see <https://docs.rs/tokio/latest/tokio/io/trait.AsyncRead.html>).
///
/// Therefore, we could keep it unpinned and pin it in every call site
/// whenever we need to call `poll_*` which can be confusing to the non-
/// expert developer and we have a fair share amount of situations where the
/// [BufReader] instance is naked, i.e. in its `&mut BufReader<R>`
/// form, this is annoying because it could lead to expose the naked `R`
/// internal instance somehow and would produce a risk of making it move
/// unexpectedly.
///
/// We choose the path of the least resistance as we have no reason to have
/// access to the raw `BufReader<R>` instance, we just `#[pin]` it too and
/// enjoy its `poll_*` safe APIs and push the unpinning concerns to the
/// internal implementations themselves, which studied the question longer
/// than us.
pub struct NaiveSeeker<R: tokio::io::AsyncRead> {
#[pin]
r: tokio::io::BufReader<R>,
pos: u64,
bytes_to_skip: u64,
}
}
impl<R: tokio::io::AsyncRead> NaiveSeeker<R> {
pub fn new(r: R) -> Self {
NaiveSeeker {
r: tokio::io::BufReader::new(r),
pos: 0,
bytes_to_skip: 0,
}
}
}
impl<R: tokio::io::AsyncRead> tokio::io::AsyncRead for NaiveSeeker<R> {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
// The amount of data read can be determined by the increase
// in the length of the slice returned by `ReadBuf::filled`.
let filled_before = buf.filled().len();
let this = self.project();
let pos: &mut u64 = this.pos;
match this.r.poll_read(cx, buf) {
Poll::Ready(a) => {
let bytes_read = buf.filled().len() - filled_before;
*pos += bytes_read as u64;
Poll::Ready(a)
}
Poll::Pending => Poll::Pending,
}
}
}
impl<R: tokio::io::AsyncRead> tokio::io::AsyncBufRead for NaiveSeeker<R> {
fn poll_fill_buf(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<io::Result<&[u8]>> {
self.project().r.poll_fill_buf(cx)
}
fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
let this = self.project();
this.r.consume(amt);
let pos: &mut u64 = this.pos;
*pos += amt as u64;
}
}
impl<R: tokio::io::AsyncRead> tokio::io::AsyncSeek for NaiveSeeker<R> {
#[instrument(skip(self))]
fn start_seek(
self: std::pin::Pin<&mut Self>,
position: std::io::SeekFrom,
) -> std::io::Result<()> {
let absolute_offset: u64 = match position {
io::SeekFrom::Start(start_offset) => {
if start_offset < self.pos {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
format!("can't seek backwards ({} -> {})", self.pos, start_offset),
));
} else {
start_offset
}
}
// we don't know the total size, can't support this.
io::SeekFrom::End(_end_offset) => {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"can't seek from end",
));
}
io::SeekFrom::Current(relative_offset) => {
if relative_offset < 0 {
return Err(io::Error::new(
io::ErrorKind::Unsupported,
"can't seek backwards relative to current position",
));
} else {
self.pos + relative_offset as u64
}
}
};
debug!(absolute_offset=?absolute_offset, "seek");
// we already know absolute_offset is larger than self.pos
debug_assert!(
absolute_offset >= self.pos,
"absolute_offset {} is larger than self.pos {}",
absolute_offset,
self.pos
);
// calculate bytes to skip
*self.project().bytes_to_skip = absolute_offset - self.pos;
Ok(())
}
#[instrument(skip(self))]
fn poll_complete(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<std::io::Result<u64>> {
if self.bytes_to_skip == 0 {
// return the new position (from the start of the stream)
return Poll::Ready(Ok(self.pos));
}
// discard some bytes, until pos is where we want it to be.
// We create a buffer that we'll discard later on.
let mut buf = [0; 1024];
// Loop until we've reached the desired seek position. This is done by issuing repeated
// `poll_read` calls. If the data is not available yet, we will yield back to the executor
// and wait to be polled again.
loop {
// calculate the length we want to skip at most, which is either a max
// buffer size, or the number of remaining bytes to read, whatever is
// smaller.
let bytes_to_skip = std::cmp::min(self.bytes_to_skip as usize, buf.len());
let mut read_buf = tokio::io::ReadBuf::new(&mut buf[..bytes_to_skip]);
match self.as_mut().poll_read(cx, &mut read_buf) {
Poll::Ready(_a) => {
let bytes_read = read_buf.filled().len() as u64;
if bytes_read == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!(
"tried to skip {} bytes, but only was able to skip {} until reaching EOF",
bytes_to_skip, bytes_read
),
)));
}
// calculate bytes to skip
let bytes_to_skip = self.bytes_to_skip - bytes_read;
*self.as_mut().project().bytes_to_skip = bytes_to_skip;
if bytes_to_skip == 0 {
return Poll::Ready(Ok(self.pos));
}
}
Poll::Pending => return Poll::Pending,
};
}
}
}
impl<R: tokio::io::AsyncRead + Send + Unpin + 'static> BlobReader for NaiveSeeker<R> {}
#[cfg(test)]
mod tests {
use super::NaiveSeeker;
use std::io::{Cursor, SeekFrom};
use tokio::io::{AsyncReadExt, AsyncSeekExt};
/// This seek requires multiple `poll_read` as we use a 1024 bytes internal
/// buffer when doing the seek.
/// This ensures we don't hang indefinitely.
#[tokio::test]
async fn seek() {
let buf = vec![0u8; 4096];
let reader = Cursor::new(&buf);
let mut seeker = NaiveSeeker::new(reader);
seeker.seek(SeekFrom::Start(4000)).await.unwrap();
}
#[tokio::test]
async fn seek_read() {
let mut buf = vec![0u8; 2048];
buf.extend_from_slice(&[1u8; 2048]);
buf.extend_from_slice(&[2u8; 2048]);
let reader = Cursor::new(&buf);
let mut seeker = NaiveSeeker::new(reader);
let mut read_buf = vec![0u8; 1024];
seeker.read_exact(&mut read_buf).await.expect("must read");
assert_eq!(read_buf.as_slice(), &[0u8; 1024]);
seeker
.seek(SeekFrom::Current(1024))
.await
.expect("must seek");
seeker.read_exact(&mut read_buf).await.expect("must read");
assert_eq!(read_buf.as_slice(), &[1u8; 1024]);
seeker
.seek(SeekFrom::Start(2 * 2048))
.await
.expect("must seek");
seeker.read_exact(&mut read_buf).await.expect("must read");
assert_eq!(read_buf.as_slice(), &[2u8; 1024]);
}
}

View file

@ -1,9 +1,11 @@
use super::{BlobReader, BlobService, BlobWriter};
use crate::{B3Digest, Error};
use std::{
io::{self, Cursor},
io::{self, Cursor, Write},
path::PathBuf,
task::Poll,
};
use tonic::async_trait;
use tracing::instrument;
#[derive(Clone)]
@ -27,6 +29,7 @@ impl SledBlobService {
}
}
#[async_trait]
impl BlobService for SledBlobService {
/// Constructs a [SledBlobService] from the passed [url::Url]:
/// - scheme has to be `sled://`
@ -57,7 +60,7 @@ impl BlobService for SledBlobService {
}
#[instrument(skip(self), fields(blob.digest=%digest))]
fn has(&self, digest: &B3Digest) -> Result<bool, Error> {
async fn has(&self, digest: &B3Digest) -> Result<bool, Error> {
match self.db.contains_key(digest.to_vec()) {
Ok(has) => Ok(has),
Err(e) => Err(Error::StorageError(e.to_string())),
@ -65,7 +68,7 @@ impl BlobService for SledBlobService {
}
#[instrument(skip(self), fields(blob.digest=%digest))]
fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, Error> {
async fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, Error> {
match self.db.get(digest.to_vec()) {
Ok(None) => Ok(None),
Ok(Some(data)) => Ok(Some(Box::new(Cursor::new(data[..].to_vec())))),
@ -74,7 +77,7 @@ impl BlobService for SledBlobService {
}
#[instrument(skip(self))]
fn open_write(&self) -> Box<dyn BlobWriter> {
async fn open_write(&self) -> Box<dyn BlobWriter> {
Box::new(SledBlobWriter::new(self.db.clone()))
}
}
@ -99,9 +102,13 @@ impl SledBlobWriter {
}
}
impl io::Write for SledBlobWriter {
fn write(&mut self, b: &[u8]) -> io::Result<usize> {
match &mut self.writers {
impl tokio::io::AsyncWrite for SledBlobWriter {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
b: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
Poll::Ready(match &mut self.writers {
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"already closed",
@ -110,22 +117,34 @@ impl io::Write for SledBlobWriter {
let bytes_written = buf.write(b)?;
hasher.write(&b[..bytes_written])
}
}
})
}
fn flush(&mut self) -> io::Result<()> {
match &mut self.writers {
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
Poll::Ready(match &mut self.writers {
None => Err(io::Error::new(
io::ErrorKind::NotConnected,
"already closed",
)),
Some(_) => Ok(()),
}
})
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
// shutdown is "instantaneous", we only write to a Vec<u8> as buffer.
Poll::Ready(Ok(()))
}
}
#[async_trait]
impl BlobWriter for SledBlobWriter {
fn close(&mut self) -> Result<B3Digest, Error> {
async fn close(&mut self) -> Result<B3Digest, Error> {
if self.writers.is_none() {
match &self.digest {
Some(digest) => Ok(digest.clone()),

View file

@ -1,6 +1,9 @@
use std::io;
use std::pin::pin;
use test_case::test_case;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncSeekExt;
use super::B3Digest;
use super::BlobService;
@ -24,19 +27,25 @@ fn gen_sled_blob_service() -> impl BlobService {
#[test_case(gen_memory_blob_service(); "memory")]
#[test_case(gen_sled_blob_service(); "sled")]
fn has_nonexistent_false(blob_service: impl BlobService) {
assert!(!blob_service
.has(&fixtures::BLOB_A_DIGEST)
.expect("must not fail"));
tokio::runtime::Runtime::new().unwrap().block_on(async {
assert!(!blob_service
.has(&fixtures::BLOB_A_DIGEST)
.await
.expect("must not fail"));
})
}
/// Trying to read a non-existing blob should return a None instead of a reader.
#[test_case(gen_memory_blob_service(); "memory")]
#[test_case(gen_sled_blob_service(); "sled")]
fn not_found_read(blob_service: impl BlobService) {
assert!(blob_service
.open_read(&fixtures::BLOB_A_DIGEST)
.expect("must not fail")
.is_none())
tokio::runtime::Runtime::new().unwrap().block_on(async {
assert!(blob_service
.open_read(&fixtures::BLOB_A_DIGEST)
.await
.expect("must not fail")
.is_none())
})
}
/// Put a blob in the store, check has, get it back.
@ -46,165 +55,192 @@ fn not_found_read(blob_service: impl BlobService) {
#[test_case(gen_memory_blob_service(), &fixtures::BLOB_B, &fixtures::BLOB_B_DIGEST; "memory-big")]
#[test_case(gen_sled_blob_service(), &fixtures::BLOB_B, &fixtures::BLOB_B_DIGEST; "sled-big")]
fn put_has_get(blob_service: impl BlobService, blob_contents: &[u8], blob_digest: &B3Digest) {
let mut w = blob_service.open_write();
tokio::runtime::Runtime::new().unwrap().block_on(async {
let mut w = blob_service.open_write().await;
let l = io::copy(&mut io::Cursor::new(blob_contents), &mut w).expect("copy must succeed");
assert_eq!(
blob_contents.len(),
l as usize,
"written bytes must match blob length"
);
let l = tokio::io::copy(&mut io::Cursor::new(blob_contents), &mut w)
.await
.expect("copy must succeed");
assert_eq!(
blob_contents.len(),
l as usize,
"written bytes must match blob length"
);
let digest = w.close().expect("close must succeed");
let digest = w.close().await.expect("close must succeed");
assert_eq!(*blob_digest, digest, "returned digest must be correct");
assert_eq!(*blob_digest, digest, "returned digest must be correct");
assert!(
blob_service.has(blob_digest).expect("must not fail"),
"blob service should now have the blob"
);
assert!(
blob_service.has(blob_digest).await.expect("must not fail"),
"blob service should now have the blob"
);
let mut r = blob_service
.open_read(blob_digest)
.expect("open_read must succeed")
.expect("must be some");
let mut r = blob_service
.open_read(blob_digest)
.await
.expect("open_read must succeed")
.expect("must be some");
let mut buf: Vec<u8> = Vec::new();
let l = io::copy(&mut r, &mut buf).expect("copy must succeed");
let mut buf: Vec<u8> = Vec::new();
let mut pinned_reader = pin!(r);
let l = tokio::io::copy(&mut pinned_reader, &mut buf)
.await
.expect("copy must succeed");
// let l = io::copy(&mut r, &mut buf).expect("copy must succeed");
assert_eq!(
blob_contents.len(),
l as usize,
"read bytes must match blob length"
);
assert_eq!(
blob_contents.len(),
l as usize,
"read bytes must match blob length"
);
assert_eq!(blob_contents, buf, "read blob contents must match");
assert_eq!(blob_contents, buf, "read blob contents must match");
})
}
/// Put a blob in the store, and seek inside it a bit.
#[test_case(gen_memory_blob_service(); "memory")]
#[test_case(gen_sled_blob_service(); "sled")]
fn put_seek(blob_service: impl BlobService) {
let mut w = blob_service.open_write();
tokio::runtime::Runtime::new().unwrap().block_on(async {
let mut w = blob_service.open_write().await;
io::copy(&mut io::Cursor::new(&fixtures::BLOB_B.to_vec()), &mut w).expect("copy must succeed");
w.close().expect("close must succeed");
tokio::io::copy(&mut io::Cursor::new(&fixtures::BLOB_B.to_vec()), &mut w)
.await
.expect("copy must succeed");
w.close().await.expect("close must succeed");
// open a blob for reading
let mut r = blob_service
.open_read(&fixtures::BLOB_B_DIGEST)
.expect("open_read must succeed")
.expect("must be some");
// open a blob for reading
let mut r = blob_service
.open_read(&fixtures::BLOB_B_DIGEST)
.await
.expect("open_read must succeed")
.expect("must be some");
let mut pos: u64 = 0;
let mut pos: u64 = 0;
// read the first 10 bytes, they must match the data in the fixture.
{
let mut buf = [0; 10];
r.read_exact(&mut buf).expect("must succeed");
assert_eq!(
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
buf,
"expected first 10 bytes to match"
);
pos += buf.len() as u64;
}
// seek by 0 bytes, using SeekFrom::Start.
let p = r.seek(io::SeekFrom::Start(pos)).expect("must not fail");
assert_eq!(pos, p);
// read the next 10 bytes, they must match the data in the fixture.
{
let mut buf = [0; 10];
r.read_exact(&mut buf).expect("must succeed");
assert_eq!(
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
buf,
"expected data to match"
);
pos += buf.len() as u64;
}
// seek by 5 bytes, using SeekFrom::Start.
let p = r.seek(io::SeekFrom::Start(pos + 5)).expect("must not fail");
pos += 5;
assert_eq!(pos, p);
// read the next 10 bytes, they must match the data in the fixture.
{
let mut buf = [0; 10];
r.read_exact(&mut buf).expect("must succeed");
assert_eq!(
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
buf,
"expected data to match"
);
pos += buf.len() as u64;
}
// seek by 12345 bytes, using SeekFrom::
let p = r.seek(io::SeekFrom::Current(12345)).expect("must not fail");
pos += 12345;
assert_eq!(pos, p);
// read the next 10 bytes, they must match the data in the fixture.
{
let mut buf = [0; 10];
r.read_exact(&mut buf).expect("must succeed");
assert_eq!(
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
buf,
"expected data to match"
);
#[allow(unused_assignments)]
// read the first 10 bytes, they must match the data in the fixture.
{
let mut buf = [0; 10];
r.read_exact(&mut buf).await.expect("must succeed");
assert_eq!(
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
buf,
"expected first 10 bytes to match"
);
pos += buf.len() as u64;
}
}
// seek by 0 bytes, using SeekFrom::Start.
let p = r
.seek(io::SeekFrom::Start(pos))
.await
.expect("must not fail");
assert_eq!(pos, p);
// seeking to the end is okay…
let p = r
.seek(io::SeekFrom::Start(fixtures::BLOB_B.len() as u64))
.expect("must not fail");
pos = fixtures::BLOB_B.len() as u64;
assert_eq!(pos, p);
// read the next 10 bytes, they must match the data in the fixture.
{
let mut buf = [0; 10];
r.read_exact(&mut buf).await.expect("must succeed");
{
// but it returns no more data.
let mut buf: Vec<u8> = Vec::new();
r.read_to_end(&mut buf).expect("must not fail");
assert!(buf.is_empty(), "expected no more data to be read");
}
assert_eq!(
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
buf,
"expected data to match"
);
// seeking past the end…
match r.seek(io::SeekFrom::Start(fixtures::BLOB_B.len() as u64 + 1)) {
// should either be ok, but then return 0 bytes.
// this matches the behaviour or a Cursor<Vec<u8>>.
Ok(_pos) => {
pos += buf.len() as u64;
}
// seek by 5 bytes, using SeekFrom::Start.
let p = r
.seek(io::SeekFrom::Start(pos + 5))
.await
.expect("must not fail");
pos += 5;
assert_eq!(pos, p);
// read the next 10 bytes, they must match the data in the fixture.
{
let mut buf = [0; 10];
r.read_exact(&mut buf).await.expect("must succeed");
assert_eq!(
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
buf,
"expected data to match"
);
pos += buf.len() as u64;
}
// seek by 12345 bytes, using SeekFrom::
let p = r
.seek(io::SeekFrom::Current(12345))
.await
.expect("must not fail");
pos += 12345;
assert_eq!(pos, p);
// read the next 10 bytes, they must match the data in the fixture.
{
let mut buf = [0; 10];
r.read_exact(&mut buf).await.expect("must succeed");
assert_eq!(
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
buf,
"expected data to match"
);
#[allow(unused_assignments)]
{
pos += buf.len() as u64;
}
}
// seeking to the end is okay…
let p = r
.seek(io::SeekFrom::Start(fixtures::BLOB_B.len() as u64))
.await
.expect("must not fail");
pos = fixtures::BLOB_B.len() as u64;
assert_eq!(pos, p);
{
// but it returns no more data.
let mut buf: Vec<u8> = Vec::new();
r.read_to_end(&mut buf).expect("must not fail");
r.read_to_end(&mut buf).await.expect("must not fail");
assert!(buf.is_empty(), "expected no more data to be read");
}
// or not be okay.
Err(_) => {}
}
// TODO: this is only broken for the gRPC version
// We expect seeking backwards or relative to the end to fail.
// r.seek(io::SeekFrom::Current(-1))
// .expect_err("SeekFrom::Current(-1) expected to fail");
// seeking past the end…
match r
.seek(io::SeekFrom::Start(fixtures::BLOB_B.len() as u64 + 1))
.await
{
// should either be ok, but then return 0 bytes.
// this matches the behaviour or a Cursor<Vec<u8>>.
Ok(_pos) => {
let mut buf: Vec<u8> = Vec::new();
r.read_to_end(&mut buf).await.expect("must not fail");
assert!(buf.is_empty(), "expected no more data to be read");
}
// or not be okay.
Err(_) => {}
}
// r.seek(io::SeekFrom::Start(pos - 1))
// .expect_err("SeekFrom::Start(pos-1) expected to fail");
// TODO: this is only broken for the gRPC version
// We expect seeking backwards or relative to the end to fail.
// r.seek(io::SeekFrom::Current(-1))
// .expect_err("SeekFrom::Current(-1) expected to fail");
// r.seek(io::SeekFrom::End(0))
// .expect_err("SeekFrom::End(_) expected to fail");
// r.seek(io::SeekFrom::Start(pos - 1))
// .expect_err("SeekFrom::Start(pos-1) expected to fail");
// r.seek(io::SeekFrom::End(0))
// .expect_err("SeekFrom::End(_) expected to fail");
})
}