Add some comments to the code

This commit is contained in:
Matteo Settenvini 2023-05-03 07:29:46 +02:00
parent 8489d8bb20
commit d3f23b7956
Signed by: matteo
GPG Key ID: CCF27A3AD054D593
1 changed files with 31 additions and 15 deletions

View File

@ -1,8 +1,6 @@
// SPDX-FileCopyrightText: 2023 Matteo Settenvini <matteo.settenvini@montecristosoftware.eu> // SPDX-FileCopyrightText: 2023 Matteo Settenvini <matteo.settenvini@montecristosoftware.eu>
// SPDX-License-Identifier: AGPL-3.0-or-later // SPDX-License-Identifier: AGPL-3.0-or-later
use tonic::transport::server::TcpIncoming;
mod pb { mod pb {
tonic::include_proto!("package"); tonic::include_proto!("package");
} }
@ -17,17 +15,20 @@ use {
std::sync::{Arc, Mutex, Weak}, std::sync::{Arc, Mutex, Weak},
tokio::sync::broadcast, tokio::sync::broadcast,
tokio_stream::wrappers::BroadcastStream, tokio_stream::wrappers::BroadcastStream,
tonic::transport::server::TcpIncoming,
tonic::{Request, Response, Status, Streaming}, tonic::{Request, Response, Status, Streaming},
}; };
// How many messages to hold in memory before start discarding newcomers:
const MESSAGE_QUEUE_SIZE: usize = 20; const MESSAGE_QUEUE_SIZE: usize = 20;
// gRPC server/client connect addr
const RENDEZVOUS: &'static str = "[::1]:10000"; const RENDEZVOUS: &'static str = "[::1]:10000";
#[tokio::main(flavor = "current_thread")] #[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn Error>> { async fn main() -> Result<(), Box<dyn Error>> {
let url = RENDEZVOUS.parse()?; let url = RENDEZVOUS.parse()?;
let incoming = tonic::transport::server::TcpIncoming::new(url, true, None) let incoming = TcpIncoming::new(url, true, None).expect("Cannot bind server socket");
.expect("Cannot bind server socket");
run_server(incoming).await?; run_server(incoming).await?;
Ok(()) Ok(())
} }
@ -43,12 +44,17 @@ fn run_server(
builder.serve_with_incoming(incoming) builder.serve_with_incoming(incoming)
} }
type ProtectedWeak<T> = Mutex<RefCell<Weak<T>>>;
#[derive(Default, Debug)] #[derive(Default, Debug)]
struct MessageService { struct MessageService {
shared_tx: Mutex<RefCell<Weak<broadcast::Sender<Message>>>>, shared_tx: ProtectedWeak<broadcast::Sender<Message>>,
} }
impl MessageService { impl MessageService {
// Lazily create a channel and keep a cached reference in self.shared_tx
// for other clients to reuse. When the last client goes away, we recycle the
// channel. This allows cleanly shutting down the server in tests too.
fn shared_channel( fn shared_channel(
&self, &self,
) -> ( ) -> (
@ -76,16 +82,23 @@ impl MessageService {
impl Service for MessageService { impl Service for MessageService {
type BroadcasterStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>; type BroadcasterStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
// This is the implementation of the gRPC method from the .proto file.
async fn broadcaster( async fn broadcaster(
&self, &self,
requests: Request<Streaming<Message>>, requests: Request<Streaming<Message>>,
) -> Result<Response<Self::BroadcasterStream>, Status> { ) -> Result<Response<Self::BroadcasterStream>, Status> {
let mut incoming = requests.into_inner(); let mut incoming = requests.into_inner();
// When the connection is first established, we create a lazy stream for output messages.
// This will be returned from the function.
let (tx, rx) = self.shared_channel(); let (tx, rx) = self.shared_channel();
let output = BroadcastStream::new(rx) let output = BroadcastStream::new(rx)
.map(|result| result.map_err(|err| Status::data_loss(err.to_string()))); .map(|result| result.map_err(|err| Status::data_loss(err.to_string())));
// We also spawn a separate task to process incoming messages and
// forward them to all clients. This will run on the executor
// until an error is encountered; typically, until the incoming
// stream is closed.
tokio::spawn(async move { tokio::spawn(async move {
while let Some(Ok(message)) = incoming.next().await { while let Some(Ok(message)) = incoming.next().await {
let _ = tx.send(message); // Ignore err if no receivers let _ = tx.send(message); // Ignore err if no receivers
@ -103,6 +116,7 @@ mod test {
super::run_server, super::run_server,
super::{Message, RENDEZVOUS}, super::{Message, RENDEZVOUS},
mockall::automock, mockall::automock,
tonic::transport::server::TcpIncoming,
tonic::Request, tonic::Request,
}; };
@ -114,11 +128,12 @@ mod test {
#[tokio::test] #[tokio::test]
async fn bidi_streaming() -> Result<(), Box<dyn std::error::Error>> { async fn bidi_streaming() -> Result<(), Box<dyn std::error::Error>> {
let url = RENDEZVOUS.parse()?; let url = RENDEZVOUS.parse()?;
let incoming = tonic::transport::server::TcpIncoming::new(url, true, None) let incoming = TcpIncoming::new(url, true, None).expect("Cannot bind server socket");
.expect("Cannot bind server socket");
tokio::spawn(run_server(incoming)); tokio::spawn(run_server(incoming));
const N_MESSAGES: usize = 20; // This is the number of messages to broadcast
const N_MESSAGES: usize = 100;
// We use 4 test clients connecting to the server
const CLIENT_IDS: &[&str] = &["client AAA", "client BBB", "client CCC"]; const CLIENT_IDS: &[&str] = &["client AAA", "client BBB", "client CCC"];
async fn client(id: &str, mock: &MockMessageChecker) { async fn client(id: &str, mock: &MockMessageChecker) {
@ -145,9 +160,9 @@ mod test {
} }
} }
// We create a mock to check that we receive exactly the
// messages we want. We also check for saturation.
let mut mock = MockMessageChecker::new(); let mut mock = MockMessageChecker::new();
let mut clients = vec![];
for client_id in CLIENT_IDS.iter() { for client_id in CLIENT_IDS.iter() {
for i in 1..(N_MESSAGES + 1) { for i in 1..(N_MESSAGES + 1) {
let expected = format!("{}: {}", client_id, i); let expected = format!("{}: {}", client_id, i);
@ -158,12 +173,13 @@ mod test {
} }
} }
for client_id in CLIENT_IDS { // Time to create the clients, and await for their completion
let c = client(client_id, &mock); let clients = CLIENT_IDS
clients.push(c); .into_iter()
} .map(|&id| client(id, &mock))
.collect::<Vec<_>>();
let _ = futures::future::join_all(clients).await; let _ = futures::future::join_all(clients).await;
println!("Exiting."); println!("Exiting.");
Ok(()) Ok(())
} }