Use a p2p dbus connection during tests

This removes the requirement of having the DBus daemon running locally,
which is particularly helpful e.g. in a container or for CI.
This commit is contained in:
Matteo Settenvini 2022-09-12 00:38:01 +02:00
parent 256267c213
commit 20072cf1ea
Signed by: matteo
GPG Key ID: 8576CC1AD97D42DF
5 changed files with 101 additions and 64 deletions

View File

@ -27,11 +27,12 @@ bindgen = "0.60"
[dev-dependencies] [dev-dependencies]
malcontent-nss = { path = ".", features = ["integration_test"] } malcontent-nss = { path = ".", features = ["integration_test"] }
env_logger = "0.9" env_logger = "0.9"
event-listener = "2.5"
futures-util = "0.3" futures-util = "0.3"
rusty-hook = "0.11" rusty-hook = "0.11"
rusty-forkfork = "0.4" rusty-forkfork = "0.4"
test-cdylib = "1.1" test-cdylib = "1.1"
tokio = { version = "1", features = ["rt", "sync", "macros", "time"] } tokio = { version = "1", features = ["rt", "sync", "macros", "net", "time"] }
[dependencies] [dependencies]
anyhow = "1.0" anyhow = "1.0"
@ -42,8 +43,8 @@ log = "0.4"
nix = { version = "0.24", features = ["socket", "user", "sched"] } nix = { version = "0.24", features = ["socket", "user", "sched"] }
serde = "1.0" serde = "1.0"
tokio = { version = "1", features = ["rt"] } tokio = { version = "1", features = ["rt"] }
trust-dns-resolver = { version = "0.21", features = ["dns-over-rustls"] } trust-dns-resolver = { version = "0.22", features = ["dns-over-rustls"] }
trust-dns-proto = "0.21" trust-dns-proto = "0.22"
zbus = { version = "3.0", default-features = false, features = ["tokio"] } zbus = { version = "3.0", default-features = false, features = ["tokio"] }
zvariant = "3.6" zvariant = "3.6"
zbus_names = { version = "2.2", optional = true } zbus_names = { version = "2.2", optional = true }

View File

