Add some comments to the code

This commit is contained in:
Matteo Settenvini 2023-05-03 07:29:46 +02:00
parent 8489d8bb20
commit d3f23b7956
Signed by: matteo
GPG Key ID: CCF27A3AD054D593
1 changed files with 31 additions and 15 deletions

View File

@ -1,8 +1,6 @@
// SPDX-FileCopyrightText: 2023 Matteo Settenvini <matteo.settenvini@montecristosoftware.eu>
// SPDX-License-Identifier: AGPL-3.0-or-later
use tonic::transport::server::TcpIncoming;
mod pb {
tonic::include_proto!("package");
}
@ -17,17 +15,20 @@ use {
std::sync::{Arc, Mutex, Weak},
tokio::sync::broadcast,
tokio_stream::wrappers::BroadcastStream,
tonic::transport::server::TcpIncoming,
tonic::{Request, Response, Status, Streaming},
};
// How many messages to hold in memory before start discarding newcomers:
const MESSAGE_QUEUE_SIZE: usize = 20;
// gRPC server/client connect addr
const RENDEZVOUS: &'static str = "[::1]:10000";
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn Error>> {
let url = RENDEZVOUS.parse()?;
let incoming = tonic::transport::server::TcpIncoming::new(url, true, None)
.expect("Cannot bind server socket");
let incoming = TcpIncoming::new(url, true, None).expect("Cannot bind server socket");
run_server(incoming).await?;
Ok(())
}
@ -43,12 +44,17 @@ fn run_server(
builder.serve_with_incoming(incoming)
}
type ProtectedWeak<T> = Mutex<RefCell<Weak<T>>>;
#[derive(Default, Debug)]
struct MessageService {
shared_tx: Mutex<RefCell<Weak<broadcast::Sender<Message>>>>,
shared_tx: ProtectedWeak<broadcast::Sender<Message>>,
}
impl MessageService {
// Lazily create a channel and keep a cached reference in self.shared_tx
// for other clients to reuse. When the last client goes away, we recycle the
// channel. This allows cleanly shutting down the server in tests too.
fn shared_channel(
&self,
) -> (
@ -76,16 +82,23 @@ impl MessageService {
impl Service for MessageService {
type BroadcasterStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
// This is the implementation of the gRPC method from the .proto file.
async fn broadcaster(
&self,
requests: Request<Streaming<Message>>,
) -> Result<Response<Self::BroadcasterStream>, Status> {
let mut incoming = requests.into_inner();
// When the connection is first established, we create a lazy stream for output messages.
// This will be returned from the function.
let (tx, rx) = self.shared_channel();
let output = BroadcastStream::new(rx)
.map(|result| result.map_err(|err| Status::data_loss(err.to_string())));
// We also spawn a separate task to process incoming messages and
// forward them to all clients. This will run on the executor
// until an error is encountered; typically, until the incoming
// stream is closed.
tokio::spawn(async move {
while let Some(Ok(message)) = incoming.next().await {
let _ = tx.send(message); // Ignore err if no receivers
@ -103,6 +116,7 @@ mod test {
super::run_server,
super::{Message, RENDEZVOUS},
mockall::automock,
tonic::transport::server::TcpIncoming,
tonic::Request,
};
@ -114,11 +128,12 @@ mod test {
#[tokio::test]
async fn bidi_streaming() -> Result<(), Box<dyn std::error::Error>> {
let url = RENDEZVOUS.parse()?;
let incoming = tonic::transport::server::TcpIncoming::new(url, true, None)
.expect("Cannot bind server socket");
let incoming = TcpIncoming::new(url, true, None).expect("Cannot bind server socket");
tokio::spawn(run_server(incoming));
const N_MESSAGES: usize = 20;
// This is the number of messages to broadcast
const N_MESSAGES: usize = 100;
// We use 4 test clients connecting to the server
const CLIENT_IDS: &[&str] = &["client AAA", "client BBB", "client CCC"];
async fn client(id: &str, mock: &MockMessageChecker) {
@ -145,9 +160,9 @@ mod test {
}
}
// We create a mock to check that we receive exactly the
// messages we want. We also check for saturation.
let mut mock = MockMessageChecker::new();
let mut clients = vec![];
for client_id in CLIENT_IDS.iter() {
for i in 1..(N_MESSAGES + 1) {
let expected = format!("{}: {}", client_id, i);
@ -158,12 +173,13 @@ mod test {
}
}
for client_id in CLIENT_IDS {
let c = client(client_id, &mock);
clients.push(c);
}
// Time to create the clients, and await for their completion
let clients = CLIENT_IDS
.into_iter()
.map(|&id| client(id, &mock))
.collect::<Vec<_>>();
let _ = futures::future::join_all(clients).await;
println!("Exiting.");
Ok(())
}