diff --git a/tvix/castore/src/directoryservice/grpc.rs b/tvix/castore/src/directoryservice/grpc.rs index 1d6ad2c13..c98708608 100644 --- a/tvix/castore/src/directoryservice/grpc.rs +++ b/tvix/castore/src/directoryservice/grpc.rs @@ -288,7 +288,7 @@ impl DirectoryPutter for GRPCPutter { mod tests { use core::time; use futures::StreamExt; - use std::{any::Any, sync::Arc, time::Duration}; + use std::{any::Any, time::Duration}; use tempfile::TempDir; use tokio::net::UnixListener; use tokio_retry::{strategy::ExponentialBackoff, Retry}; @@ -460,8 +460,8 @@ mod tests { let mut server = tonic::transport::Server::builder(); let router = server.add_service( crate::proto::directory_service_server::DirectoryServiceServer::new( - GRPCDirectoryServiceWrapper::from( - Arc::new(MemoryDirectoryService::default()) as Arc + GRPCDirectoryServiceWrapper::new( + Box::::default() as Box ), ), ); diff --git a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs index 097958050..b83048045 100644 --- a/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs +++ b/tvix/castore/src/proto/grpc_directoryservice_wrapper.rs @@ -2,26 +2,27 @@ use crate::proto; use crate::{directoryservice::DirectoryService, B3Digest}; use futures::StreamExt; use std::collections::HashMap; -use std::sync::Arc; -use tokio::{sync::mpsc::channel, task}; +use std::ops::Deref; +use tokio::sync::mpsc::channel; use tokio_stream::wrappers::ReceiverStream; use tonic::{async_trait, Request, Response, Status, Streaming}; use tracing::{debug, instrument, warn}; -pub struct GRPCDirectoryServiceWrapper { - directory_service: Arc, +pub struct GRPCDirectoryServiceWrapper { + directory_service: T, } -impl From> for GRPCDirectoryServiceWrapper { - fn from(value: Arc) -> Self { - Self { - directory_service: value, - } +impl GRPCDirectoryServiceWrapper { + pub fn new(directory_service: T) -> Self { + Self { directory_service } } } #[async_trait] -impl proto::directory_service_server::DirectoryService for GRPCDirectoryServiceWrapper { +impl proto::directory_service_server::DirectoryService for GRPCDirectoryServiceWrapper +where + T: Deref + Send + Sync + 'static, +{ type GetStream = ReceiverStream>; #[instrument(skip(self))] @@ -33,50 +34,43 @@ impl proto::directory_service_server::DirectoryService for GRPCDirectoryServiceW let req_inner = request.into_inner(); - let directory_service = self.directory_service.clone(); + // look at the digest in the request and put it in the top of the queue. + match &req_inner.by_what { + None => return Err(Status::invalid_argument("by_what needs to be specified")), + Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => { + let digest: B3Digest = digest + .clone() + .try_into() + .map_err(|_e| Status::invalid_argument("invalid digest length"))?; - let _task = { - // look at the digest in the request and put it in the top of the queue. - match &req_inner.by_what { - None => return Err(Status::invalid_argument("by_what needs to be specified")), - Some(proto::get_directory_request::ByWhat::Digest(ref digest)) => { - let digest: B3Digest = digest - .clone() - .try_into() - .map_err(|_e| Status::invalid_argument("invalid digest length"))?; - - task::spawn(async move { - if !req_inner.recursive { - let e: Result = - match directory_service.get(&digest).await { - Ok(Some(directory)) => Ok(directory), - Ok(None) => Err(Status::not_found(format!( - "directory {} not found", - digest - ))), - Err(e) => Err(e.into()), - }; - - if tx.send(e).await.is_err() { - debug!("receiver dropped"); + if !req_inner.recursive { + let e: Result = + match self.directory_service.get(&digest).await { + Ok(Some(directory)) => Ok(directory), + Ok(None) => { + Err(Status::not_found(format!("directory {} not found", digest))) } - } else { - // If recursive was requested, traverse via get_recursive. - let mut directories_it = directory_service.get_recursive(&digest); + Err(e) => Err(e.into()), + }; - while let Some(e) = directories_it.next().await { - // map err in res from Error to Status - let res = e.map_err(|e| Status::internal(e.to_string())); - if tx.send(res).await.is_err() { - debug!("receiver dropped"); - break; - } - } + if tx.send(e).await.is_err() { + debug!("receiver dropped"); + } + } else { + // If recursive was requested, traverse via get_recursive. + let mut directories_it = self.directory_service.get_recursive(&digest); + + while let Some(e) = directories_it.next().await { + // map err in res from Error to Status + let res = e.map_err(|e| Status::internal(e.to_string())); + if tx.send(res).await.is_err() { + debug!("receiver dropped"); + break; } - }); + } } } - }; + } let receiver_stream = ReceiverStream::new(rx); Ok(Response::new(receiver_stream)) diff --git a/tvix/castore/src/utils.rs b/tvix/castore/src/utils.rs index 1b0d4c674..b24627ed9 100644 --- a/tvix/castore/src/utils.rs +++ b/tvix/castore/src/utils.rs @@ -35,7 +35,7 @@ pub(crate) async fn gen_directorysvc_grpc_client() -> DirectoryServiceClient Result<(), Box> { blob_service, ))) .add_service(DirectoryServiceServer::new( - GRPCDirectoryServiceWrapper::from(directory_service), + GRPCDirectoryServiceWrapper::new(directory_service), )) .add_service(PathInfoServiceServer::new(GRPCPathInfoServiceWrapper::new( Arc::from(path_info_service),