rust-playground/grpc/broadcaster/src/main.rs

171 lines
5.1 KiB
Rust

// SPDX-FileCopyrightText: 2023 Matteo Settenvini <matteo.settenvini@montecristosoftware.eu>
// SPDX-License-Identifier: AGPL-3.0-or-later
use tonic::transport::server::TcpIncoming;
mod pb {
tonic::include_proto!("package");
}
use {
futures::{Stream, StreamExt as _},
pb::service_server::{Service, ServiceServer},
pb::Message,
std::cell::RefCell,
std::error::Error,
std::pin::Pin,
std::sync::{Arc, Mutex, Weak},
tokio::sync::broadcast,
tokio_stream::wrappers::BroadcastStream,
tonic::{Request, Response, Status, Streaming},
};
const MESSAGE_QUEUE_SIZE: usize = 20;
const RENDEZVOUS: &'static str = "[::1]:10000";
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn Error>> {
let url = RENDEZVOUS.parse()?;
let incoming = tonic::transport::server::TcpIncoming::new(url, true, None)
.expect("Cannot bind server socket");
run_server(incoming).await?;
Ok(())
}
fn run_server(
incoming: TcpIncoming,
) -> impl std::future::Future<Output = Result<(), tonic::transport::Error>> {
use tonic::transport::Server;
let service = MessageService::default();
let svc_adapter = ServiceServer::new(service);
let builder = Server::builder().add_service(svc_adapter);
builder.serve_with_incoming(incoming)
}
#[derive(Default, Debug)]
struct MessageService {
shared_tx: Mutex<RefCell<Weak<broadcast::Sender<Message>>>>,
}
impl MessageService {
fn shared_channel(
&self,
) -> (
Arc<broadcast::Sender<Message>>,
broadcast::Receiver<Message>,
) {
let tx_guard = self.shared_tx.lock().unwrap();
let maybe_tx = tx_guard.borrow().upgrade();
match maybe_tx {
Some(tx) => {
let rx = tx.subscribe();
(tx, rx)
}
None => {
let (tx, rx) = broadcast::channel(MESSAGE_QUEUE_SIZE);
let tx = Arc::new(tx);
*tx_guard.borrow_mut() = Arc::downgrade(&tx);
(tx, rx)
}
}
}
}
#[tonic::async_trait]
impl Service for MessageService {
type BroadcasterStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
async fn broadcaster(
&self,
requests: Request<Streaming<Message>>,
) -> Result<Response<Self::BroadcasterStream>, Status> {
let mut incoming = requests.into_inner();
let (tx, rx) = self.shared_channel();
let output = BroadcastStream::new(rx)
.map(|result| result.map_err(|err| Status::data_loss(err.to_string())));
tokio::spawn(async move {
while let Some(Ok(message)) = incoming.next().await {
let _ = tx.send(message); // Ignore err if no receivers
}
});
Ok(Response::new(Box::pin(output)))
}
}
#[cfg(test)]
mod test {
use {
super::pb::service_client::ServiceClient,
super::run_server,
super::{Message, RENDEZVOUS},
mockall::automock,
tonic::Request,
};
#[automock]
trait MessageChecker {
fn check_contents(&self, msg: &Message);
}
#[tokio::test]
async fn bidi_streaming() -> Result<(), Box<dyn std::error::Error>> {
let url = RENDEZVOUS.parse()?;
let incoming = tonic::transport::server::TcpIncoming::new(url, true, None)
.expect("Cannot bind server socket");
tokio::spawn(run_server(incoming));
const N_MESSAGES: usize = 20;
const CLIENT_IDS: &[&str] = &["client AAA", "client BBB", "client CCC"];
async fn client(id: &str, mock: &MockMessageChecker) {
let addr = format!("http://{}", RENDEZVOUS);
let mut client = ServiceClient::connect(addr).await.unwrap();
let client_id = String::from(id);
let outbound = async_stream::stream! {
for i in 1..(N_MESSAGES+1) {
tokio::task::yield_now().await;
let message = Message {
contents: format!("{}: {}", client_id, i),
};
yield message;
}
};
let response = client.broadcaster(Request::new(outbound)).await.unwrap();
let mut inbound = response.into_inner();
while let Some(msg) = inbound.message().await.unwrap() {
println!("{} received {:?}", id, msg);
mock.check_contents(&msg);
}
}
let mut mock = MockMessageChecker::new();
let mut clients = vec![];
for client_id in CLIENT_IDS.iter() {
for i in 1..(N_MESSAGES + 1) {
let expected = format!("{}: {}", client_id, i);
mock.expect_check_contents()
.times(CLIENT_IDS.len())
.withf(move |msg| msg.contents == expected)
.return_const(());
}
}
for client_id in CLIENT_IDS {
let c = client(client_id, &mock);
clients.push(c);
}
let _ = futures::future::join_all(clients).await;
println!("Exiting.");
Ok(())
}
}