refactor(tvix): move castore into tvix-castore crate
This splits the pure content-addressed layers from tvix-store into a `castore` crate, and only leaves PathInfo related things, as well as the CLI entrypoint in the tvix-store crate. Notable changes: - `fixtures` and `utils` had to be moved out of the `test` cfg, so they can be imported from tvix-store. - Some ad-hoc fixtures in the test were moved to proper fixtures in the same step. - The protos are now created by a (more static) recipe in the protos/ directory. The (now two) golang targets are commented out, as it's not possible to update them properly in the same CL. This will be done by a followup CL once this is merged (and whitby deployed) Bug: https://b.tvl.fyi/issues/301 Change-Id: I8d675d4bf1fb697eb7d479747c1b1e3635718107 Reviewed-on: https://cl.tvl.fyi/c/depot/+/9370 Reviewed-by: tazjin <tazjin@tvl.su> Reviewed-by: flokli <flokli@flokli.de> Autosubmit: flokli <flokli@flokli.de> Tested-by: BuildkiteCI Reviewed-by: Connor Brewster <cbrewster@hey.com>
This commit is contained in:
parent
d8ef0cfb4a
commit
32f41458c0
89 changed files with 2308 additions and 1829 deletions
30
tvix/castore/src/blobservice/from_addr.rs
Normal file
30
tvix/castore/src/blobservice/from_addr.rs
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
use std::sync::Arc;
|
||||
use url::Url;
|
||||
|
||||
use super::{BlobService, GRPCBlobService, MemoryBlobService, SledBlobService};
|
||||
|
||||
/// Constructs a new instance of a [BlobService] from an URI.
|
||||
///
|
||||
/// The following schemes are supported by the following services:
|
||||
/// - `memory://` ([MemoryBlobService])
|
||||
/// - `sled://` ([SledBlobService])
|
||||
/// - `grpc+*://` ([GRPCBlobService])
|
||||
///
|
||||
/// See their `from_url` methods for more details about their syntax.
|
||||
pub fn from_addr(uri: &str) -> Result<Arc<dyn BlobService>, crate::Error> {
|
||||
let url = Url::parse(uri)
|
||||
.map_err(|e| crate::Error::StorageError(format!("unable to parse url: {}", e)))?;
|
||||
|
||||
Ok(if url.scheme() == "memory" {
|
||||
Arc::new(MemoryBlobService::from_url(&url)?)
|
||||
} else if url.scheme() == "sled" {
|
||||
Arc::new(SledBlobService::from_url(&url)?)
|
||||
} else if url.scheme().starts_with("grpc+") {
|
||||
Arc::new(GRPCBlobService::from_url(&url)?)
|
||||
} else {
|
||||
Err(crate::Error::StorageError(format!(
|
||||
"unknown scheme: {}",
|
||||
url.scheme()
|
||||
)))?
|
||||
})
|
||||
}
|
||||
426
tvix/castore/src/blobservice/grpc.rs
Normal file
426
tvix/castore/src/blobservice/grpc.rs
Normal file
|
|
@ -0,0 +1,426 @@
|
|||
use super::{naive_seeker::NaiveSeeker, BlobReader, BlobService, BlobWriter};
|
||||
use crate::{proto, B3Digest};
|
||||
use futures::sink::SinkExt;
|
||||
use futures::TryFutureExt;
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
io::{self},
|
||||
pin::pin,
|
||||
task::Poll,
|
||||
};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::{net::UnixStream, task::JoinHandle};
|
||||
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
|
||||
use tokio_util::{
|
||||
io::{CopyToBytes, SinkWriter},
|
||||
sync::{PollSendError, PollSender},
|
||||
};
|
||||
use tonic::{async_trait, transport::Channel, Code, Status};
|
||||
use tracing::instrument;
|
||||
|
||||
/// Connects to a (remote) tvix-store BlobService over gRPC.
|
||||
#[derive(Clone)]
|
||||
pub struct GRPCBlobService {
|
||||
/// The internal reference to a gRPC client.
|
||||
/// Cloning it is cheap, and it internally handles concurrent requests.
|
||||
grpc_client: proto::blob_service_client::BlobServiceClient<Channel>,
|
||||
}
|
||||
|
||||
impl GRPCBlobService {
|
||||
/// construct a [GRPCBlobService] from a [proto::blob_service_client::BlobServiceClient].
|
||||
/// panics if called outside the context of a tokio runtime.
|
||||
pub fn from_client(
|
||||
grpc_client: proto::blob_service_client::BlobServiceClient<Channel>,
|
||||
) -> Self {
|
||||
Self { grpc_client }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BlobService for GRPCBlobService {
|
||||
/// Constructs a [GRPCBlobService] from the passed [url::Url]:
|
||||
/// - scheme has to match `grpc+*://`.
|
||||
/// That's normally grpc+unix for unix sockets, and grpc+http(s) for the HTTP counterparts.
|
||||
/// - In the case of unix sockets, there must be a path, but may not be a host.
|
||||
/// - In the case of non-unix sockets, there must be a host, but no path.
|
||||
fn from_url(url: &url::Url) -> Result<Self, crate::Error> {
|
||||
// Start checking for the scheme to start with grpc+.
|
||||
match url.scheme().strip_prefix("grpc+") {
|
||||
None => Err(crate::Error::StorageError("invalid scheme".to_string())),
|
||||
Some(rest) => {
|
||||
if rest == "unix" {
|
||||
if url.host_str().is_some() {
|
||||
return Err(crate::Error::StorageError(
|
||||
"host may not be set".to_string(),
|
||||
));
|
||||
}
|
||||
let path = url.path().to_string();
|
||||
let channel = tonic::transport::Endpoint::try_from("http://[::]:50051") // doesn't matter
|
||||
.unwrap()
|
||||
.connect_with_connector_lazy(tower::service_fn(
|
||||
move |_: tonic::transport::Uri| UnixStream::connect(path.clone()),
|
||||
));
|
||||
let grpc_client = proto::blob_service_client::BlobServiceClient::new(channel);
|
||||
Ok(Self::from_client(grpc_client))
|
||||
} else {
|
||||
// ensure path is empty, not supported with gRPC.
|
||||
if !url.path().is_empty() {
|
||||
return Err(crate::Error::StorageError(
|
||||
"path may not be set".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// clone the uri, and drop the grpc+ from the scheme.
|
||||
// Recreate a new uri with the `grpc+` prefix dropped from the scheme.
|
||||
// We can't use `url.set_scheme(rest)`, as it disallows
|
||||
// setting something http(s) that previously wasn't.
|
||||
let url = {
|
||||
let url_str = url.to_string();
|
||||
let s_stripped = url_str.strip_prefix("grpc+").unwrap();
|
||||
url::Url::parse(s_stripped).unwrap()
|
||||
};
|
||||
let channel = tonic::transport::Endpoint::try_from(url.to_string())
|
||||
.unwrap()
|
||||
.connect_lazy();
|
||||
|
||||
let grpc_client = proto::blob_service_client::BlobServiceClient::new(channel);
|
||||
Ok(Self::from_client(grpc_client))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self, digest), fields(blob.digest=%digest))]
|
||||
async fn has(&self, digest: &B3Digest) -> Result<bool, crate::Error> {
|
||||
let mut grpc_client = self.grpc_client.clone();
|
||||
let resp = grpc_client
|
||||
.stat(proto::StatBlobRequest {
|
||||
digest: digest.clone().into(),
|
||||
})
|
||||
.await;
|
||||
|
||||
match resp {
|
||||
Ok(_blob_meta) => Ok(true),
|
||||
Err(e) if e.code() == Code::NotFound => Ok(false),
|
||||
Err(e) => Err(crate::Error::StorageError(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
// On success, this returns a Ok(Some(io::Read)), which can be used to read
|
||||
// the contents of the Blob, identified by the digest.
|
||||
async fn open_read(
|
||||
&self,
|
||||
digest: &B3Digest,
|
||||
) -> Result<Option<Box<dyn BlobReader>>, crate::Error> {
|
||||
// Get a new handle to the gRPC client, and copy the digest.
|
||||
let mut grpc_client = self.grpc_client.clone();
|
||||
|
||||
// Get a stream of [proto::BlobChunk], or return an error if the blob
|
||||
// doesn't exist.
|
||||
let resp = grpc_client
|
||||
.read(proto::ReadBlobRequest {
|
||||
digest: digest.clone().into(),
|
||||
})
|
||||
.await;
|
||||
|
||||
// This runs the task to completion, which on success will return a stream.
|
||||
// On reading from it, we receive individual [proto::BlobChunk], so we
|
||||
// massage this to a stream of bytes,
|
||||
// then create an [AsyncRead], which we'll turn into a [io::Read],
|
||||
// that's returned from the function.
|
||||
match resp {
|
||||
Ok(stream) => {
|
||||
// map the stream of proto::BlobChunk to bytes.
|
||||
let data_stream = stream.into_inner().map(|x| {
|
||||
x.map(|x| VecDeque::from(x.data.to_vec()))
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
|
||||
});
|
||||
|
||||
// Use StreamReader::new to convert to an AsyncRead.
|
||||
let data_reader = tokio_util::io::StreamReader::new(data_stream);
|
||||
|
||||
Ok(Some(Box::new(NaiveSeeker::new(data_reader))))
|
||||
}
|
||||
Err(e) if e.code() == Code::NotFound => Ok(None),
|
||||
Err(e) => Err(crate::Error::StorageError(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a BlobWriter, that'll internally wrap each write in a
|
||||
// [proto::BlobChunk], which is send to the gRPC server.
|
||||
async fn open_write(&self) -> Box<dyn BlobWriter> {
|
||||
let mut grpc_client = self.grpc_client.clone();
|
||||
|
||||
// set up an mpsc channel passing around Bytes.
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<bytes::Bytes>(10);
|
||||
|
||||
// bytes arriving on the RX side are wrapped inside a
|
||||
// [proto::BlobChunk], and a [ReceiverStream] is constructed.
|
||||
let blobchunk_stream = ReceiverStream::new(rx).map(|x| proto::BlobChunk { data: x });
|
||||
|
||||
// That receiver stream is used as a stream in the gRPC BlobService.put rpc call.
|
||||
let task: JoinHandle<Result<_, Status>> =
|
||||
tokio::spawn(async move { Ok(grpc_client.put(blobchunk_stream).await?.into_inner()) });
|
||||
|
||||
// The tx part of the channel is converted to a sink of byte chunks.
|
||||
|
||||
// We need to make this a function pointer, not a closure.
|
||||
fn convert_error(_: PollSendError<bytes::Bytes>) -> io::Error {
|
||||
io::Error::from(io::ErrorKind::BrokenPipe)
|
||||
}
|
||||
|
||||
let sink = PollSender::new(tx)
|
||||
.sink_map_err(convert_error as fn(PollSendError<bytes::Bytes>) -> io::Error);
|
||||
// We need to explicitly cast here, otherwise rustc does error with "expected fn pointer, found fn item"
|
||||
|
||||
// … which is turned into an [tokio::io::AsyncWrite].
|
||||
let writer = SinkWriter::new(CopyToBytes::new(sink));
|
||||
|
||||
Box::new(GRPCBlobWriter {
|
||||
task_and_writer: Some((task, writer)),
|
||||
digest: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GRPCBlobWriter<W: tokio::io::AsyncWrite> {
|
||||
/// The task containing the put request, and the inner writer, if we're still writing.
|
||||
task_and_writer: Option<(JoinHandle<Result<proto::PutBlobResponse, Status>>, W)>,
|
||||
|
||||
/// The digest that has been returned, if we successfully closed.
|
||||
digest: Option<B3Digest>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<W: tokio::io::AsyncWrite + Send + Sync + Unpin + 'static> BlobWriter for GRPCBlobWriter<W> {
|
||||
async fn close(&mut self) -> Result<B3Digest, crate::Error> {
|
||||
if self.task_and_writer.is_none() {
|
||||
// if we're already closed, return the b3 digest, which must exist.
|
||||
// If it doesn't, we already closed and failed once, and didn't handle the error.
|
||||
match &self.digest {
|
||||
Some(digest) => Ok(digest.clone()),
|
||||
None => Err(crate::Error::StorageError(
|
||||
"previously closed with error".to_string(),
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
let (task, mut writer) = self.task_and_writer.take().unwrap();
|
||||
|
||||
// invoke shutdown, so the inner writer closes its internal tx side of
|
||||
// the channel.
|
||||
writer
|
||||
.shutdown()
|
||||
.map_err(|e| crate::Error::StorageError(e.to_string()))
|
||||
.await?;
|
||||
|
||||
// block on the RPC call to return.
|
||||
// This ensures all chunks are sent out, and have been received by the
|
||||
// backend.
|
||||
|
||||
match task.await? {
|
||||
Ok(resp) => {
|
||||
// return the digest from the response, and store it in self.digest for subsequent closes.
|
||||
let digest: B3Digest = resp.digest.try_into().map_err(|_| {
|
||||
crate::Error::StorageError(
|
||||
"invalid root digest length in response".to_string(),
|
||||
)
|
||||
})?;
|
||||
self.digest = Some(digest.clone());
|
||||
Ok(digest)
|
||||
}
|
||||
Err(e) => Err(crate::Error::StorageError(e.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for GRPCBlobWriter<W> {
|
||||
fn poll_write(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, io::Error>> {
|
||||
match &mut self.task_and_writer {
|
||||
None => Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"already closed",
|
||||
))),
|
||||
Some((_, ref mut writer)) => {
|
||||
let pinned_writer = pin!(writer);
|
||||
pinned_writer.poll_write(cx, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), io::Error>> {
|
||||
match &mut self.task_and_writer {
|
||||
None => Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"already closed",
|
||||
))),
|
||||
Some((_, ref mut writer)) => {
|
||||
let pinned_writer = pin!(writer);
|
||||
pinned_writer.poll_flush(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), io::Error>> {
|
||||
// TODO(raitobezarius): this might not be a graceful shutdown of the
|
||||
// channel inside the gRPC connection.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
|
||||
use tempfile::TempDir;
|
||||
use tokio::net::UnixListener;
|
||||
use tokio::time;
|
||||
use tokio_stream::wrappers::UnixListenerStream;
|
||||
|
||||
use crate::blobservice::MemoryBlobService;
|
||||
use crate::fixtures;
|
||||
use crate::proto::GRPCBlobServiceWrapper;
|
||||
|
||||
use super::BlobService;
|
||||
use super::GRPCBlobService;
|
||||
|
||||
/// This uses the wrong scheme
|
||||
#[test]
|
||||
fn test_invalid_scheme() {
|
||||
let url = url::Url::parse("http://foo.example/test").expect("must parse");
|
||||
|
||||
assert!(GRPCBlobService::from_url(&url).is_err());
|
||||
}
|
||||
|
||||
/// This uses the correct scheme for a unix socket.
|
||||
/// The fact that /path/to/somewhere doesn't exist yet is no problem, because we connect lazily.
|
||||
#[tokio::test]
|
||||
async fn test_valid_unix_path() {
|
||||
let url = url::Url::parse("grpc+unix:///path/to/somewhere").expect("must parse");
|
||||
|
||||
assert!(GRPCBlobService::from_url(&url).is_ok());
|
||||
}
|
||||
|
||||
/// This uses the correct scheme for a unix socket,
|
||||
/// but sets a host, which is unsupported.
|
||||
#[tokio::test]
|
||||
async fn test_invalid_unix_path_with_domain() {
|
||||
let url =
|
||||
url::Url::parse("grpc+unix://host.example/path/to/somewhere").expect("must parse");
|
||||
|
||||
assert!(GRPCBlobService::from_url(&url).is_err());
|
||||
}
|
||||
|
||||
/// This uses the correct scheme for a HTTP server.
|
||||
/// The fact that nothing is listening there is no problem, because we connect lazily.
|
||||
#[tokio::test]
|
||||
async fn test_valid_http() {
|
||||
let url = url::Url::parse("grpc+http://localhost").expect("must parse");
|
||||
|
||||
assert!(GRPCBlobService::from_url(&url).is_ok());
|
||||
}
|
||||
|
||||
/// This uses the correct scheme for a HTTPS server.
|
||||
/// The fact that nothing is listening there is no problem, because we connect lazily.
|
||||
#[tokio::test]
|
||||
async fn test_valid_https() {
|
||||
let url = url::Url::parse("grpc+https://localhost").expect("must parse");
|
||||
|
||||
assert!(GRPCBlobService::from_url(&url).is_ok());
|
||||
}
|
||||
|
||||
/// This uses the correct scheme, but also specifies
|
||||
/// an additional path, which is not supported for gRPC.
|
||||
/// The fact that nothing is listening there is no problem, because we connect lazily.
|
||||
#[tokio::test]
|
||||
async fn test_invalid_http_with_path() {
|
||||
let url = url::Url::parse("grpc+https://localhost/some-path").expect("must parse");
|
||||
|
||||
assert!(GRPCBlobService::from_url(&url).is_err());
|
||||
}
|
||||
|
||||
/// This uses the correct scheme for a unix socket, and provides a server on the other side.
|
||||
/// This is not a tokio::test, because spawn two separate tokio runtimes and
|
||||
// want to have explicit control.
|
||||
#[test]
|
||||
fn test_valid_unix_path_ping_pong() {
|
||||
let tmpdir = TempDir::new().unwrap();
|
||||
let path = tmpdir.path().join("daemon");
|
||||
|
||||
let path_clone = path.clone();
|
||||
|
||||
// Spin up a server, in a thread far away, which spawns its own tokio runtime,
|
||||
// and blocks on the task.
|
||||
thread::spawn(move || {
|
||||
// Create the runtime
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
|
||||
let task = rt.spawn(async {
|
||||
let uds = UnixListener::bind(path_clone).unwrap();
|
||||
let uds_stream = UnixListenerStream::new(uds);
|
||||
|
||||
// spin up a new server
|
||||
let mut server = tonic::transport::Server::builder();
|
||||
let router =
|
||||
server.add_service(crate::proto::blob_service_server::BlobServiceServer::new(
|
||||
GRPCBlobServiceWrapper::from(
|
||||
Arc::new(MemoryBlobService::default()) as Arc<dyn BlobService>
|
||||
),
|
||||
));
|
||||
router.serve_with_incoming(uds_stream).await
|
||||
});
|
||||
|
||||
rt.block_on(task).unwrap().unwrap();
|
||||
});
|
||||
|
||||
// Now create another tokio runtime which we'll use in the main test code.
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
|
||||
let task = rt.spawn(async move {
|
||||
// wait for the socket to be created
|
||||
{
|
||||
let mut socket_created = false;
|
||||
// TODO: exponential backoff urgently
|
||||
for _try in 1..20 {
|
||||
if path.exists() {
|
||||
socket_created = true;
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(time::Duration::from_millis(20)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
socket_created,
|
||||
"expected socket path to eventually get created, but never happened"
|
||||
);
|
||||
}
|
||||
|
||||
// prepare a client
|
||||
let client = {
|
||||
let mut url =
|
||||
url::Url::parse("grpc+unix:///path/to/somewhere").expect("must parse");
|
||||
url.set_path(path.to_str().unwrap());
|
||||
GRPCBlobService::from_url(&url).expect("must succeed")
|
||||
};
|
||||
|
||||
let has = client
|
||||
.has(&fixtures::BLOB_A_DIGEST)
|
||||
.await
|
||||
.expect("must not be err");
|
||||
|
||||
assert!(!has);
|
||||
});
|
||||
rt.block_on(task).unwrap()
|
||||
}
|
||||
}
|
||||
196
tvix/castore/src/blobservice/memory.rs
Normal file
196
tvix/castore/src/blobservice/memory.rs
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
use std::io::{self, Cursor, Write};
|
||||
use std::task::Poll;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
use tonic::async_trait;
|
||||
use tracing::instrument;
|
||||
|
||||
use super::{BlobReader, BlobService, BlobWriter};
|
||||
use crate::{B3Digest, Error};
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct MemoryBlobService {
|
||||
db: Arc<RwLock<HashMap<B3Digest, Vec<u8>>>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BlobService for MemoryBlobService {
|
||||
/// Constructs a [MemoryBlobService] from the passed [url::Url]:
|
||||
/// - scheme has to be `memory://`
|
||||
/// - there may not be a host.
|
||||
/// - there may not be a path.
|
||||
fn from_url(url: &url::Url) -> Result<Self, Error> {
|
||||
if url.scheme() != "memory" {
|
||||
return Err(crate::Error::StorageError("invalid scheme".to_string()));
|
||||
}
|
||||
|
||||
if url.has_host() || !url.path().is_empty() {
|
||||
return Err(crate::Error::StorageError("invalid url".to_string()));
|
||||
}
|
||||
|
||||
Ok(Self::default())
|
||||
}
|
||||
|
||||
#[instrument(skip(self, digest), fields(blob.digest=%digest))]
|
||||
async fn has(&self, digest: &B3Digest) -> Result<bool, Error> {
|
||||
let db = self.db.read().unwrap();
|
||||
Ok(db.contains_key(digest))
|
||||
}
|
||||
|
||||
async fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, Error> {
|
||||
let db = self.db.read().unwrap();
|
||||
|
||||
match db.get(digest).map(|x| Cursor::new(x.clone())) {
|
||||
Some(result) => Ok(Some(Box::new(result))),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn open_write(&self) -> Box<dyn BlobWriter> {
|
||||
Box::new(MemoryBlobWriter::new(self.db.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MemoryBlobWriter {
|
||||
db: Arc<RwLock<HashMap<B3Digest, Vec<u8>>>>,
|
||||
|
||||
/// Contains the buffer Vec and hasher, or None if already closed
|
||||
writers: Option<(Vec<u8>, blake3::Hasher)>,
|
||||
|
||||
/// The digest that has been returned, if we successfully closed.
|
||||
digest: Option<B3Digest>,
|
||||
}
|
||||
|
||||
impl MemoryBlobWriter {
|
||||
fn new(db: Arc<RwLock<HashMap<B3Digest, Vec<u8>>>>) -> Self {
|
||||
Self {
|
||||
db,
|
||||
writers: Some((Vec::new(), blake3::Hasher::new())),
|
||||
digest: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl tokio::io::AsyncWrite for MemoryBlobWriter {
|
||||
fn poll_write(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
b: &[u8],
|
||||
) -> std::task::Poll<Result<usize, io::Error>> {
|
||||
Poll::Ready(match &mut self.writers {
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"already closed",
|
||||
)),
|
||||
Some((ref mut buf, ref mut hasher)) => {
|
||||
let bytes_written = buf.write(b)?;
|
||||
hasher.write(&b[..bytes_written])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(match self.writers {
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"already closed",
|
||||
)),
|
||||
Some(_) => Ok(()),
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), io::Error>> {
|
||||
// shutdown is "instantaneous", we only write to memory.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BlobWriter for MemoryBlobWriter {
|
||||
async fn close(&mut self) -> Result<B3Digest, Error> {
|
||||
if self.writers.is_none() {
|
||||
match &self.digest {
|
||||
Some(digest) => Ok(digest.clone()),
|
||||
None => Err(crate::Error::StorageError(
|
||||
"previously closed with error".to_string(),
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
let (buf, hasher) = self.writers.take().unwrap();
|
||||
|
||||
// We know self.hasher is doing blake3 hashing, so this won't fail.
|
||||
let digest: B3Digest = hasher.finalize().as_bytes().into();
|
||||
|
||||
// Only insert if the blob doesn't already exist.
|
||||
let db = self.db.read()?;
|
||||
if !db.contains_key(&digest) {
|
||||
// drop the read lock, so we can open for writing.
|
||||
drop(db);
|
||||
|
||||
// open the database for writing.
|
||||
let mut db = self.db.write()?;
|
||||
|
||||
// and put buf in there. This will move buf out.
|
||||
db.insert(digest.clone(), buf);
|
||||
}
|
||||
|
||||
self.digest = Some(digest.clone());
|
||||
|
||||
Ok(digest)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::BlobService;
|
||||
use super::MemoryBlobService;
|
||||
|
||||
/// This uses a wrong scheme.
|
||||
#[test]
|
||||
fn test_invalid_scheme() {
|
||||
let url = url::Url::parse("http://foo.example/test").expect("must parse");
|
||||
|
||||
assert!(MemoryBlobService::from_url(&url).is_err());
|
||||
}
|
||||
|
||||
/// This correctly sets the scheme, and doesn't set a path.
|
||||
#[test]
|
||||
fn test_valid_scheme() {
|
||||
let url = url::Url::parse("memory://").expect("must parse");
|
||||
|
||||
assert!(MemoryBlobService::from_url(&url).is_ok());
|
||||
}
|
||||
|
||||
/// This sets the host to `foo`
|
||||
#[test]
|
||||
fn test_invalid_host() {
|
||||
let url = url::Url::parse("memory://foo").expect("must parse");
|
||||
|
||||
assert!(MemoryBlobService::from_url(&url).is_err());
|
||||
}
|
||||
|
||||
/// This has the path "/", which is invalid.
|
||||
#[test]
|
||||
fn test_invalid_has_path() {
|
||||
let url = url::Url::parse("memory:///").expect("must parse");
|
||||
|
||||
assert!(MemoryBlobService::from_url(&url).is_err());
|
||||
}
|
||||
|
||||
/// This has the path "/foo", which is invalid.
|
||||
#[test]
|
||||
fn test_invalid_path2() {
|
||||
let url = url::Url::parse("memory:///foo").expect("must parse");
|
||||
|
||||
assert!(MemoryBlobService::from_url(&url).is_err());
|
||||
}
|
||||
}
|
||||
62
tvix/castore/src/blobservice/mod.rs
Normal file
62
tvix/castore/src/blobservice/mod.rs
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
use std::io;
|
||||
use tonic::async_trait;
|
||||
|
||||
use crate::{B3Digest, Error};
|
||||
|
||||
mod from_addr;
|
||||
mod grpc;
|
||||
mod memory;
|
||||
mod naive_seeker;
|
||||
mod sled;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use self::from_addr::from_addr;
|
||||
pub use self::grpc::GRPCBlobService;
|
||||
pub use self::memory::MemoryBlobService;
|
||||
pub use self::sled::SledBlobService;
|
||||
|
||||
/// The base trait all BlobService services need to implement.
|
||||
/// It provides functions to check whether a given blob exists,
|
||||
/// a way to get a [io::Read] to a blob, and a method to initiate writing a new
|
||||
/// Blob, which will return something implmenting io::Write, and providing a
|
||||
/// close funtion, to finalize a blob and get its digest.
|
||||
#[async_trait]
|
||||
pub trait BlobService: Send + Sync {
|
||||
/// Create a new instance by passing in a connection URL.
|
||||
/// TODO: check if we want to make this async, instead of lazily connecting
|
||||
fn from_url(url: &url::Url) -> Result<Self, Error>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// Check if the service has the blob, by its content hash.
|
||||
async fn has(&self, digest: &B3Digest) -> Result<bool, Error>;
|
||||
|
||||
/// Request a blob from the store, by its content hash.
|
||||
async fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, Error>;
|
||||
|
||||
/// Insert a new blob into the store. Returns a [BlobWriter], which
|
||||
/// implements [io::Write] and a [BlobWriter::close].
|
||||
async fn open_write(&self) -> Box<dyn BlobWriter>;
|
||||
}
|
||||
|
||||
/// A [tokio::io::AsyncWrite] that you need to close() afterwards, and get back
|
||||
/// the digest of the written blob.
|
||||
#[async_trait]
|
||||
pub trait BlobWriter: tokio::io::AsyncWrite + Send + Sync + Unpin + 'static {
|
||||
/// Signal there's no more data to be written, and return the digest of the
|
||||
/// contents written.
|
||||
///
|
||||
/// Closing a already-closed BlobWriter is a no-op.
|
||||
async fn close(&mut self) -> Result<B3Digest, Error>;
|
||||
}
|
||||
|
||||
/// A [tokio::io::AsyncRead] that also allows seeking.
|
||||
pub trait BlobReader:
|
||||
tokio::io::AsyncRead + tokio::io::AsyncSeek + tokio::io::AsyncBufRead + Send + Unpin + 'static
|
||||
{
|
||||
}
|
||||
|
||||
/// A [`io::Cursor<Vec<u8>>`] can be used as a BlobReader.
|
||||
impl BlobReader for io::Cursor<Vec<u8>> {}
|
||||
269
tvix/castore/src/blobservice/naive_seeker.rs
Normal file
269
tvix/castore/src/blobservice/naive_seeker.rs
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
use super::BlobReader;
|
||||
use pin_project_lite::pin_project;
|
||||
use std::io;
|
||||
use std::task::Poll;
|
||||
use tokio::io::AsyncRead;
|
||||
use tracing::{debug, instrument};
|
||||
|
||||
pin_project! {
|
||||
/// This implements [tokio::io::AsyncSeek] for and [tokio::io::AsyncRead] by
|
||||
/// simply skipping over some bytes, keeping track of the position.
|
||||
/// It fails whenever you try to seek backwards.
|
||||
///
|
||||
/// ## Pinning concerns:
|
||||
///
|
||||
/// [NaiveSeeker] is itself pinned by callers, and we do not need to concern
|
||||
/// ourselves regarding that.
|
||||
///
|
||||
/// Though, its fields as per
|
||||
/// <https://doc.rust-lang.org/std/pin/#pinning-is-not-structural-for-field>
|
||||
/// can be pinned or unpinned.
|
||||
///
|
||||
/// So we need to go over each field and choose our policy carefully.
|
||||
///
|
||||
/// The obvious cases are the bookkeeping integers we keep in the structure,
|
||||
/// those are private and not shared to anyone, we never build a
|
||||
/// `Pin<&mut X>` out of them at any point, therefore, we can safely never
|
||||
/// mark them as pinned. Of course, it is expected that no developer here
|
||||
/// attempt to `pin!(self.pos)` to pin them because it makes no sense. If
|
||||
/// they have to become pinned, they should be marked `#[pin]` and we need
|
||||
/// to discuss it.
|
||||
///
|
||||
/// So the bookkeeping integers are in the right state with respect to their
|
||||
/// pinning status. The projection should offer direct access.
|
||||
///
|
||||
/// On the `r` field, i.e. a `BufReader<R>`, given that
|
||||
/// <https://docs.rs/tokio/latest/tokio/io/struct.BufReader.html#impl-Unpin-for-BufReader%3CR%3E>
|
||||
/// is available, even a `Pin<&mut BufReader<R>>` can be safely moved.
|
||||
///
|
||||
/// The only care we should have regards the internal reader itself, i.e.
|
||||
/// the `R` instance, see that Tokio decided to `#[pin]` it too:
|
||||
/// <https://docs.rs/tokio/latest/src/tokio/io/util/buf_reader.rs.html#29>
|
||||
///
|
||||
/// In general, there's no `Unpin` instance for `R: tokio::io::AsyncRead`
|
||||
/// (see <https://docs.rs/tokio/latest/tokio/io/trait.AsyncRead.html>).
|
||||
///
|
||||
/// Therefore, we could keep it unpinned and pin it in every call site
|
||||
/// whenever we need to call `poll_*` which can be confusing to the non-
|
||||
/// expert developer and we have a fair share amount of situations where the
|
||||
/// [BufReader] instance is naked, i.e. in its `&mut BufReader<R>`
|
||||
/// form, this is annoying because it could lead to expose the naked `R`
|
||||
/// internal instance somehow and would produce a risk of making it move
|
||||
/// unexpectedly.
|
||||
///
|
||||
/// We choose the path of the least resistance as we have no reason to have
|
||||
/// access to the raw `BufReader<R>` instance, we just `#[pin]` it too and
|
||||
/// enjoy its `poll_*` safe APIs and push the unpinning concerns to the
|
||||
/// internal implementations themselves, which studied the question longer
|
||||
/// than us.
|
||||
pub struct NaiveSeeker<R: tokio::io::AsyncRead> {
|
||||
#[pin]
|
||||
r: tokio::io::BufReader<R>,
|
||||
pos: u64,
|
||||
bytes_to_skip: u64,
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: tokio::io::AsyncRead> NaiveSeeker<R> {
|
||||
pub fn new(r: R) -> Self {
|
||||
NaiveSeeker {
|
||||
r: tokio::io::BufReader::new(r),
|
||||
pos: 0,
|
||||
bytes_to_skip: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: tokio::io::AsyncRead> tokio::io::AsyncRead for NaiveSeeker<R> {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
// The amount of data read can be determined by the increase
|
||||
// in the length of the slice returned by `ReadBuf::filled`.
|
||||
let filled_before = buf.filled().len();
|
||||
let this = self.project();
|
||||
let pos: &mut u64 = this.pos;
|
||||
|
||||
match this.r.poll_read(cx, buf) {
|
||||
Poll::Ready(a) => {
|
||||
let bytes_read = buf.filled().len() - filled_before;
|
||||
*pos += bytes_read as u64;
|
||||
|
||||
Poll::Ready(a)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: tokio::io::AsyncRead> tokio::io::AsyncBufRead for NaiveSeeker<R> {
|
||||
fn poll_fill_buf(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<io::Result<&[u8]>> {
|
||||
self.project().r.poll_fill_buf(cx)
|
||||
}
|
||||
|
||||
fn consume(self: std::pin::Pin<&mut Self>, amt: usize) {
|
||||
let this = self.project();
|
||||
this.r.consume(amt);
|
||||
let pos: &mut u64 = this.pos;
|
||||
*pos += amt as u64;
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: tokio::io::AsyncRead> tokio::io::AsyncSeek for NaiveSeeker<R> {
|
||||
#[instrument(skip(self))]
|
||||
fn start_seek(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
position: std::io::SeekFrom,
|
||||
) -> std::io::Result<()> {
|
||||
let absolute_offset: u64 = match position {
|
||||
io::SeekFrom::Start(start_offset) => {
|
||||
if start_offset < self.pos {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Unsupported,
|
||||
format!("can't seek backwards ({} -> {})", self.pos, start_offset),
|
||||
));
|
||||
} else {
|
||||
start_offset
|
||||
}
|
||||
}
|
||||
// we don't know the total size, can't support this.
|
||||
io::SeekFrom::End(_end_offset) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Unsupported,
|
||||
"can't seek from end",
|
||||
));
|
||||
}
|
||||
io::SeekFrom::Current(relative_offset) => {
|
||||
if relative_offset < 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::Unsupported,
|
||||
"can't seek backwards relative to current position",
|
||||
));
|
||||
} else {
|
||||
self.pos + relative_offset as u64
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
debug!(absolute_offset=?absolute_offset, "seek");
|
||||
|
||||
// we already know absolute_offset is larger than self.pos
|
||||
debug_assert!(
|
||||
absolute_offset >= self.pos,
|
||||
"absolute_offset {} is larger than self.pos {}",
|
||||
absolute_offset,
|
||||
self.pos
|
||||
);
|
||||
|
||||
// calculate bytes to skip
|
||||
*self.project().bytes_to_skip = absolute_offset - self.pos;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
fn poll_complete(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<std::io::Result<u64>> {
|
||||
if self.bytes_to_skip == 0 {
|
||||
// return the new position (from the start of the stream)
|
||||
return Poll::Ready(Ok(self.pos));
|
||||
}
|
||||
|
||||
// discard some bytes, until pos is where we want it to be.
|
||||
// We create a buffer that we'll discard later on.
|
||||
let mut buf = [0; 1024];
|
||||
|
||||
// Loop until we've reached the desired seek position. This is done by issuing repeated
|
||||
// `poll_read` calls. If the data is not available yet, we will yield back to the executor
|
||||
// and wait to be polled again.
|
||||
loop {
|
||||
// calculate the length we want to skip at most, which is either a max
|
||||
// buffer size, or the number of remaining bytes to read, whatever is
|
||||
// smaller.
|
||||
let bytes_to_skip = std::cmp::min(self.bytes_to_skip as usize, buf.len());
|
||||
|
||||
let mut read_buf = tokio::io::ReadBuf::new(&mut buf[..bytes_to_skip]);
|
||||
|
||||
match self.as_mut().poll_read(cx, &mut read_buf) {
|
||||
Poll::Ready(_a) => {
|
||||
let bytes_read = read_buf.filled().len() as u64;
|
||||
|
||||
if bytes_read == 0 {
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
format!(
|
||||
"tried to skip {} bytes, but only was able to skip {} until reaching EOF",
|
||||
bytes_to_skip, bytes_read
|
||||
),
|
||||
)));
|
||||
}
|
||||
|
||||
// calculate bytes to skip
|
||||
let bytes_to_skip = self.bytes_to_skip - bytes_read;
|
||||
|
||||
*self.as_mut().project().bytes_to_skip = bytes_to_skip;
|
||||
|
||||
if bytes_to_skip == 0 {
|
||||
return Poll::Ready(Ok(self.pos));
|
||||
}
|
||||
}
|
||||
Poll::Pending => return Poll::Pending,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: tokio::io::AsyncRead + Send + Unpin + 'static> BlobReader for NaiveSeeker<R> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::NaiveSeeker;
|
||||
use std::io::{Cursor, SeekFrom};
|
||||
use tokio::io::{AsyncReadExt, AsyncSeekExt};
|
||||
|
||||
/// This seek requires multiple `poll_read` as we use a 1024 bytes internal
|
||||
/// buffer when doing the seek.
|
||||
/// This ensures we don't hang indefinitely.
|
||||
#[tokio::test]
|
||||
async fn seek() {
|
||||
let buf = vec![0u8; 4096];
|
||||
let reader = Cursor::new(&buf);
|
||||
let mut seeker = NaiveSeeker::new(reader);
|
||||
seeker.seek(SeekFrom::Start(4000)).await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn seek_read() {
|
||||
let mut buf = vec![0u8; 2048];
|
||||
buf.extend_from_slice(&[1u8; 2048]);
|
||||
buf.extend_from_slice(&[2u8; 2048]);
|
||||
|
||||
let reader = Cursor::new(&buf);
|
||||
let mut seeker = NaiveSeeker::new(reader);
|
||||
|
||||
let mut read_buf = vec![0u8; 1024];
|
||||
seeker.read_exact(&mut read_buf).await.expect("must read");
|
||||
assert_eq!(read_buf.as_slice(), &[0u8; 1024]);
|
||||
|
||||
seeker
|
||||
.seek(SeekFrom::Current(1024))
|
||||
.await
|
||||
.expect("must seek");
|
||||
seeker.read_exact(&mut read_buf).await.expect("must read");
|
||||
assert_eq!(read_buf.as_slice(), &[1u8; 1024]);
|
||||
|
||||
seeker
|
||||
.seek(SeekFrom::Start(2 * 2048))
|
||||
.await
|
||||
.expect("must seek");
|
||||
seeker.read_exact(&mut read_buf).await.expect("must read");
|
||||
assert_eq!(read_buf.as_slice(), &[2u8; 1024]);
|
||||
}
|
||||
}
|
||||
249
tvix/castore/src/blobservice/sled.rs
Normal file
249
tvix/castore/src/blobservice/sled.rs
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
use super::{BlobReader, BlobService, BlobWriter};
|
||||
use crate::{B3Digest, Error};
|
||||
use std::{
|
||||
io::{self, Cursor, Write},
|
||||
path::PathBuf,
|
||||
task::Poll,
|
||||
};
|
||||
use tonic::async_trait;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SledBlobService {
|
||||
db: sled::Db,
|
||||
}
|
||||
|
||||
impl SledBlobService {
|
||||
pub fn new(p: PathBuf) -> Result<Self, sled::Error> {
|
||||
let config = sled::Config::default().use_compression(true).path(p);
|
||||
let db = config.open()?;
|
||||
|
||||
Ok(Self { db })
|
||||
}
|
||||
|
||||
pub fn new_temporary() -> Result<Self, sled::Error> {
|
||||
let config = sled::Config::default().temporary(true);
|
||||
let db = config.open()?;
|
||||
|
||||
Ok(Self { db })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BlobService for SledBlobService {
|
||||
/// Constructs a [SledBlobService] from the passed [url::Url]:
|
||||
/// - scheme has to be `sled://`
|
||||
/// - there may not be a host.
|
||||
/// - a path to the sled needs to be provided (which may not be `/`).
|
||||
fn from_url(url: &url::Url) -> Result<Self, Error> {
|
||||
if url.scheme() != "sled" {
|
||||
return Err(crate::Error::StorageError("invalid scheme".to_string()));
|
||||
}
|
||||
|
||||
if url.has_host() {
|
||||
return Err(crate::Error::StorageError(format!(
|
||||
"invalid host: {}",
|
||||
url.host().unwrap()
|
||||
)));
|
||||
}
|
||||
|
||||
// TODO: expose compression and other parameters as URL parameters, drop new and new_temporary?
|
||||
if url.path().is_empty() {
|
||||
Self::new_temporary().map_err(|e| Error::StorageError(e.to_string()))
|
||||
} else if url.path() == "/" {
|
||||
Err(crate::Error::StorageError(
|
||||
"cowardly refusing to open / with sled".to_string(),
|
||||
))
|
||||
} else {
|
||||
Self::new(url.path().into()).map_err(|e| Error::StorageError(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self), fields(blob.digest=%digest))]
|
||||
async fn has(&self, digest: &B3Digest) -> Result<bool, Error> {
|
||||
match self.db.contains_key(digest.to_vec()) {
|
||||
Ok(has) => Ok(has),
|
||||
Err(e) => Err(Error::StorageError(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self), fields(blob.digest=%digest))]
|
||||
async fn open_read(&self, digest: &B3Digest) -> Result<Option<Box<dyn BlobReader>>, Error> {
|
||||
match self.db.get(digest.to_vec()) {
|
||||
Ok(None) => Ok(None),
|
||||
Ok(Some(data)) => Ok(Some(Box::new(Cursor::new(data[..].to_vec())))),
|
||||
Err(e) => Err(Error::StorageError(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(self))]
|
||||
async fn open_write(&self) -> Box<dyn BlobWriter> {
|
||||
Box::new(SledBlobWriter::new(self.db.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SledBlobWriter {
|
||||
db: sled::Db,
|
||||
|
||||
/// Contains the buffer Vec and hasher, or None if already closed
|
||||
writers: Option<(Vec<u8>, blake3::Hasher)>,
|
||||
|
||||
/// The digest that has been returned, if we successfully closed.
|
||||
digest: Option<B3Digest>,
|
||||
}
|
||||
|
||||
impl SledBlobWriter {
|
||||
pub fn new(db: sled::Db) -> Self {
|
||||
Self {
|
||||
db,
|
||||
writers: Some((Vec::new(), blake3::Hasher::new())),
|
||||
digest: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl tokio::io::AsyncWrite for SledBlobWriter {
|
||||
fn poll_write(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
b: &[u8],
|
||||
) -> std::task::Poll<Result<usize, io::Error>> {
|
||||
Poll::Ready(match &mut self.writers {
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"already closed",
|
||||
)),
|
||||
Some((ref mut buf, ref mut hasher)) => {
|
||||
let bytes_written = buf.write(b)?;
|
||||
hasher.write(&b[..bytes_written])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
mut self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), io::Error>> {
|
||||
Poll::Ready(match &mut self.writers {
|
||||
None => Err(io::Error::new(
|
||||
io::ErrorKind::NotConnected,
|
||||
"already closed",
|
||||
)),
|
||||
Some(_) => Ok(()),
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), io::Error>> {
|
||||
// shutdown is "instantaneous", we only write to a Vec<u8> as buffer.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl BlobWriter for SledBlobWriter {
|
||||
async fn close(&mut self) -> Result<B3Digest, Error> {
|
||||
if self.writers.is_none() {
|
||||
match &self.digest {
|
||||
Some(digest) => Ok(digest.clone()),
|
||||
None => Err(crate::Error::StorageError(
|
||||
"previously closed with error".to_string(),
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
let (buf, hasher) = self.writers.take().unwrap();
|
||||
|
||||
let digest: B3Digest = hasher.finalize().as_bytes().into();
|
||||
|
||||
// Only insert if the blob doesn't already exist.
|
||||
if !self.db.contains_key(digest.to_vec()).map_err(|e| {
|
||||
Error::StorageError(format!("Unable to check if we have blob {}: {}", digest, e))
|
||||
})? {
|
||||
// put buf in there. This will move buf out.
|
||||
self.db
|
||||
.insert(digest.to_vec(), buf)
|
||||
.map_err(|e| Error::StorageError(format!("unable to insert blob: {}", e)))?;
|
||||
}
|
||||
|
||||
self.digest = Some(digest.clone());
|
||||
|
||||
Ok(digest)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tempfile::TempDir;
|
||||
|
||||
use super::BlobService;
|
||||
use super::SledBlobService;
|
||||
|
||||
/// This uses a wrong scheme.
|
||||
#[test]
|
||||
fn test_invalid_scheme() {
|
||||
let url = url::Url::parse("http://foo.example/test").expect("must parse");
|
||||
|
||||
assert!(SledBlobService::from_url(&url).is_err());
|
||||
}
|
||||
|
||||
/// This uses the correct scheme, and doesn't specify a path (temporary sled).
|
||||
#[test]
|
||||
fn test_valid_scheme_temporary() {
|
||||
let url = url::Url::parse("sled://").expect("must parse");
|
||||
|
||||
assert!(SledBlobService::from_url(&url).is_ok());
|
||||
}
|
||||
|
||||
/// This sets the path to a location that doesn't exist, which should fail (as sled doesn't mkdir -p)
|
||||
#[test]
|
||||
fn test_nonexistent_path() {
|
||||
let tmpdir = TempDir::new().unwrap();
|
||||
|
||||
let mut url = url::Url::parse("sled://foo.example").expect("must parse");
|
||||
url.set_path(tmpdir.path().join("foo").join("bar").to_str().unwrap());
|
||||
|
||||
assert!(SledBlobService::from_url(&url).is_err());
|
||||
}
|
||||
|
||||
/// This uses the correct scheme, and specifies / as path (which should fail
|
||||
// for obvious reasons)
|
||||
#[test]
|
||||
fn test_invalid_path_root() {
|
||||
let url = url::Url::parse("sled:///").expect("must parse");
|
||||
|
||||
assert!(SledBlobService::from_url(&url).is_err());
|
||||
}
|
||||
|
||||
/// This uses the correct scheme, and sets a tempdir as location.
|
||||
#[test]
|
||||
fn test_valid_scheme_path() {
|
||||
let tmpdir = TempDir::new().unwrap();
|
||||
|
||||
let mut url = url::Url::parse("sled://").expect("must parse");
|
||||
url.set_path(tmpdir.path().to_str().unwrap());
|
||||
|
||||
assert!(SledBlobService::from_url(&url).is_ok());
|
||||
}
|
||||
|
||||
/// This sets a host, rather than a path, which should fail.
|
||||
#[test]
|
||||
fn test_invalid_host() {
|
||||
let url = url::Url::parse("sled://foo.example").expect("must parse");
|
||||
|
||||
assert!(SledBlobService::from_url(&url).is_err());
|
||||
}
|
||||
|
||||
/// This sets a host AND a valid path, which should fail
|
||||
#[test]
|
||||
fn test_invalid_host_and_path() {
|
||||
let tmpdir = TempDir::new().unwrap();
|
||||
|
||||
let mut url = url::Url::parse("sled://foo.example").expect("must parse");
|
||||
url.set_path(tmpdir.path().to_str().unwrap());
|
||||
|
||||
assert!(SledBlobService::from_url(&url).is_err());
|
||||
}
|
||||
}
|
||||
246
tvix/castore/src/blobservice/tests.rs
Normal file
246
tvix/castore/src/blobservice/tests.rs
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
use std::io;
|
||||
use std::pin::pin;
|
||||
|
||||
use test_case::test_case;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncSeekExt;
|
||||
|
||||
use super::B3Digest;
|
||||
use super::BlobService;
|
||||
use super::MemoryBlobService;
|
||||
use super::SledBlobService;
|
||||
use crate::fixtures;
|
||||
|
||||
// TODO: avoid having to define all different services we test against for all functions.
|
||||
// maybe something like rstest can be used?
|
||||
|
||||
fn gen_memory_blob_service() -> impl BlobService {
|
||||
MemoryBlobService::default()
|
||||
}
|
||||
fn gen_sled_blob_service() -> impl BlobService {
|
||||
SledBlobService::new_temporary().unwrap()
|
||||
}
|
||||
|
||||
// TODO: add GRPC blob service here.
|
||||
|
||||
/// Using [BlobService::has] on a non-existing blob should return false
|
||||
#[test_case(gen_memory_blob_service(); "memory")]
|
||||
#[test_case(gen_sled_blob_service(); "sled")]
|
||||
fn has_nonexistent_false(blob_service: impl BlobService) {
|
||||
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
assert!(!blob_service
|
||||
.has(&fixtures::BLOB_A_DIGEST)
|
||||
.await
|
||||
.expect("must not fail"));
|
||||
})
|
||||
}
|
||||
|
||||
/// Trying to read a non-existing blob should return a None instead of a reader.
|
||||
#[test_case(gen_memory_blob_service(); "memory")]
|
||||
#[test_case(gen_sled_blob_service(); "sled")]
|
||||
fn not_found_read(blob_service: impl BlobService) {
|
||||
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
assert!(blob_service
|
||||
.open_read(&fixtures::BLOB_A_DIGEST)
|
||||
.await
|
||||
.expect("must not fail")
|
||||
.is_none())
|
||||
})
|
||||
}
|
||||
|
||||
/// Put a blob in the store, check has, get it back.
|
||||
/// We test both with small and big blobs.
|
||||
#[test_case(gen_memory_blob_service(), &fixtures::BLOB_A, &fixtures::BLOB_A_DIGEST; "memory-small")]
|
||||
#[test_case(gen_sled_blob_service(), &fixtures::BLOB_A, &fixtures::BLOB_A_DIGEST; "sled-small")]
|
||||
#[test_case(gen_memory_blob_service(), &fixtures::BLOB_B, &fixtures::BLOB_B_DIGEST; "memory-big")]
|
||||
#[test_case(gen_sled_blob_service(), &fixtures::BLOB_B, &fixtures::BLOB_B_DIGEST; "sled-big")]
|
||||
fn put_has_get(blob_service: impl BlobService, blob_contents: &[u8], blob_digest: &B3Digest) {
|
||||
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
let mut w = blob_service.open_write().await;
|
||||
|
||||
let l = tokio::io::copy(&mut io::Cursor::new(blob_contents), &mut w)
|
||||
.await
|
||||
.expect("copy must succeed");
|
||||
assert_eq!(
|
||||
blob_contents.len(),
|
||||
l as usize,
|
||||
"written bytes must match blob length"
|
||||
);
|
||||
|
||||
let digest = w.close().await.expect("close must succeed");
|
||||
|
||||
assert_eq!(*blob_digest, digest, "returned digest must be correct");
|
||||
|
||||
assert!(
|
||||
blob_service.has(blob_digest).await.expect("must not fail"),
|
||||
"blob service should now have the blob"
|
||||
);
|
||||
|
||||
let mut r = blob_service
|
||||
.open_read(blob_digest)
|
||||
.await
|
||||
.expect("open_read must succeed")
|
||||
.expect("must be some");
|
||||
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
let mut pinned_reader = pin!(r);
|
||||
let l = tokio::io::copy(&mut pinned_reader, &mut buf)
|
||||
.await
|
||||
.expect("copy must succeed");
|
||||
// let l = io::copy(&mut r, &mut buf).expect("copy must succeed");
|
||||
|
||||
assert_eq!(
|
||||
blob_contents.len(),
|
||||
l as usize,
|
||||
"read bytes must match blob length"
|
||||
);
|
||||
|
||||
assert_eq!(blob_contents, buf, "read blob contents must match");
|
||||
})
|
||||
}
|
||||
|
||||
/// Put a blob in the store, and seek inside it a bit.
|
||||
#[test_case(gen_memory_blob_service(); "memory")]
|
||||
#[test_case(gen_sled_blob_service(); "sled")]
|
||||
fn put_seek(blob_service: impl BlobService) {
|
||||
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
let mut w = blob_service.open_write().await;
|
||||
|
||||
tokio::io::copy(&mut io::Cursor::new(&fixtures::BLOB_B.to_vec()), &mut w)
|
||||
.await
|
||||
.expect("copy must succeed");
|
||||
w.close().await.expect("close must succeed");
|
||||
|
||||
// open a blob for reading
|
||||
let mut r = blob_service
|
||||
.open_read(&fixtures::BLOB_B_DIGEST)
|
||||
.await
|
||||
.expect("open_read must succeed")
|
||||
.expect("must be some");
|
||||
|
||||
let mut pos: u64 = 0;
|
||||
|
||||
// read the first 10 bytes, they must match the data in the fixture.
|
||||
{
|
||||
let mut buf = [0; 10];
|
||||
r.read_exact(&mut buf).await.expect("must succeed");
|
||||
|
||||
assert_eq!(
|
||||
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
|
||||
buf,
|
||||
"expected first 10 bytes to match"
|
||||
);
|
||||
|
||||
pos += buf.len() as u64;
|
||||
}
|
||||
// seek by 0 bytes, using SeekFrom::Start.
|
||||
let p = r
|
||||
.seek(io::SeekFrom::Start(pos))
|
||||
.await
|
||||
.expect("must not fail");
|
||||
assert_eq!(pos, p);
|
||||
|
||||
// read the next 10 bytes, they must match the data in the fixture.
|
||||
{
|
||||
let mut buf = [0; 10];
|
||||
r.read_exact(&mut buf).await.expect("must succeed");
|
||||
|
||||
assert_eq!(
|
||||
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
|
||||
buf,
|
||||
"expected data to match"
|
||||
);
|
||||
|
||||
pos += buf.len() as u64;
|
||||
}
|
||||
|
||||
// seek by 5 bytes, using SeekFrom::Start.
|
||||
let p = r
|
||||
.seek(io::SeekFrom::Start(pos + 5))
|
||||
.await
|
||||
.expect("must not fail");
|
||||
pos += 5;
|
||||
assert_eq!(pos, p);
|
||||
|
||||
// read the next 10 bytes, they must match the data in the fixture.
|
||||
{
|
||||
let mut buf = [0; 10];
|
||||
r.read_exact(&mut buf).await.expect("must succeed");
|
||||
|
||||
assert_eq!(
|
||||
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
|
||||
buf,
|
||||
"expected data to match"
|
||||
);
|
||||
|
||||
pos += buf.len() as u64;
|
||||
}
|
||||
|
||||
// seek by 12345 bytes, using SeekFrom::
|
||||
let p = r
|
||||
.seek(io::SeekFrom::Current(12345))
|
||||
.await
|
||||
.expect("must not fail");
|
||||
pos += 12345;
|
||||
assert_eq!(pos, p);
|
||||
|
||||
// read the next 10 bytes, they must match the data in the fixture.
|
||||
{
|
||||
let mut buf = [0; 10];
|
||||
r.read_exact(&mut buf).await.expect("must succeed");
|
||||
|
||||
assert_eq!(
|
||||
&fixtures::BLOB_B[pos as usize..pos as usize + buf.len()],
|
||||
buf,
|
||||
"expected data to match"
|
||||
);
|
||||
|
||||
#[allow(unused_assignments)]
|
||||
{
|
||||
pos += buf.len() as u64;
|
||||
}
|
||||
}
|
||||
|
||||
// seeking to the end is okay…
|
||||
let p = r
|
||||
.seek(io::SeekFrom::Start(fixtures::BLOB_B.len() as u64))
|
||||
.await
|
||||
.expect("must not fail");
|
||||
pos = fixtures::BLOB_B.len() as u64;
|
||||
assert_eq!(pos, p);
|
||||
|
||||
{
|
||||
// but it returns no more data.
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
r.read_to_end(&mut buf).await.expect("must not fail");
|
||||
assert!(buf.is_empty(), "expected no more data to be read");
|
||||
}
|
||||
|
||||
// seeking past the end…
|
||||
match r
|
||||
.seek(io::SeekFrom::Start(fixtures::BLOB_B.len() as u64 + 1))
|
||||
.await
|
||||
{
|
||||
// should either be ok, but then return 0 bytes.
|
||||
// this matches the behaviour or a Cursor<Vec<u8>>.
|
||||
Ok(_pos) => {
|
||||
let mut buf: Vec<u8> = Vec::new();
|
||||
r.read_to_end(&mut buf).await.expect("must not fail");
|
||||
assert!(buf.is_empty(), "expected no more data to be read");
|
||||
}
|
||||
// or not be okay.
|
||||
Err(_) => {}
|
||||
}
|
||||
|
||||
// TODO: this is only broken for the gRPC version
|
||||
// We expect seeking backwards or relative to the end to fail.
|
||||
// r.seek(io::SeekFrom::Current(-1))
|
||||
// .expect_err("SeekFrom::Current(-1) expected to fail");
|
||||
|
||||
// r.seek(io::SeekFrom::Start(pos - 1))
|
||||
// .expect_err("SeekFrom::Start(pos-1) expected to fail");
|
||||
|
||||
// r.seek(io::SeekFrom::End(0))
|
||||
// .expect_err("SeekFrom::End(_) expected to fail");
|
||||
})
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue