Use a p2p dbus connection during tests

This removes the requirement of having the DBus daemon running locally,
which is particularly helpful e.g. in a container or for CI.
This commit is contained in:
Matteo Settenvini 2022-09-12 00:38:01 +02:00
parent 256267c213
commit 20072cf1ea
Signed by: matteo
GPG Key ID: 8576CC1AD97D42DF
5 changed files with 101 additions and 64 deletions

View File

@ -27,11 +27,12 @@ bindgen = "0.60"
[dev-dependencies]
malcontent-nss = { path = ".", features = ["integration_test"] }
env_logger = "0.9"
event-listener = "2.5"
futures-util = "0.3"
rusty-hook = "0.11"
rusty-forkfork = "0.4"
test-cdylib = "1.1"
tokio = { version = "1", features = ["rt", "sync", "macros", "time"] }
tokio = { version = "1", features = ["rt", "sync", "macros", "net", "time"] }
[dependencies]
anyhow = "1.0"
@ -42,8 +43,8 @@ log = "0.4"
nix = { version = "0.24", features = ["socket", "user", "sched"] }
serde = "1.0"
tokio = { version = "1", features = ["rt"] }
trust-dns-resolver = { version = "0.21", features = ["dns-over-rustls"] }
trust-dns-proto = "0.21"
trust-dns-resolver = { version = "0.22", features = ["dns-over-rustls"] }
trust-dns-proto = "0.22"
zbus = { version = "3.0", default-features = false, features = ["tokio"] }
zvariant = "3.6"
zbus_names = { version = "2.2", optional = true }

View File