@ -13,7 +13,6 @@ use {
std::sync::Arc, std::sync::Arc,
trust_dns_proto::rr::record_type::RecordType, trust_dns_proto::rr::record_type::RecordType,
trust_dns_proto::rr::{RData, Record}, trust_dns_proto::rr::{RData, Record},
trust_dns_proto::xfer::dns_request::DnsRequestOptions,
trust_dns_resolver::TokioAsyncResolver, trust_dns_resolver::TokioAsyncResolver,
trust_dns_resolver::{lookup::Lookup, lookup_ip::LookupIp}, trust_dns_resolver::{lookup::Lookup, lookup_ip::LookupIp},
}; };
@ -81,16 +80,8 @@ async unsafe fn resolve_hostname_through(
let lookup: std::result::Result<Lookup, _> = match args.family { let lookup: std::result::Result<Lookup, _> = match args.family {
libc::AF_UNSPEC => resolver.lookup_ip(name).await.map(LookupIp::into), libc::AF_UNSPEC => resolver.lookup_ip(name).await.map(LookupIp::into),
libc::AF_INET => { libc::AF_INET => resolver.lookup(name, RecordType::A).await,
resolver libc::AF_INET6 => resolver.lookup(name, RecordType::AAAA).await,
.lookup(name, RecordType::A, DnsRequestOptions::default())
.await
}
libc::AF_INET6 => {
resolver
.lookup(name, RecordType::AAAA, DnsRequestOptions::default())
.await
}
_ => return nss_status::NSS_STATUS_NOTFOUND, _ => return nss_status::NSS_STATUS_NOTFOUND,
}; };

View File

@ -36,14 +36,30 @@ pub async fn restrictions_for(user: Uid) -> anyhow::Result<Vec<Restriction>, any
#[cfg(feature = "integration_test")] #[cfg(feature = "integration_test")]
let proxy = { let proxy = {
use tokio::net::UnixStream;
// During integration testing, we want to connect to a private // During integration testing, we want to connect to a private
// bus name to avoid clashes with existing system services. // bus name to avoid clashes with existing system services.
let connection = zbus::Connection::session().await?; let socket_path = std::env::var("TEST_DBUS_SOCKET")
let dbus_name = std::env::var("TEST_DBUS_SERVICE_NAME") .expect("The test has not set the TEST_DBUS_SOCKET environment variable to the unix socket to connect to");
.expect("The test has not set the TEST_DBUS_SERVICE_NAME environment variable to the private bus name prior to attempting name resolution");
let socket = loop {
match UnixStream::connect(&socket_path).await {
Ok(stream) => break stream,
Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
tokio::task::yield_now().await;
continue;
}
Err(e) => anyhow::bail!(e),
}
};
let connection = zbus::ConnectionBuilder::unix_stream(socket)
.p2p()
.build()
.await?;
MalcontentDnsProxy::builder(&connection) MalcontentDnsProxy::builder(&connection)
.destination(zbus_names::UniqueName::try_from(dbus_name).unwrap())
.unwrap()
.build() .build()
.await .await
.expect("Unable to build DBus proxy object") .expect("Unable to build DBus proxy object")

View File

@ -6,52 +6,48 @@ include!(concat!(
"/src/policy_checker/dbus.rs" "/src/policy_checker/dbus.rs"
)); ));
use {std::collections::HashMap, zbus::dbus_interface}; use {
std::collections::HashMap,
std::sync::atomic::{AtomicUsize, Ordering},
zbus::dbus_interface,
};
#[derive(Debug)] #[derive(Debug)]
pub struct MalcontentDBusMock { pub struct MalcontentDBusMock {
responses: HashMap<Uid, Vec<Restrictions>>, responses: HashMap<Uid, Vec<Restrictions>>,
invocations_left: usize, invocations_left: AtomicUsize,
} }
#[dbus_interface(name = "com.endlessm.ParentalControls.Dns")] #[dbus_interface(name = "com.endlessm.ParentalControls.Dns")]
impl MalcontentDBusMock { impl MalcontentDBusMock {
fn get_restrictions(&mut self, user_id: u32) -> Restrictions { fn get_restrictions(&mut self, user_id: u32) -> Restrictions {
let answers = self let uid = Uid::from_raw(user_id);
.responses let answers = self.responses.get_mut(&uid).expect(&format!(
.get_mut(&Uid::from_raw(user_id))
.expect(&format!(
"MockError: No mocked invocations available for user with id {}", "MockError: No mocked invocations available for user with id {}",
user_id uid
));
let restrictions = answers.pop().expect(&format!(
"MockError: DBus mock is saturated for user with id {}",
user_id
)); ));
self.invocations_left -= 1; let restrictions = answers.pop().expect(&format!(
"MockError: DBus mock is saturated for user with id {}",
uid
));
self.invocations_left.fetch_sub(1, Ordering::SeqCst);
restrictions restrictions
} }
} }
impl MalcontentDBusMock { impl MalcontentDBusMock {
pub fn new(mut responses: HashMap<Uid, Vec<Restrictions>>) -> Self { pub fn new(mut responses: HashMap<Uid, Vec<Restrictions>>) -> Self {
let responses_size: usize = responses let responses_size: usize = responses.values().map(Vec::len).sum();
.values()
.map(|v| {
std::cmp::max(
v.len(),
1, /* 'No restrictions' still counts as one message */
)
})
.sum();
for r in responses.values_mut() { for r in responses.values_mut() {
r.reverse(); // we pop responses from the back, so... r.reverse(); // we pop responses from the back, so...
} }
let ret = Self { let ret = Self {
responses, responses,
invocations_left: responses_size, invocations_left: AtomicUsize::new(responses_size),
}; };
ret ret
@ -60,24 +56,57 @@ impl MalcontentDBusMock {
impl Drop for MalcontentDBusMock { impl Drop for MalcontentDBusMock {
fn drop(&mut self) { fn drop(&mut self) {
let invocations_left = self.invocations_left.load(Ordering::Acquire);
assert_eq!( assert_eq!(
self.invocations_left, 0, invocations_left, 0,
"MockError: During teardown, {} invocations are still left on the mock object", "MockError: During teardown, {} invocations are still left on the mock object",
self.invocations_left invocations_left
); );
} }
} }
pub async fn mock_dbus(responses: HashMap<Uid, Vec<Restrictions>>) -> Result<zbus::Connection> { pub struct DBusServerGuard {
handle: tokio::task::JoinHandle<()>,
}
impl Drop for DBusServerGuard {
fn drop(&mut self) {
self.handle.abort();
}
}
pub async fn mock_dbus(responses: HashMap<Uid, Vec<Restrictions>>) -> Result<DBusServerGuard> {
let guid = zbus::Guid::generate();
let socket_path =
std::path::PathBuf::from(std::env!("CARGO_TARGET_TMPDIR")).join(guid.as_str());
if socket_path.exists() {
std::fs::remove_file(&socket_path)?;
}
let socket = tokio::net::UnixListener::bind(&socket_path)?;
std::env::set_var("TEST_DBUS_SOCKET", &socket_path);
let mock = MalcontentDBusMock::new(responses); let mock = MalcontentDBusMock::new(responses);
let connection = zbus::ConnectionBuilder::session()? let handle = tokio::spawn(async move {
.serve_at("/com/endlessm/ParentalControls/Dns", mock)? let (stream, _) = socket
.build() .accept()
.await?; .await
.expect("Server socket closed unexpectedly");
std::fs::remove_file(socket_path).unwrap(); // Once we accepted, we can already remove the socket
std::env::set_var( let _ = zbus::ConnectionBuilder::unix_stream(stream)
"TEST_DBUS_SERVICE_NAME", .server(&guid)
connection.unique_name().unwrap().as_str(), .p2p()
); .name("com.endlessm.ParentalControls")
Ok(connection) .expect("Unable to serve given dbus name")
.serve_at("/com/endlessm/ParentalControls/Dns", mock)
.expect("Unable to server malcontent dbus mock object")
.build()
.await;
std::future::pending::<()>().await;
});
Ok(DBusServerGuard { handle })
} }

View File

@ -122,12 +122,12 @@ fork_test! {
const HOSTNAME: &str = "wikipedia.org"; const HOSTNAME: &str = "wikipedia.org";
tokio::runtime::Runtime::new().unwrap().block_on(async { tokio::runtime::Runtime::new().unwrap().block_on(async {
for family in [libc::AF_INET, libc::AF_INET6] {
let _dbus = common::mock_dbus(HashMap::from([( let _dbus = common::mock_dbus(HashMap::from([(
getuid(), getuid(),
vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()], vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()],
)])).await?; )])).await?;
for family in [libc::AF_INET, libc::AF_INET6] {
let system_addr = common::resolve_with_system(family, HOSTNAME)?; let system_addr = common::resolve_with_system(family, HOSTNAME)?;
let our_addrs = common::resolve_with_module(family, HOSTNAME)?; let our_addrs = common::resolve_with_module(family, HOSTNAME)?;
assert!(our_addrs.contains(&system_addr)); assert!(our_addrs.contains(&system_addr));
@ -150,7 +150,7 @@ fork_test! {
let system_addr = common::resolve_with_system(libc::AF_INET, HOSTNAME)?; let system_addr = common::resolve_with_system(libc::AF_INET, HOSTNAME)?;
let our_addrs = common::resolve_with_module(libc::AF_INET, HOSTNAME)?; let our_addrs = common::resolve_with_module(libc::AF_INET, HOSTNAME)?;
assert!(!our_addrs.contains(&system_addr)); assert!(!our_addrs.contains(&system_addr), "Resolver answered with {:?}, should not contain {}", our_addrs, system_addr);
assert_eq!(our_addrs, [IpAddr::V4(Ipv4Addr::UNSPECIFIED)]); assert_eq!(our_addrs, [IpAddr::V4(Ipv4Addr::UNSPECIFIED)]);
Ok(()) Ok(())
}) })
@ -170,7 +170,7 @@ fork_test! {
let system_addr = common::resolve_with_system(libc::AF_INET6, HOSTNAME)?; let system_addr = common::resolve_with_system(libc::AF_INET6, HOSTNAME)?;
let our_addrs = common::resolve_with_module(libc::AF_INET6, HOSTNAME)?; let our_addrs = common::resolve_with_module(libc::AF_INET6, HOSTNAME)?;
assert!(!our_addrs.contains(&system_addr)); assert!(!our_addrs.contains(&system_addr), "Resolver answered with {:?}, should not contain {}", our_addrs, system_addr);
assert_eq!(our_addrs, [IpAddr::V6(Ipv6Addr::UNSPECIFIED)]); assert_eq!(our_addrs, [IpAddr::V6(Ipv6Addr::UNSPECIFIED)]);
Ok(()) Ok(())
}) })
@ -183,8 +183,8 @@ fork_test! {
const HOSTNAME: &str = "malware.testcategory.com"; const HOSTNAME: &str = "malware.testcategory.com";
tokio::runtime::Runtime::new().unwrap().block_on(async { tokio::runtime::Runtime::new().unwrap().block_on(async {
for family in [libc::AF_INET, libc::AF_INET6] {
let _dbus = common::mock_dbus(HashMap::from([(getuid(), vec![NO_RESTRICTIONS.clone()])])).await?; let _dbus = common::mock_dbus(HashMap::from([(getuid(), vec![NO_RESTRICTIONS.clone()])])).await?;
for family in [libc::AF_INET, libc::AF_INET6] {
let system_addr = common::resolve_with_system(family, HOSTNAME)?; let system_addr = common::resolve_with_system(family, HOSTNAME)?;
let our_addrs = common::resolve_with_module(family, HOSTNAME)?; let our_addrs = common::resolve_with_module(family, HOSTNAME)?;
assert!(our_addrs.contains(&system_addr)); assert!(our_addrs.contains(&system_addr));