rust-playground/grpc/broadcaster/src/main.rs

140 lines
4.1 KiB
Rust

// SPDX-FileCopyrightText: 2023 Matteo Settenvini <matteo.settenvini@montecristosoftware.eu>
// SPDX-License-Identifier: AGPL-3.0-or-later
mod pb {
tonic::include_proto!("package");
}
use {
futures::{Stream, StreamExt as _},
pb::service_server::{Service, ServiceServer},
pb::Message,
std::cell::RefCell,
std::error::Error,
std::pin::Pin,
std::sync::{Arc, Mutex, Weak},
tokio::sync::broadcast,
tokio::task::JoinHandle,
tokio_stream::wrappers::BroadcastStream,
tonic::{Request, Response, Status, Streaming},
};
const MESSAGE_QUEUE_SIZE: usize = 20;
const RENDEZVOUS: &'static str = "[::1]:10000";
#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Box<dyn Error>> {
let server_handle = run_server().await;
Ok(server_handle.await.unwrap()?)
}
async fn run_server() -> JoinHandle<Result<(), tonic::transport::Error>> {
use tonic::transport::Server;
let addr = RENDEZVOUS.parse().unwrap();
let service = MessageService::default();
let svc_adapter = ServiceServer::new(service);
let builder = Server::builder().add_service(svc_adapter);
tokio::spawn(builder.serve(addr))
}
#[derive(Default, Debug)]
struct MessageService {
shared_tx: Mutex<RefCell<Weak<broadcast::Sender<Message>>>>,
}
impl MessageService {
fn shared_channel(
&self,
) -> (
Arc<broadcast::Sender<Message>>,
broadcast::Receiver<Message>,
) {
let tx_guard = self.shared_tx.lock().unwrap();
let maybe_tx = tx_guard.borrow().upgrade();
match maybe_tx {
Some(tx) => (tx.clone(), tx.subscribe()),
None => {
let (tx, rx) = broadcast::channel(MESSAGE_QUEUE_SIZE);
let tx = Arc::new(tx);
*tx_guard.borrow_mut() = Arc::downgrade(&tx);
(tx, rx)
}
}
}
}
#[tonic::async_trait]
impl Service for MessageService {
type BroadcasterStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
async fn broadcaster(
&self,
requests: Request<Streaming<Message>>,
) -> Result<Response<Self::BroadcasterStream>, Status> {
let mut incoming = requests.into_inner();
let (tx, rx) = self.shared_channel();
let output = BroadcastStream::new(rx)
.map(|result| result.map_err(|err| Status::data_loss(err.to_string())));
tokio::spawn(async move {
while let Some(Ok(message)) = incoming.next().await {
let _ = tx.send(message); // Ignore err if no receivers
}
});
Ok(Response::new(Box::pin(output)))
}
}
#[cfg(test)]
mod test {
use {
super::pb::service_client::ServiceClient,
super::run_server,
super::{Message, RENDEZVOUS},
tonic::Request,
};
#[tokio::test]
async fn bidi_streaming() -> Result<(), Box<dyn std::error::Error>> {
let _server_handle = run_server().await;
// FIXME: avoid sleep waiting for server to start
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
async fn client(id: &str) {
let mut client = ServiceClient::connect(format!("http://{}", RENDEZVOUS))
.await
.unwrap();
let client_id = String::from(id);
let outbound = async_stream::stream! {
for i in 1..21 {
tokio::task::yield_now().await;
let message = Message {
contents: format!("{}: {}", client_id, i),
};
yield message;
}
};
let response = client.broadcaster(Request::new(outbound)).await.unwrap();
let mut inbound = response.into_inner();
while let Some(msg) = inbound.message().await.unwrap() {
println!("{} received {:?}", id, msg);
}
}
let c1 = client("client AAA");
let c2 = client("client BBB");
let c3 = client("client CCC");
let _ = tokio::join!(c1, c2, c3);
println!("Exiting.");
Ok(())
}
}