@ -13,7 +13,6 @@ use {
std::sync::Arc,
trust_dns_proto::rr::record_type::RecordType,
trust_dns_proto::rr::{RData, Record},
trust_dns_proto::xfer::dns_request::DnsRequestOptions,
trust_dns_resolver::TokioAsyncResolver,
trust_dns_resolver::{lookup::Lookup, lookup_ip::LookupIp},
};
@ -81,16 +80,8 @@ async unsafe fn resolve_hostname_through(
let lookup: std::result::Result<Lookup, _> = match args.family {
libc::AF_UNSPEC => resolver.lookup_ip(name).await.map(LookupIp::into),
libc::AF_INET => {
resolver
.lookup(name, RecordType::A, DnsRequestOptions::default())
.await
}
libc::AF_INET6 => {
resolver
.lookup(name, RecordType::AAAA, DnsRequestOptions::default())
.await
}
libc::AF_INET => resolver.lookup(name, RecordType::A).await,
libc::AF_INET6 => resolver.lookup(name, RecordType::AAAA).await,
_ => return nss_status::NSS_STATUS_NOTFOUND,
};

View File

@ -36,14 +36,30 @@ pub async fn restrictions_for(user: Uid) -> anyhow::Result<Vec<Restriction>, any
#[cfg(feature = "integration_test")]
let proxy = {
use tokio::net::UnixStream;
// During integration testing, we want to connect to a private
// bus name to avoid clashes with existing system services.
let connection = zbus::Connection::session().await?;
let dbus_name = std::env::var("TEST_DBUS_SERVICE_NAME")
.expect("The test has not set the TEST_DBUS_SERVICE_NAME environment variable to the private bus name prior to attempting name resolution");
let socket_path = std::env::var("TEST_DBUS_SOCKET")
.expect("The test has not set the TEST_DBUS_SOCKET environment variable to the unix socket to connect to");
let socket = loop {
match UnixStream::connect(&socket_path).await {
Ok(stream) => break stream,
Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
tokio::task::yield_now().await;
continue;
}
Err(e) => anyhow::bail!(e),
}
};
let connection = zbus::ConnectionBuilder::unix_stream(socket)
.p2p()
.build()
.await?;
MalcontentDnsProxy::builder(&connection)
.destination(zbus_names::UniqueName::try_from(dbus_name).unwrap())
.unwrap()
.build()
.await
.expect("Unable to build DBus proxy object")

View File

@ -6,52 +6,48 @@ include!(concat!(
"/src/policy_checker/dbus.rs"
));
use {std::collections::HashMap, zbus::dbus_interface};
use {
std::collections::HashMap,
std::sync::atomic::{AtomicUsize, Ordering},
zbus::dbus_interface,
};
#[derive(Debug)]
pub struct MalcontentDBusMock {
responses: HashMap<Uid, Vec<Restrictions>>,
invocations_left: usize,
invocations_left: AtomicUsize,
}
#[dbus_interface(name = "com.endlessm.ParentalControls.Dns")]
impl MalcontentDBusMock {
fn get_restrictions(&mut self, user_id: u32) -> Restrictions {
let answers = self
.responses
.get_mut(&Uid::from_raw(user_id))
.expect(&format!(
"MockError: No mocked invocations available for user with id {}",
user_id
));
let restrictions = answers.pop().expect(&format!(
"MockError: DBus mock is saturated for user with id {}",
user_id
let uid = Uid::from_raw(user_id);
let answers = self.responses.get_mut(&uid).expect(&format!(
"MockError: No mocked invocations available for user with id {}",
uid
));
self.invocations_left -= 1;
let restrictions = answers.pop().expect(&format!(
"MockError: DBus mock is saturated for user with id {}",
uid
));
self.invocations_left.fetch_sub(1, Ordering::SeqCst);
restrictions
}
}
impl MalcontentDBusMock {
pub fn new(mut responses: HashMap<Uid, Vec<Restrictions>>) -> Self {
let responses_size: usize = responses
.values()
.map(|v| {
std::cmp::max(
v.len(),
1, /* 'No restrictions' still counts as one message */
)
})
.sum();
let responses_size: usize = responses.values().map(Vec::len).sum();
for r in responses.values_mut() {
r.reverse(); // we pop responses from the back, so...
}
let ret = Self {
responses,
invocations_left: responses_size,
invocations_left: AtomicUsize::new(responses_size),
};
ret
@ -60,24 +56,57 @@ impl MalcontentDBusMock {
impl Drop for MalcontentDBusMock {
fn drop(&mut self) {
let invocations_left = self.invocations_left.load(Ordering::Acquire);
assert_eq!(
self.invocations_left, 0,
invocations_left, 0,
"MockError: During teardown, {} invocations are still left on the mock object",
self.invocations_left
invocations_left
);
}
}
pub async fn mock_dbus(responses: HashMap<Uid, Vec<Restrictions>>) -> Result<zbus::Connection> {
let mock = MalcontentDBusMock::new(responses);
let connection = zbus::ConnectionBuilder::session()?
.serve_at("/com/endlessm/ParentalControls/Dns", mock)?
.build()
.await?;
std::env::set_var(
"TEST_DBUS_SERVICE_NAME",
connection.unique_name().unwrap().as_str(),
);
Ok(connection)
pub struct DBusServerGuard {
handle: tokio::task::JoinHandle<()>,
}
impl Drop for DBusServerGuard {
fn drop(&mut self) {
self.handle.abort();
}
}
pub async fn mock_dbus(responses: HashMap<Uid, Vec<Restrictions>>) -> Result<DBusServerGuard> {
let guid = zbus::Guid::generate();
let socket_path =
std::path::PathBuf::from(std::env!("CARGO_TARGET_TMPDIR")).join(guid.as_str());
if socket_path.exists() {
std::fs::remove_file(&socket_path)?;
}
let socket = tokio::net::UnixListener::bind(&socket_path)?;
std::env::set_var("TEST_DBUS_SOCKET", &socket_path);
let mock = MalcontentDBusMock::new(responses);
let handle = tokio::spawn(async move {
let (stream, _) = socket
.accept()
.await
.expect("Server socket closed unexpectedly");
std::fs::remove_file(socket_path).unwrap(); // Once we accepted, we can already remove the socket
let _ = zbus::ConnectionBuilder::unix_stream(stream)
.server(&guid)
.p2p()
.name("com.endlessm.ParentalControls")
.expect("Unable to serve given dbus name")
.serve_at("/com/endlessm/ParentalControls/Dns", mock)
.expect("Unable to server malcontent dbus mock object")
.build()
.await;
std::future::pending::<()>().await;
});
Ok(DBusServerGuard { handle })
}

View File

@ -122,12 +122,12 @@ fork_test! {
const HOSTNAME: &str = "wikipedia.org";
tokio::runtime::Runtime::new().unwrap().block_on(async {
for family in [libc::AF_INET, libc::AF_INET6] {
let _dbus = common::mock_dbus(HashMap::from([(
getuid(),
vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()],
)])).await?;
let _dbus = common::mock_dbus(HashMap::from([(
getuid(),
vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()],
)])).await?;
for family in [libc::AF_INET, libc::AF_INET6] {
let system_addr = common::resolve_with_system(family, HOSTNAME)?;
let our_addrs = common::resolve_with_module(family, HOSTNAME)?;
assert!(our_addrs.contains(&system_addr));
@ -150,7 +150,7 @@ fork_test! {
let system_addr = common::resolve_with_system(libc::AF_INET, HOSTNAME)?;
let our_addrs = common::resolve_with_module(libc::AF_INET, HOSTNAME)?;
assert!(!our_addrs.contains(&system_addr));
assert!(!our_addrs.contains(&system_addr), "Resolver answered with {:?}, should not contain {}", our_addrs, system_addr);
assert_eq!(our_addrs, [IpAddr::V4(Ipv4Addr::UNSPECIFIED)]);
Ok(())
})
@ -170,7 +170,7 @@ fork_test! {
let system_addr = common::resolve_with_system(libc::AF_INET6, HOSTNAME)?;
let our_addrs = common::resolve_with_module(libc::AF_INET6, HOSTNAME)?;
assert!(!our_addrs.contains(&system_addr));
assert!(!our_addrs.contains(&system_addr), "Resolver answered with {:?}, should not contain {}", our_addrs, system_addr);
assert_eq!(our_addrs, [IpAddr::V6(Ipv6Addr::UNSPECIFIED)]);
Ok(())
})
@ -183,8 +183,8 @@ fork_test! {
const HOSTNAME: &str = "malware.testcategory.com";
tokio::runtime::Runtime::new().unwrap().block_on(async {
let _dbus = common::mock_dbus(HashMap::from([(getuid(), vec![NO_RESTRICTIONS.clone()])])).await?;
for family in [libc::AF_INET, libc::AF_INET6] {
let _dbus = common::mock_dbus(HashMap::from([(getuid(), vec![NO_RESTRICTIONS.clone()])])).await?;
let system_addr = common::resolve_with_system(family, HOSTNAME)?;
let our_addrs = common::resolve_with_module(family, HOSTNAME)?;
assert!(our_addrs.contains(&system_addr));