Add some comments to the code
This commit is contained in:
parent
8489d8bb20
commit
d3f23b7956
|
@ -1,8 +1,6 @@
|
|||
// SPDX-FileCopyrightText: 2023 Matteo Settenvini <matteo.settenvini@montecristosoftware.eu>
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
use tonic::transport::server::TcpIncoming;
|
||||
|
||||
mod pb {
|
||||
tonic::include_proto!("package");
|
||||
}
|
||||
|
@ -17,17 +15,20 @@ use {
|
|||
std::sync::{Arc, Mutex, Weak},
|
||||
tokio::sync::broadcast,
|
||||
tokio_stream::wrappers::BroadcastStream,
|
||||
tonic::transport::server::TcpIncoming,
|
||||
tonic::{Request, Response, Status, Streaming},
|
||||
};
|
||||
|
||||
// How many messages to hold in memory before start discarding newcomers:
|
||||
const MESSAGE_QUEUE_SIZE: usize = 20;
|
||||
|
||||
// gRPC server/client connect addr
|
||||
const RENDEZVOUS: &'static str = "[::1]:10000";
|
||||
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
let url = RENDEZVOUS.parse()?;
|
||||
let incoming = tonic::transport::server::TcpIncoming::new(url, true, None)
|
||||
.expect("Cannot bind server socket");
|
||||
let incoming = TcpIncoming::new(url, true, None).expect("Cannot bind server socket");
|
||||
run_server(incoming).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -43,12 +44,17 @@ fn run_server(
|
|||
builder.serve_with_incoming(incoming)
|
||||
}
|
||||
|
||||
type ProtectedWeak<T> = Mutex<RefCell<Weak<T>>>;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct MessageService {
|
||||
shared_tx: Mutex<RefCell<Weak<broadcast::Sender<Message>>>>,
|
||||
shared_tx: ProtectedWeak<broadcast::Sender<Message>>,
|
||||
}
|
||||
|
||||
impl MessageService {
|
||||
// Lazily create a channel and keep a cached reference in self.shared_tx
|
||||
// for other clients to reuse. When the last client goes away, we recycle the
|
||||
// channel. This allows cleanly shutting down the server in tests too.
|
||||
fn shared_channel(
|
||||
&self,
|
||||
) -> (
|
||||
|
@ -76,16 +82,23 @@ impl MessageService {
|
|||
impl Service for MessageService {
|
||||
type BroadcasterStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
|
||||
|
||||
// This is the implementation of the gRPC method from the .proto file.
|
||||
async fn broadcaster(
|
||||
&self,
|
||||
requests: Request<Streaming<Message>>,
|
||||
) -> Result<Response<Self::BroadcasterStream>, Status> {
|
||||
let mut incoming = requests.into_inner();
|
||||
|
||||
// When the connection is first established, we create a lazy stream for output messages.
|
||||
// This will be returned from the function.
|
||||
let (tx, rx) = self.shared_channel();
|
||||
let output = BroadcastStream::new(rx)
|
||||
.map(|result| result.map_err(|err| Status::data_loss(err.to_string())));
|
||||
|
||||
// We also spawn a separate task to process incoming messages and
|
||||
// forward them to all clients. This will run on the executor
|
||||
// until an error is encountered; typically, until the incoming
|
||||
// stream is closed.
|
||||
tokio::spawn(async move {
|
||||
while let Some(Ok(message)) = incoming.next().await {
|
||||
let _ = tx.send(message); // Ignore err if no receivers
|
||||
|
@ -103,6 +116,7 @@ mod test {
|
|||
super::run_server,
|
||||
super::{Message, RENDEZVOUS},
|
||||
mockall::automock,
|
||||
tonic::transport::server::TcpIncoming,
|
||||
tonic::Request,
|
||||
};
|
||||
|
||||
|
@ -114,11 +128,12 @@ mod test {
|
|||
#[tokio::test]
|
||||
async fn bidi_streaming() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = RENDEZVOUS.parse()?;
|
||||
let incoming = tonic::transport::server::TcpIncoming::new(url, true, None)
|
||||
.expect("Cannot bind server socket");
|
||||
let incoming = TcpIncoming::new(url, true, None).expect("Cannot bind server socket");
|
||||
tokio::spawn(run_server(incoming));
|
||||
|
||||
const N_MESSAGES: usize = 20;
|
||||
// This is the number of messages to broadcast
|
||||
const N_MESSAGES: usize = 100;
|
||||
// We use 4 test clients connecting to the server
|
||||
const CLIENT_IDS: &[&str] = &["client AAA", "client BBB", "client CCC"];
|
||||
|
||||
async fn client(id: &str, mock: &MockMessageChecker) {
|
||||
|
@ -145,9 +160,9 @@ mod test {
|
|||
}
|
||||
}
|
||||
|
||||
// We create a mock to check that we receive exactly the
|
||||
// messages we want. We also check for saturation.
|
||||
let mut mock = MockMessageChecker::new();
|
||||
|
||||
let mut clients = vec![];
|
||||
for client_id in CLIENT_IDS.iter() {
|
||||
for i in 1..(N_MESSAGES + 1) {
|
||||
let expected = format!("{}: {}", client_id, i);
|
||||
|
@ -158,12 +173,13 @@ mod test {
|
|||
}
|
||||
}
|
||||
|
||||
for client_id in CLIENT_IDS {
|
||||
let c = client(client_id, &mock);
|
||||
clients.push(c);
|
||||
}
|
||||
|
||||
// Time to create the clients, and await for their completion
|
||||
let clients = CLIENT_IDS
|
||||
.into_iter()
|
||||
.map(|&id| client(id, &mock))
|
||||
.collect::<Vec<_>>();
|
||||
let _ = futures::future::join_all(clients).await;
|
||||
|
||||
println!("Exiting.");
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue