diff --git a/Cargo.toml b/Cargo.toml index 27abd7c..85584b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,11 +27,12 @@ bindgen = "0.60" [dev-dependencies] malcontent-nss = { path = ".", features = ["integration_test"] } env_logger = "0.9" +event-listener = "2.5" futures-util = "0.3" rusty-hook = "0.11" rusty-forkfork = "0.4" test-cdylib = "1.1" -tokio = { version = "1", features = ["rt", "sync", "macros", "time"] } +tokio = { version = "1", features = ["rt", "sync", "macros", "net", "time"] } [dependencies] anyhow = "1.0" @@ -42,8 +43,8 @@ log = "0.4" nix = { version = "0.24", features = ["socket", "user", "sched"] } serde = "1.0" tokio = { version = "1", features = ["rt"] } -trust-dns-resolver = { version = "0.21", features = ["dns-over-rustls"] } -trust-dns-proto = "0.21" +trust-dns-resolver = { version = "0.22", features = ["dns-over-rustls"] } +trust-dns-proto = "0.22" zbus = { version = "3.0", default-features = false, features = ["tokio"] } zvariant = "3.6" zbus_names = { version = "2.2", optional = true } \ No newline at end of file diff --git a/src/gethostbyname.rs b/src/gethostbyname.rs index 59a85b4..311ab5d 100644 --- a/src/gethostbyname.rs +++ b/src/gethostbyname.rs @@ -13,7 +13,6 @@ use { std::sync::Arc, trust_dns_proto::rr::record_type::RecordType, trust_dns_proto::rr::{RData, Record}, - trust_dns_proto::xfer::dns_request::DnsRequestOptions, trust_dns_resolver::TokioAsyncResolver, trust_dns_resolver::{lookup::Lookup, lookup_ip::LookupIp}, }; @@ -81,16 +80,8 @@ async unsafe fn resolve_hostname_through( let lookup: std::result::Result = match args.family { libc::AF_UNSPEC => resolver.lookup_ip(name).await.map(LookupIp::into), - libc::AF_INET => { - resolver - .lookup(name, RecordType::A, DnsRequestOptions::default()) - .await - } - libc::AF_INET6 => { - resolver - .lookup(name, RecordType::AAAA, DnsRequestOptions::default()) - .await - } + libc::AF_INET => resolver.lookup(name, RecordType::A).await, + libc::AF_INET6 => resolver.lookup(name, RecordType::AAAA).await, _ => return nss_status::NSS_STATUS_NOTFOUND, }; diff --git a/src/policy_checker/dbus.rs b/src/policy_checker/dbus.rs index c379894..ad9ab3a 100644 --- a/src/policy_checker/dbus.rs +++ b/src/policy_checker/dbus.rs @@ -36,14 +36,30 @@ pub async fn restrictions_for(user: Uid) -> anyhow::Result, any #[cfg(feature = "integration_test")] let proxy = { + use tokio::net::UnixStream; + // During integration testing, we want to connect to a private // bus name to avoid clashes with existing system services. - let connection = zbus::Connection::session().await?; - let dbus_name = std::env::var("TEST_DBUS_SERVICE_NAME") - .expect("The test has not set the TEST_DBUS_SERVICE_NAME environment variable to the private bus name prior to attempting name resolution"); + let socket_path = std::env::var("TEST_DBUS_SOCKET") + .expect("The test has not set the TEST_DBUS_SOCKET environment variable to the unix socket to connect to"); + + let socket = loop { + match UnixStream::connect(&socket_path).await { + Ok(stream) => break stream, + Err(e) if e.kind() == std::io::ErrorKind::ConnectionRefused => { + tokio::task::yield_now().await; + continue; + } + Err(e) => anyhow::bail!(e), + } + }; + + let connection = zbus::ConnectionBuilder::unix_stream(socket) + .p2p() + .build() + .await?; + MalcontentDnsProxy::builder(&connection) - .destination(zbus_names::UniqueName::try_from(dbus_name).unwrap()) - .unwrap() .build() .await .expect("Unable to build DBus proxy object") diff --git a/tests/common/dbus.rs b/tests/common/dbus.rs index fee255d..d332133 100644 --- a/tests/common/dbus.rs +++ b/tests/common/dbus.rs @@ -6,52 +6,48 @@ include!(concat!( "/src/policy_checker/dbus.rs" )); -use {std::collections::HashMap, zbus::dbus_interface}; +use { + std::collections::HashMap, + std::sync::atomic::{AtomicUsize, Ordering}, + zbus::dbus_interface, +}; #[derive(Debug)] pub struct MalcontentDBusMock { responses: HashMap>, - invocations_left: usize, + invocations_left: AtomicUsize, } #[dbus_interface(name = "com.endlessm.ParentalControls.Dns")] impl MalcontentDBusMock { fn get_restrictions(&mut self, user_id: u32) -> Restrictions { - let answers = self - .responses - .get_mut(&Uid::from_raw(user_id)) - .expect(&format!( - "MockError: No mocked invocations available for user with id {}", - user_id - )); - let restrictions = answers.pop().expect(&format!( - "MockError: DBus mock is saturated for user with id {}", - user_id + let uid = Uid::from_raw(user_id); + let answers = self.responses.get_mut(&uid).expect(&format!( + "MockError: No mocked invocations available for user with id {}", + uid )); - self.invocations_left -= 1; + let restrictions = answers.pop().expect(&format!( + "MockError: DBus mock is saturated for user with id {}", + uid + )); + + self.invocations_left.fetch_sub(1, Ordering::SeqCst); restrictions } } impl MalcontentDBusMock { pub fn new(mut responses: HashMap>) -> Self { - let responses_size: usize = responses - .values() - .map(|v| { - std::cmp::max( - v.len(), - 1, /* 'No restrictions' still counts as one message */ - ) - }) - .sum(); + let responses_size: usize = responses.values().map(Vec::len).sum(); + for r in responses.values_mut() { r.reverse(); // we pop responses from the back, so... } let ret = Self { responses, - invocations_left: responses_size, + invocations_left: AtomicUsize::new(responses_size), }; ret @@ -60,24 +56,57 @@ impl MalcontentDBusMock { impl Drop for MalcontentDBusMock { fn drop(&mut self) { + let invocations_left = self.invocations_left.load(Ordering::Acquire); assert_eq!( - self.invocations_left, 0, + invocations_left, 0, "MockError: During teardown, {} invocations are still left on the mock object", - self.invocations_left + invocations_left ); } } -pub async fn mock_dbus(responses: HashMap>) -> Result { - let mock = MalcontentDBusMock::new(responses); - let connection = zbus::ConnectionBuilder::session()? - .serve_at("/com/endlessm/ParentalControls/Dns", mock)? - .build() - .await?; - - std::env::set_var( - "TEST_DBUS_SERVICE_NAME", - connection.unique_name().unwrap().as_str(), - ); - Ok(connection) +pub struct DBusServerGuard { + handle: tokio::task::JoinHandle<()>, +} + +impl Drop for DBusServerGuard { + fn drop(&mut self) { + self.handle.abort(); + } +} + +pub async fn mock_dbus(responses: HashMap>) -> Result { + let guid = zbus::Guid::generate(); + let socket_path = + std::path::PathBuf::from(std::env!("CARGO_TARGET_TMPDIR")).join(guid.as_str()); + + if socket_path.exists() { + std::fs::remove_file(&socket_path)?; + } + + let socket = tokio::net::UnixListener::bind(&socket_path)?; + std::env::set_var("TEST_DBUS_SOCKET", &socket_path); + + let mock = MalcontentDBusMock::new(responses); + let handle = tokio::spawn(async move { + let (stream, _) = socket + .accept() + .await + .expect("Server socket closed unexpectedly"); + std::fs::remove_file(socket_path).unwrap(); // Once we accepted, we can already remove the socket + + let _ = zbus::ConnectionBuilder::unix_stream(stream) + .server(&guid) + .p2p() + .name("com.endlessm.ParentalControls") + .expect("Unable to serve given dbus name") + .serve_at("/com/endlessm/ParentalControls/Dns", mock) + .expect("Unable to server malcontent dbus mock object") + .build() + .await; + + std::future::pending::<()>().await; + }); + + Ok(DBusServerGuard { handle }) } diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 65cb585..7ae4889 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -122,12 +122,12 @@ fork_test! { const HOSTNAME: &str = "wikipedia.org"; tokio::runtime::Runtime::new().unwrap().block_on(async { - for family in [libc::AF_INET, libc::AF_INET6] { - let _dbus = common::mock_dbus(HashMap::from([( - getuid(), - vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()], - )])).await?; + let _dbus = common::mock_dbus(HashMap::from([( + getuid(), + vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()], + )])).await?; + for family in [libc::AF_INET, libc::AF_INET6] { let system_addr = common::resolve_with_system(family, HOSTNAME)?; let our_addrs = common::resolve_with_module(family, HOSTNAME)?; assert!(our_addrs.contains(&system_addr)); @@ -150,7 +150,7 @@ fork_test! { let system_addr = common::resolve_with_system(libc::AF_INET, HOSTNAME)?; let our_addrs = common::resolve_with_module(libc::AF_INET, HOSTNAME)?; - assert!(!our_addrs.contains(&system_addr)); + assert!(!our_addrs.contains(&system_addr), "Resolver answered with {:?}, should not contain {}", our_addrs, system_addr); assert_eq!(our_addrs, [IpAddr::V4(Ipv4Addr::UNSPECIFIED)]); Ok(()) }) @@ -170,7 +170,7 @@ fork_test! { let system_addr = common::resolve_with_system(libc::AF_INET6, HOSTNAME)?; let our_addrs = common::resolve_with_module(libc::AF_INET6, HOSTNAME)?; - assert!(!our_addrs.contains(&system_addr)); + assert!(!our_addrs.contains(&system_addr), "Resolver answered with {:?}, should not contain {}", our_addrs, system_addr); assert_eq!(our_addrs, [IpAddr::V6(Ipv6Addr::UNSPECIFIED)]); Ok(()) }) @@ -183,8 +183,8 @@ fork_test! { const HOSTNAME: &str = "malware.testcategory.com"; tokio::runtime::Runtime::new().unwrap().block_on(async { + let _dbus = common::mock_dbus(HashMap::from([(getuid(), vec![NO_RESTRICTIONS.clone()])])).await?; for family in [libc::AF_INET, libc::AF_INET6] { - let _dbus = common::mock_dbus(HashMap::from([(getuid(), vec![NO_RESTRICTIONS.clone()])])).await?; let system_addr = common::resolve_with_system(family, HOSTNAME)?; let our_addrs = common::resolve_with_module(family, HOSTNAME)?; assert!(our_addrs.contains(&system_addr));