Use a p2p dbus connection during tests
This removes the requirement of having the DBus daemon running locally, which is particularly helpful e.g. in a container or for CI.
This commit is contained in:
parent
256267c213
commit
20072cf1ea
|
@ -27,11 +27,12 @@ bindgen = "0.60"
|
|||
[dev-dependencies]
|
||||
malcontent-nss = { path = ".", features = ["integration_test"] }
|
||||
env_logger = "0.9"
|
||||
event-listener = "2.5"
|
||||
futures-util = "0.3"
|
||||
rusty-hook = "0.11"
|
||||
rusty-forkfork = "0.4"
|
||||
test-cdylib = "1.1"
|
||||
tokio = { version = "1", features = ["rt", "sync", "macros", "time"] }
|
||||
tokio = { version = "1", features = ["rt", "sync", "macros", "net", "time"] }
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
|
@ -42,8 +43,8 @@ log = "0.4"
|
|||
nix = { version = "0.24", features = ["socket", "user", "sched"] }
|
||||
serde = "1.0"
|
||||
tokio = { version = "1", features = ["rt"] }
|
||||
trust-dns-resolver = { version = "0.21", features = ["dns-over-rustls"] }
|
||||
trust-dns-proto = "0.21"
|
||||
trust-dns-resolver = { version = "0.22", features = ["dns-over-rustls"] }
|
||||
trust-dns-proto = "0.22"
|
||||
zbus = { version = "3.0", default-features = false, features = ["tokio"] }
|
||||
zvariant = "3.6"
|
||||
zbus_names = { version = "2.2", optional = true }
|
|
@ -13,7 +13,6 @@ use {
|
|||
std::sync::Arc,
|
||||
trust_dns_proto::rr::record_type::RecordType,
|
||||
trust_dns_proto::rr::{RData, Record},
|
||||
trust_dns_proto::xfer::dns_request::DnsRequestOptions,
|
||||
trust_dns_resolver::TokioAsyncResolver,
|
||||
trust_dns_resolver::{lookup::Lookup, lookup_ip::LookupIp},
|
||||
};
|
||||
|
@ -81,16 +80,8 @@ async unsafe fn resolve_hostname_through(
|
|||
|
||||
let lookup: std::result::Result<Lookup, _> = match args.family {
|
||||
libc::AF_UNSPEC => resolver.lookup_ip(name).await.map(LookupIp::into),
|
||||
libc::AF_INET => {
|
||||
resolver
|
||||
.lookup(name, RecordType::A, DnsRequestOptions::default())
|
||||
.await
|
||||
}
|
||||
libc::AF_INET6 => {
|
||||
resolver
|
||||
.lookup(name, RecordType::AAAA, DnsRequestOptions::default())
|
||||
.await
|
||||
}
|
||||
libc::AF_INET => resolver.lookup(name, RecordType::A).await,
|
||||
libc::AF_INET6 => resolver.lookup(name, RecordType::AAAA).await,
|
||||
_ => return nss_status::NSS_STATUS_NOTFOUND,
|
||||
};
|
||||
|
||||
|
|
|
@ -36,14 +36,30 @@ pub async fn restrictions_for(user: Uid) -> anyhow::Result<Vec<Restriction>, any
|
|||
|
||||
#[cfg(feature = "integration_test")]
|
||||
let proxy = {
|
||||
use tokio::net::UnixStream;
|
||||
|
||||
// During integration testing, we want to connect to a private
|
||||
// bus name to avoid clashes with existing system services.
|
||||
let connection = zbus::Connection::session().await?;
|
||||
let dbus_name = std::env::var("TEST_DBUS_SERVICE_NAME")
|
||||
.expect("The test has not set the TEST_DBUS_SERVICE_NAME environment variable to the private bus name prior to attempting name resolution");
|
||||
let socket_path = std::env::var("TEST_DBUS_SOCKET")
|
||||
.expect("The test has not set the TEST_DBUS_SOCKET environment variable to the unix socket to connect to");
|
||||
|
||||
let socket = loop {
|
||||
match UnixStream::connect(&socket_path).await {
|
||||
Ok(stream) => break stream,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
|
||||
tokio::task::yield_now().await;
|
||||
continue;
|
||||
}
|
||||
Err(e) => anyhow::bail!(e),
|
||||
}
|
||||
};
|
||||
|
||||
let connection = zbus::ConnectionBuilder::unix_stream(socket)
|
||||
.p2p()
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
MalcontentDnsProxy::builder(&connection)
|
||||
.destination(zbus_names::UniqueName::try_from(dbus_name).unwrap())
|
||||
.unwrap()
|
||||
.build()
|
||||
.await
|
||||
.expect("Unable to build DBus proxy object")
|
||||
|
|
|
@ -6,52 +6,48 @@ include!(concat!(
|
|||
"/src/policy_checker/dbus.rs"
|
||||
));
|
||||
|
||||
use {std::collections::HashMap, zbus::dbus_interface};
|
||||
use {
|
||||
std::collections::HashMap,
|
||||
std::sync::atomic::{AtomicUsize, Ordering},
|
||||
zbus::dbus_interface,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MalcontentDBusMock {
|
||||
responses: HashMap<Uid, Vec<Restrictions>>,
|
||||
invocations_left: usize,
|
||||
invocations_left: AtomicUsize,
|
||||
}
|
||||
|
||||
#[dbus_interface(name = "com.endlessm.ParentalControls.Dns")]
|
||||
impl MalcontentDBusMock {
|
||||
fn get_restrictions(&mut self, user_id: u32) -> Restrictions {
|
||||
let answers = self
|
||||
.responses
|
||||
.get_mut(&Uid::from_raw(user_id))
|
||||
.expect(&format!(
|
||||
"MockError: No mocked invocations available for user with id {}",
|
||||
user_id
|
||||
));
|
||||
let restrictions = answers.pop().expect(&format!(
|
||||
"MockError: DBus mock is saturated for user with id {}",
|
||||
user_id
|
||||
let uid = Uid::from_raw(user_id);
|
||||
let answers = self.responses.get_mut(&uid).expect(&format!(
|
||||
"MockError: No mocked invocations available for user with id {}",
|
||||
uid
|
||||
));
|
||||
|
||||
self.invocations_left -= 1;
|
||||
let restrictions = answers.pop().expect(&format!(
|
||||
"MockError: DBus mock is saturated for user with id {}",
|
||||
uid
|
||||
));
|
||||
|
||||
self.invocations_left.fetch_sub(1, Ordering::SeqCst);
|
||||
restrictions
|
||||
}
|
||||
}
|
||||
|
||||
impl MalcontentDBusMock {
|
||||
pub fn new(mut responses: HashMap<Uid, Vec<Restrictions>>) -> Self {
|
||||
let responses_size: usize = responses
|
||||
.values()
|
||||
.map(|v| {
|
||||
std::cmp::max(
|
||||
v.len(),
|
||||
1, /* 'No restrictions' still counts as one message */
|
||||
)
|
||||
})
|
||||
.sum();
|
||||
let responses_size: usize = responses.values().map(Vec::len).sum();
|
||||
|
||||
for r in responses.values_mut() {
|
||||
r.reverse(); // we pop responses from the back, so...
|
||||
}
|
||||
|
||||
let ret = Self {
|
||||
responses,
|
||||
invocations_left: responses_size,
|
||||
invocations_left: AtomicUsize::new(responses_size),
|
||||
};
|
||||
|
||||
ret
|
||||
|
@ -60,24 +56,57 @@ impl MalcontentDBusMock {
|
|||
|
||||
impl Drop for MalcontentDBusMock {
|
||||
fn drop(&mut self) {
|
||||
let invocations_left = self.invocations_left.load(Ordering::Acquire);
|
||||
assert_eq!(
|
||||
self.invocations_left, 0,
|
||||
invocations_left, 0,
|
||||
"MockError: During teardown, {} invocations are still left on the mock object",
|
||||
self.invocations_left
|
||||
invocations_left
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn mock_dbus(responses: HashMap<Uid, Vec<Restrictions>>) -> Result<zbus::Connection> {
|
||||
let mock = MalcontentDBusMock::new(responses);
|
||||
let connection = zbus::ConnectionBuilder::session()?
|
||||
.serve_at("/com/endlessm/ParentalControls/Dns", mock)?
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
std::env::set_var(
|
||||
"TEST_DBUS_SERVICE_NAME",
|
||||
connection.unique_name().unwrap().as_str(),
|
||||
);
|
||||
Ok(connection)
|
||||
pub struct DBusServerGuard {
|
||||
handle: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl Drop for DBusServerGuard {
|
||||
fn drop(&mut self) {
|
||||
self.handle.abort();
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn mock_dbus(responses: HashMap<Uid, Vec<Restrictions>>) -> Result<DBusServerGuard> {
|
||||
let guid = zbus::Guid::generate();
|
||||
let socket_path =
|
||||
std::path::PathBuf::from(std::env!("CARGO_TARGET_TMPDIR")).join(guid.as_str());
|
||||
|
||||
if socket_path.exists() {
|
||||
std::fs::remove_file(&socket_path)?;
|
||||
}
|
||||
|
||||
let socket = tokio::net::UnixListener::bind(&socket_path)?;
|
||||
std::env::set_var("TEST_DBUS_SOCKET", &socket_path);
|
||||
|
||||
let mock = MalcontentDBusMock::new(responses);
|
||||
let handle = tokio::spawn(async move {
|
||||
let (stream, _) = socket
|
||||
.accept()
|
||||
.await
|
||||
.expect("Server socket closed unexpectedly");
|
||||
std::fs::remove_file(socket_path).unwrap(); // Once we accepted, we can already remove the socket
|
||||
|
||||
let _ = zbus::ConnectionBuilder::unix_stream(stream)
|
||||
.server(&guid)
|
||||
.p2p()
|
||||
.name("com.endlessm.ParentalControls")
|
||||
.expect("Unable to serve given dbus name")
|
||||
.serve_at("/com/endlessm/ParentalControls/Dns", mock)
|
||||
.expect("Unable to server malcontent dbus mock object")
|
||||
.build()
|
||||
.await;
|
||||
|
||||
std::future::pending::<()>().await;
|
||||
});
|
||||
|
||||
Ok(DBusServerGuard { handle })
|
||||
}
|
||||
|
|
|
@ -122,12 +122,12 @@ fork_test! {
|
|||
const HOSTNAME: &str = "wikipedia.org";
|
||||
|
||||
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
for family in [libc::AF_INET, libc::AF_INET6] {
|
||||
let _dbus = common::mock_dbus(HashMap::from([(
|
||||
getuid(),
|
||||
vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()],
|
||||
)])).await?;
|
||||
let _dbus = common::mock_dbus(HashMap::from([(
|
||||
getuid(),
|
||||
vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()],
|
||||
)])).await?;
|
||||
|
||||
for family in [libc::AF_INET, libc::AF_INET6] {
|
||||
let system_addr = common::resolve_with_system(family, HOSTNAME)?;
|
||||
let our_addrs = common::resolve_with_module(family, HOSTNAME)?;
|
||||
assert!(our_addrs.contains(&system_addr));
|
||||
|
@ -150,7 +150,7 @@ fork_test! {
|
|||
|
||||
let system_addr = common::resolve_with_system(libc::AF_INET, HOSTNAME)?;
|
||||
let our_addrs = common::resolve_with_module(libc::AF_INET, HOSTNAME)?;
|
||||
assert!(!our_addrs.contains(&system_addr));
|
||||
assert!(!our_addrs.contains(&system_addr), "Resolver answered with {:?}, should not contain {}", our_addrs, system_addr);
|
||||
assert_eq!(our_addrs, [IpAddr::V4(Ipv4Addr::UNSPECIFIED)]);
|
||||
Ok(())
|
||||
})
|
||||
|
@ -170,7 +170,7 @@ fork_test! {
|
|||
|
||||
let system_addr = common::resolve_with_system(libc::AF_INET6, HOSTNAME)?;
|
||||
let our_addrs = common::resolve_with_module(libc::AF_INET6, HOSTNAME)?;
|
||||
assert!(!our_addrs.contains(&system_addr));
|
||||
assert!(!our_addrs.contains(&system_addr), "Resolver answered with {:?}, should not contain {}", our_addrs, system_addr);
|
||||
assert_eq!(our_addrs, [IpAddr::V6(Ipv6Addr::UNSPECIFIED)]);
|
||||
Ok(())
|
||||
})
|
||||
|
@ -183,8 +183,8 @@ fork_test! {
|
|||
const HOSTNAME: &str = "malware.testcategory.com";
|
||||
|
||||
tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
let _dbus = common::mock_dbus(HashMap::from([(getuid(), vec![NO_RESTRICTIONS.clone()])])).await?;
|
||||
for family in [libc::AF_INET, libc::AF_INET6] {
|
||||
let _dbus = common::mock_dbus(HashMap::from([(getuid(), vec![NO_RESTRICTIONS.clone()])])).await?;
|
||||
let system_addr = common::resolve_with_system(family, HOSTNAME)?;
|
||||
let our_addrs = common::resolve_with_module(family, HOSTNAME)?;
|
||||
assert!(our_addrs.contains(&system_addr));
|
||||
|
|
Loading…
Reference in New Issue