// SPDX-FileCopyrightText: 2023 Matteo Settenvini // SPDX-License-Identifier: AGPL-3.0-or-later use tonic::transport::server::TcpIncoming; mod pb { tonic::include_proto!("package"); } use { futures::{Stream, StreamExt as _}, pb::service_server::{Service, ServiceServer}, pb::Message, std::cell::RefCell, std::error::Error, std::pin::Pin, std::sync::{Arc, Mutex, Weak}, tokio::sync::broadcast, tokio_stream::wrappers::BroadcastStream, tonic::{Request, Response, Status, Streaming}, }; const MESSAGE_QUEUE_SIZE: usize = 20; const RENDEZVOUS: &'static str = "[::1]:10000"; #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { let url = RENDEZVOUS.parse()?; let incoming = tonic::transport::server::TcpIncoming::new(url, true, None) .expect("Cannot bind server socket"); run_server(incoming).await?; Ok(()) } fn run_server( incoming: TcpIncoming, ) -> impl std::future::Future> { use tonic::transport::Server; let service = MessageService::default(); let svc_adapter = ServiceServer::new(service); let builder = Server::builder().add_service(svc_adapter); builder.serve_with_incoming(incoming) } #[derive(Default, Debug)] struct MessageService { shared_tx: Mutex>>>, } impl MessageService { fn shared_channel( &self, ) -> ( Arc>, broadcast::Receiver, ) { let tx_guard = self.shared_tx.lock().unwrap(); let maybe_tx = tx_guard.borrow().upgrade(); match maybe_tx { Some(tx) => { let rx = tx.subscribe(); (tx, rx) } None => { let (tx, rx) = broadcast::channel(MESSAGE_QUEUE_SIZE); let tx = Arc::new(tx); *tx_guard.borrow_mut() = Arc::downgrade(&tx); (tx, rx) } } } } #[tonic::async_trait] impl Service for MessageService { type BroadcasterStream = Pin> + Send + 'static>>; async fn broadcaster( &self, requests: Request>, ) -> Result, Status> { let mut incoming = requests.into_inner(); let (tx, rx) = self.shared_channel(); let output = BroadcastStream::new(rx) .map(|result| result.map_err(|err| Status::data_loss(err.to_string()))); tokio::spawn(async move { while let Some(Ok(message)) = incoming.next().await { let _ = tx.send(message); // Ignore err if no receivers } }); Ok(Response::new(Box::pin(output))) } } #[cfg(test)] mod test { use { super::pb::service_client::ServiceClient, super::run_server, super::{Message, RENDEZVOUS}, mockall::automock, tonic::Request, }; #[automock] trait MessageChecker { fn check_contents(&self, msg: &Message); } #[tokio::test] async fn bidi_streaming() -> Result<(), Box> { let url = RENDEZVOUS.parse()?; let incoming = tonic::transport::server::TcpIncoming::new(url, true, None) .expect("Cannot bind server socket"); tokio::spawn(run_server(incoming)); const N_MESSAGES: usize = 20; const CLIENT_IDS: &[&str] = &["client AAA", "client BBB", "client CCC"]; async fn client(id: &str, mock: &MockMessageChecker) { let addr = format!("http://{}", RENDEZVOUS); let mut client = ServiceClient::connect(addr).await.unwrap(); let client_id = String::from(id); let outbound = async_stream::stream! { for i in 1..(N_MESSAGES+1) { tokio::task::yield_now().await; let message = Message { contents: format!("{}: {}", client_id, i), }; yield message; } }; let response = client.broadcaster(Request::new(outbound)).await.unwrap(); let mut inbound = response.into_inner(); while let Some(msg) = inbound.message().await.unwrap() { println!("{} received {:?}", id, msg); mock.check_contents(&msg); } } let mut mock = MockMessageChecker::new(); let mut clients = vec![]; for client_id in CLIENT_IDS.iter() { for i in 1..(N_MESSAGES + 1) { let expected = format!("{}: {}", client_id, i); mock.expect_check_contents() .times(CLIENT_IDS.len()) .withf(move |msg| msg.contents == expected) .return_const(()); } } for client_id in CLIENT_IDS { let c = client(client_id, &mock); clients.push(c); } let _ = futures::future::join_all(clients).await; println!("Exiting."); Ok(()) } }