140 lines
4.1 KiB
Rust
140 lines
4.1 KiB
Rust
// SPDX-FileCopyrightText: 2023 Matteo Settenvini <matteo.settenvini@montecristosoftware.eu>
|
|
// SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
mod pb {
|
|
tonic::include_proto!("package");
|
|
}
|
|
|
|
use {
|
|
futures::{Stream, StreamExt as _},
|
|
pb::service_server::{Service, ServiceServer},
|
|
pb::Message,
|
|
std::cell::RefCell,
|
|
std::error::Error,
|
|
std::pin::Pin,
|
|
std::sync::{Arc, Mutex, Weak},
|
|
tokio::sync::broadcast,
|
|
tokio::task::JoinHandle,
|
|
tokio_stream::wrappers::BroadcastStream,
|
|
tonic::{Request, Response, Status, Streaming},
|
|
};
|
|
|
|
const MESSAGE_QUEUE_SIZE: usize = 20;
|
|
const RENDEZVOUS: &'static str = "[::1]:10000";
|
|
|
|
#[tokio::main(flavor = "current_thread")]
|
|
async fn main() -> Result<(), Box<dyn Error>> {
|
|
let server_handle = run_server().await;
|
|
Ok(server_handle.await.unwrap()?)
|
|
}
|
|
|
|
async fn run_server() -> JoinHandle<Result<(), tonic::transport::Error>> {
|
|
use tonic::transport::Server;
|
|
|
|
let addr = RENDEZVOUS.parse().unwrap();
|
|
let service = MessageService::default();
|
|
let svc_adapter = ServiceServer::new(service);
|
|
let builder = Server::builder().add_service(svc_adapter);
|
|
tokio::spawn(builder.serve(addr))
|
|
}
|
|
|
|
#[derive(Default, Debug)]
|
|
struct MessageService {
|
|
shared_tx: Mutex<RefCell<Weak<broadcast::Sender<Message>>>>,
|
|
}
|
|
|
|
impl MessageService {
|
|
fn shared_channel(
|
|
&self,
|
|
) -> (
|
|
Arc<broadcast::Sender<Message>>,
|
|
broadcast::Receiver<Message>,
|
|
) {
|
|
let tx_guard = self.shared_tx.lock().unwrap();
|
|
let maybe_tx = tx_guard.borrow().upgrade();
|
|
match maybe_tx {
|
|
Some(tx) => (tx.clone(), tx.subscribe()),
|
|
None => {
|
|
let (tx, rx) = broadcast::channel(MESSAGE_QUEUE_SIZE);
|
|
let tx = Arc::new(tx);
|
|
*tx_guard.borrow_mut() = Arc::downgrade(&tx);
|
|
(tx, rx)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tonic::async_trait]
|
|
impl Service for MessageService {
|
|
type BroadcasterStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
|
|
|
|
async fn broadcaster(
|
|
&self,
|
|
requests: Request<Streaming<Message>>,
|
|
) -> Result<Response<Self::BroadcasterStream>, Status> {
|
|
let mut incoming = requests.into_inner();
|
|
|
|
let (tx, rx) = self.shared_channel();
|
|
let output = BroadcastStream::new(rx)
|
|
.map(|result| result.map_err(|err| Status::data_loss(err.to_string())));
|
|
|
|
tokio::spawn(async move {
|
|
while let Some(Ok(message)) = incoming.next().await {
|
|
let _ = tx.send(message); // Ignore err if no receivers
|
|
}
|
|
});
|
|
|
|
Ok(Response::new(Box::pin(output)))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use {
|
|
super::pb::service_client::ServiceClient,
|
|
super::run_server,
|
|
super::{Message, RENDEZVOUS},
|
|
tonic::Request,
|
|
};
|
|
|
|
#[tokio::test]
|
|
async fn bidi_streaming() -> Result<(), Box<dyn std::error::Error>> {
|
|
let _server_handle = run_server().await;
|
|
|
|
// FIXME: avoid sleep waiting for server to start
|
|
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
|
|
|
|
async fn client(id: &str) {
|
|
let mut client = ServiceClient::connect(format!("http://{}", RENDEZVOUS))
|
|
.await
|
|
.unwrap();
|
|
|
|
let client_id = String::from(id);
|
|
let outbound = async_stream::stream! {
|
|
for i in 1..21 {
|
|
tokio::task::yield_now().await;
|
|
let message = Message {
|
|
contents: format!("{}: {}", client_id, i),
|
|
};
|
|
|
|
yield message;
|
|
}
|
|
};
|
|
|
|
let response = client.broadcaster(Request::new(outbound)).await.unwrap();
|
|
let mut inbound = response.into_inner();
|
|
|
|
while let Some(msg) = inbound.message().await.unwrap() {
|
|
println!("{} received {:?}", id, msg);
|
|
}
|
|
}
|
|
|
|
let c1 = client("client AAA");
|
|
let c2 = client("client BBB");
|
|
let c3 = client("client CCC");
|
|
let _ = tokio::join!(c1, c2, c3);
|
|
println!("Exiting.");
|
|
Ok(())
|
|
}
|
|
}
|