// SPDX-FileCopyrightText: Matteo Settenvini // SPDX-License-Identifier: GPL-3.0-or-later include!(concat!( env!("CARGO_MANIFEST_DIR"), "/src/policy_checker/dbus.rs" )); use { std::sync::atomic::{AtomicUsize, Ordering}, tokio_util::sync::CancellationToken, zbus::dbus_interface, }; #[derive(Debug)] pub struct MalcontentDBusMock { responses: Vec, invocations_left: AtomicUsize, } #[dbus_interface(name = "com.endlessm.ParentalControls.Dns")] impl MalcontentDBusMock { fn get_dns(&mut self) -> Restrictions { let restrictions = self .responses .pop() .expect("MockError: DBus mock is saturated"); self.invocations_left.fetch_sub(1, Ordering::SeqCst); restrictions } } impl MalcontentDBusMock { pub fn new(mut responses: Vec) -> Self { responses.reverse(); // We pop responses from the back, so... Self { invocations_left: AtomicUsize::new(responses.len()), responses, } } } impl Drop for MalcontentDBusMock { fn drop(&mut self) { let invocations_left = self.invocations_left.load(Ordering::Acquire); assert_eq!( invocations_left, 0, "MockError: During teardown, {} invocations are still left on the mock object", invocations_left ); } } pub struct DBusMockServer { handle: std::thread::JoinHandle>, cancellation: CancellationToken, } impl DBusMockServer { pub fn new(responses: Vec) -> Result { let token = CancellationToken::new(); let cloned_token = token.clone(); let listener = std::net::TcpListener::bind("127.0.0.1:0")?; std::env::set_var("TEST_DBUS_SOCKET", format!("{}", listener.local_addr()?)); let handle = std::thread::spawn(move || { tokio::runtime::Builder::new_current_thread() .enable_all() .build() .unwrap() .block_on(Self::spawn_async(responses, listener, cloned_token)) }); Ok(Self { handle: handle, cancellation: token, }) } pub fn stop(self) -> Result<()> { self.cancellation.cancel(); self.handle.join().unwrap() } async fn spawn_async( responses: Vec, listener: std::net::TcpListener, cancellation_token: CancellationToken, ) -> Result<()> { listener.set_nonblocking(true)?; let listener = tokio::net::TcpListener::from_std(listener)?; let guid = zbus::Guid::generate(); let mock = MalcontentDBusMock::new(responses); let (stream, _) = listener .accept() .await .expect("Server socket closed unexpectedly"); log::trace!("dbus mock server accepted client connection"); let _connection = zbus::ConnectionBuilder::tcp_stream(stream) .server(&guid) .p2p() .auth_mechanisms(&[zbus::AuthMechanism::Anonymous]) .name("com.endlessm.ParentalControls") .expect("Unable to serve given dbus name") .serve_at("/com/endlessm/ParentalControls/Dns", mock) .expect("Unable to server malcontent dbus mock object") .build() .await?; tokio::select! { _ = cancellation_token.cancelled() => { Ok(()) } _ = std::future::pending::<()>() => { unreachable!() } } } } pub struct DBusMockServerGuard { mock: Option, } impl DBusMockServerGuard { pub fn new(responses: Vec) -> Result { Ok(Self { mock: Some(DBusMockServer::new(responses)?), }) } } impl Drop for DBusMockServerGuard { fn drop(&mut self) { self.mock .take() .unwrap() .stop() .expect("cannot stop dbus server mock"); } }