From d3f23b79560d18c9e06b8e1a98a2ee5e10919ae8 Mon Sep 17 00:00:00 2001 From: Matteo Settenvini Date: Wed, 3 May 2023 07:29:46 +0200 Subject: [PATCH] Add some comments to the code --- grpc/broadcaster/src/main.rs | 46 ++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/grpc/broadcaster/src/main.rs b/grpc/broadcaster/src/main.rs index 0113bb9..6156b94 100644 --- a/grpc/broadcaster/src/main.rs +++ b/grpc/broadcaster/src/main.rs @@ -1,8 +1,6 @@ // SPDX-FileCopyrightText: 2023 Matteo Settenvini // SPDX-License-Identifier: AGPL-3.0-or-later -use tonic::transport::server::TcpIncoming; - mod pb { tonic::include_proto!("package"); } @@ -17,17 +15,20 @@ use { std::sync::{Arc, Mutex, Weak}, tokio::sync::broadcast, tokio_stream::wrappers::BroadcastStream, + tonic::transport::server::TcpIncoming, tonic::{Request, Response, Status, Streaming}, }; +// How many messages to hold in memory before start discarding newcomers: const MESSAGE_QUEUE_SIZE: usize = 20; + +// gRPC server/client connect addr const RENDEZVOUS: &'static str = "[::1]:10000"; #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box> { let url = RENDEZVOUS.parse()?; - let incoming = tonic::transport::server::TcpIncoming::new(url, true, None) - .expect("Cannot bind server socket"); + let incoming = TcpIncoming::new(url, true, None).expect("Cannot bind server socket"); run_server(incoming).await?; Ok(()) } @@ -43,12 +44,17 @@ fn run_server( builder.serve_with_incoming(incoming) } +type ProtectedWeak = Mutex>>; + #[derive(Default, Debug)] struct MessageService { - shared_tx: Mutex>>>, + shared_tx: ProtectedWeak>, } impl MessageService { + // Lazily create a channel and keep a cached reference in self.shared_tx + // for other clients to reuse. When the last client goes away, we recycle the + // channel. This allows cleanly shutting down the server in tests too. fn shared_channel( &self, ) -> ( @@ -76,16 +82,23 @@ impl MessageService { impl Service for MessageService { type BroadcasterStream = Pin> + Send + 'static>>; + // This is the implementation of the gRPC method from the .proto file. async fn broadcaster( &self, requests: Request>, ) -> Result, Status> { let mut incoming = requests.into_inner(); + // When the connection is first established, we create a lazy stream for output messages. + // This will be returned from the function. let (tx, rx) = self.shared_channel(); let output = BroadcastStream::new(rx) .map(|result| result.map_err(|err| Status::data_loss(err.to_string()))); + // We also spawn a separate task to process incoming messages and + // forward them to all clients. This will run on the executor + // until an error is encountered; typically, until the incoming + // stream is closed. tokio::spawn(async move { while let Some(Ok(message)) = incoming.next().await { let _ = tx.send(message); // Ignore err if no receivers @@ -103,6 +116,7 @@ mod test { super::run_server, super::{Message, RENDEZVOUS}, mockall::automock, + tonic::transport::server::TcpIncoming, tonic::Request, }; @@ -114,11 +128,12 @@ mod test { #[tokio::test] async fn bidi_streaming() -> Result<(), Box> { let url = RENDEZVOUS.parse()?; - let incoming = tonic::transport::server::TcpIncoming::new(url, true, None) - .expect("Cannot bind server socket"); + let incoming = TcpIncoming::new(url, true, None).expect("Cannot bind server socket"); tokio::spawn(run_server(incoming)); - const N_MESSAGES: usize = 20; + // This is the number of messages to broadcast + const N_MESSAGES: usize = 100; + // We use 4 test clients connecting to the server const CLIENT_IDS: &[&str] = &["client AAA", "client BBB", "client CCC"]; async fn client(id: &str, mock: &MockMessageChecker) { @@ -145,9 +160,9 @@ mod test { } } + // We create a mock to check that we receive exactly the + // messages we want. We also check for saturation. let mut mock = MockMessageChecker::new(); - - let mut clients = vec![]; for client_id in CLIENT_IDS.iter() { for i in 1..(N_MESSAGES + 1) { let expected = format!("{}: {}", client_id, i); @@ -158,12 +173,13 @@ mod test { } } - for client_id in CLIENT_IDS { - let c = client(client_id, &mock); - clients.push(c); - } - + // Time to create the clients, and await for their completion + let clients = CLIENT_IDS + .into_iter() + .map(|&id| client(id, &mock)) + .collect::>(); let _ = futures::future::join_all(clients).await; + println!("Exiting."); Ok(()) }