Add some comments to the code
This commit is contained in:
parent
8489d8bb20
commit
d3f23b7956
|
@ -1,8 +1,6 @@
|
||||||
// SPDX-FileCopyrightText: 2023 Matteo Settenvini <matteo.settenvini@montecristosoftware.eu>
|
// SPDX-FileCopyrightText: 2023 Matteo Settenvini <matteo.settenvini@montecristosoftware.eu>
|
||||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
|
|
||||||
use tonic::transport::server::TcpIncoming;
|
|
||||||
|
|
||||||
mod pb {
|
mod pb {
|
||||||
tonic::include_proto!("package");
|
tonic::include_proto!("package");
|
||||||
}
|
}
|
||||||
|
@ -17,17 +15,20 @@ use {
|
||||||
std::sync::{Arc, Mutex, Weak},
|
std::sync::{Arc, Mutex, Weak},
|
||||||
tokio::sync::broadcast,
|
tokio::sync::broadcast,
|
||||||
tokio_stream::wrappers::BroadcastStream,
|
tokio_stream::wrappers::BroadcastStream,
|
||||||
|
tonic::transport::server::TcpIncoming,
|
||||||
tonic::{Request, Response, Status, Streaming},
|
tonic::{Request, Response, Status, Streaming},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// How many messages to hold in memory before start discarding newcomers:
|
||||||
const MESSAGE_QUEUE_SIZE: usize = 20;
|
const MESSAGE_QUEUE_SIZE: usize = 20;
|
||||||
|
|
||||||
|
// gRPC server/client connect addr
|
||||||
const RENDEZVOUS: &'static str = "[::1]:10000";
|
const RENDEZVOUS: &'static str = "[::1]:10000";
|
||||||
|
|
||||||
#[tokio::main(flavor = "current_thread")]
|
#[tokio::main(flavor = "current_thread")]
|
||||||
async fn main() -> Result<(), Box<dyn Error>> {
|
async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
let url = RENDEZVOUS.parse()?;
|
let url = RENDEZVOUS.parse()?;
|
||||||
let incoming = tonic::transport::server::TcpIncoming::new(url, true, None)
|
let incoming = TcpIncoming::new(url, true, None).expect("Cannot bind server socket");
|
||||||
.expect("Cannot bind server socket");
|
|
||||||
run_server(incoming).await?;
|
run_server(incoming).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -43,12 +44,17 @@ fn run_server(
|
||||||
builder.serve_with_incoming(incoming)
|
builder.serve_with_incoming(incoming)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ProtectedWeak<T> = Mutex<RefCell<Weak<T>>>;
|
||||||
|
|
||||||
#[derive(Default, Debug)]
|
#[derive(Default, Debug)]
|
||||||
struct MessageService {
|
struct MessageService {
|
||||||
shared_tx: Mutex<RefCell<Weak<broadcast::Sender<Message>>>>,
|
shared_tx: ProtectedWeak<broadcast::Sender<Message>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MessageService {
|
impl MessageService {
|
||||||
|
// Lazily create a channel and keep a cached reference in self.shared_tx
|
||||||
|
// for other clients to reuse. When the last client goes away, we recycle the
|
||||||
|
// channel. This allows cleanly shutting down the server in tests too.
|
||||||
fn shared_channel(
|
fn shared_channel(
|
||||||
&self,
|
&self,
|
||||||
) -> (
|
) -> (
|
||||||
|
@ -76,16 +82,23 @@ impl MessageService {
|
||||||
impl Service for MessageService {
|
impl Service for MessageService {
|
||||||
type BroadcasterStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
|
type BroadcasterStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
|
||||||
|
|
||||||
|
// This is the implementation of the gRPC method from the .proto file.
|
||||||
async fn broadcaster(
|
async fn broadcaster(
|
||||||
&self,
|
&self,
|
||||||
requests: Request<Streaming<Message>>,
|
requests: Request<Streaming<Message>>,
|
||||||
) -> Result<Response<Self::BroadcasterStream>, Status> {
|
) -> Result<Response<Self::BroadcasterStream>, Status> {
|
||||||
let mut incoming = requests.into_inner();
|
let mut incoming = requests.into_inner();
|
||||||
|
|
||||||
|
// When the connection is first established, we create a lazy stream for output messages.
|
||||||
|
// This will be returned from the function.
|
||||||
let (tx, rx) = self.shared_channel();
|
let (tx, rx) = self.shared_channel();
|
||||||
let output = BroadcastStream::new(rx)
|
let output = BroadcastStream::new(rx)
|
||||||
.map(|result| result.map_err(|err| Status::data_loss(err.to_string())));
|
.map(|result| result.map_err(|err| Status::data_loss(err.to_string())));
|
||||||
|
|
||||||
|
// We also spawn a separate task to process incoming messages and
|
||||||
|
// forward them to all clients. This will run on the executor
|
||||||
|
// until an error is encountered; typically, until the incoming
|
||||||
|
// stream is closed.
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
while let Some(Ok(message)) = incoming.next().await {
|
while let Some(Ok(message)) = incoming.next().await {
|
||||||
let _ = tx.send(message); // Ignore err if no receivers
|
let _ = tx.send(message); // Ignore err if no receivers
|
||||||
|
@ -103,6 +116,7 @@ mod test {
|
||||||
super::run_server,
|
super::run_server,
|
||||||
super::{Message, RENDEZVOUS},
|
super::{Message, RENDEZVOUS},
|
||||||
mockall::automock,
|
mockall::automock,
|
||||||
|
tonic::transport::server::TcpIncoming,
|
||||||
tonic::Request,
|
tonic::Request,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -114,11 +128,12 @@ mod test {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn bidi_streaming() -> Result<(), Box<dyn std::error::Error>> {
|
async fn bidi_streaming() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let url = RENDEZVOUS.parse()?;
|
let url = RENDEZVOUS.parse()?;
|
||||||
let incoming = tonic::transport::server::TcpIncoming::new(url, true, None)
|
let incoming = TcpIncoming::new(url, true, None).expect("Cannot bind server socket");
|
||||||
.expect("Cannot bind server socket");
|
|
||||||
tokio::spawn(run_server(incoming));
|
tokio::spawn(run_server(incoming));
|
||||||
|
|
||||||
const N_MESSAGES: usize = 20;
|
// This is the number of messages to broadcast
|
||||||
|
const N_MESSAGES: usize = 100;
|
||||||
|
// We use 4 test clients connecting to the server
|
||||||
const CLIENT_IDS: &[&str] = &["client AAA", "client BBB", "client CCC"];
|
const CLIENT_IDS: &[&str] = &["client AAA", "client BBB", "client CCC"];
|
||||||
|
|
||||||
async fn client(id: &str, mock: &MockMessageChecker) {
|
async fn client(id: &str, mock: &MockMessageChecker) {
|
||||||
|
@ -145,9 +160,9 @@ mod test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We create a mock to check that we receive exactly the
|
||||||
|
// messages we want. We also check for saturation.
|
||||||
let mut mock = MockMessageChecker::new();
|
let mut mock = MockMessageChecker::new();
|
||||||
|
|
||||||
let mut clients = vec![];
|
|
||||||
for client_id in CLIENT_IDS.iter() {
|
for client_id in CLIENT_IDS.iter() {
|
||||||
for i in 1..(N_MESSAGES + 1) {
|
for i in 1..(N_MESSAGES + 1) {
|
||||||
let expected = format!("{}: {}", client_id, i);
|
let expected = format!("{}: {}", client_id, i);
|
||||||
|
@ -158,12 +173,13 @@ mod test {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for client_id in CLIENT_IDS {
|
// Time to create the clients, and await for their completion
|
||||||
let c = client(client_id, &mock);
|
let clients = CLIENT_IDS
|
||||||
clients.push(c);
|
.into_iter()
|
||||||
}
|
.map(|&id| client(id, &mock))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
let _ = futures::future::join_all(clients).await;
|
let _ = futures::future::join_all(clients).await;
|
||||||
|
|
||||||
println!("Exiting.");
|
println!("Exiting.");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue