malcontent/nss/tests/common/dbus.rs

143 lines
4.0 KiB
Rust

// SPDX-FileCopyrightText: Matteo Settenvini <matteo.settenvini@montecristosoftware.eu>
// SPDX-License-Identifier: GPL-3.0-or-later
include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/policy_checker/dbus.rs"
));
use {
std::sync::atomic::{AtomicUsize, Ordering},
tokio_util::sync::CancellationToken,
zbus::dbus_interface,
};
#[derive(Debug)]
pub struct MalcontentDBusMock {
responses: Vec<Restrictions>,
invocations_left: AtomicUsize,
}
#[dbus_interface(name = "com.endlessm.ParentalControls.Dns")]
impl MalcontentDBusMock {
fn get_dns(&mut self) -> Restrictions {
let restrictions = self
.responses
.pop()
.expect("MockError: DBus mock is saturated");
self.invocations_left.fetch_sub(1, Ordering::SeqCst);
restrictions
}
}
impl MalcontentDBusMock {
pub fn new(mut responses: Vec<Restrictions>) -> Self {
responses.reverse(); // We pop responses from the back, so...
Self {
invocations_left: AtomicUsize::new(responses.len()),
responses,
}
}
}
impl Drop for MalcontentDBusMock {
fn drop(&mut self) {
let invocations_left = self.invocations_left.load(Ordering::Acquire);
assert_eq!(
invocations_left, 0,
"MockError: During teardown, {} invocations are still left on the mock object",
invocations_left
);
}
}
pub struct DBusMockServer {
handle: std::thread::JoinHandle<Result<()>>,
cancellation: CancellationToken,
}
impl DBusMockServer {
pub fn new(responses: Vec<Restrictions>) -> Result<Self> {
let token = CancellationToken::new();
let cloned_token = token.clone();
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
std::env::set_var("TEST_DBUS_SOCKET", format!("{}", listener.local_addr()?));
let handle = std::thread::spawn(move || {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(Self::spawn_async(responses, listener, cloned_token))
});
Ok(Self {
handle: handle,
cancellation: token,
})
}
pub fn stop(self) -> Result<()> {
self.cancellation.cancel();
self.handle.join().unwrap()
}
async fn spawn_async(
responses: Vec<Restrictions>,
listener: std::net::TcpListener,
cancellation_token: CancellationToken,
) -> Result<()> {
listener.set_nonblocking(true)?;
let listener = tokio::net::TcpListener::from_std(listener)?;
let guid = zbus::Guid::generate();
let mock = MalcontentDBusMock::new(responses);
let (stream, _) = listener
.accept()
.await
.expect("Server socket closed unexpectedly");
log::trace!("dbus mock server accepted client connection");
let _connection = zbus::ConnectionBuilder::tcp_stream(stream)
.server(&guid)
.p2p()
.auth_mechanisms(&[zbus::AuthMechanism::Anonymous])
.name("com.endlessm.ParentalControls")
.expect("Unable to serve given dbus name")
.serve_at("/com/endlessm/ParentalControls/Dns", mock)
.expect("Unable to server malcontent dbus mock object")
.build()
.await?;
tokio::select! {
_ = cancellation_token.cancelled() => {
Ok(())
}
_ = std::future::pending::<()>() => { unreachable!() }
}
}
}
pub struct DBusMockServerGuard {
mock: Option<DBusMockServer>,
}
impl DBusMockServerGuard {
pub fn new(responses: Vec<Restrictions>) -> Result<Self> {
Ok(Self {
mock: Some(DBusMockServer::new(responses)?),
})
}
}
impl Drop for DBusMockServerGuard {
fn drop(&mut self) {
self.mock
.take()
.unwrap()
.stop()
.expect("cannot stop dbus server mock");
}
}