diff --git a/Cargo.toml b/Cargo.toml index b5ad907..df07ab8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,10 +18,14 @@ panic = "unwind" # We rely on this crate-type = ["cdylib"] name = "nss_malcontent" +[features] +integration_test = ["dep:zbus_names"] + [build-dependencies] bindgen = "0.60" [dev-dependencies] +malcontent-nss = { path = ".", features = ["integration_test"] } event-listener = "2.5" futures-util = "0.3" rusty-hook = "0.11" @@ -41,4 +45,5 @@ tokio = { version = "1", features = ["rt"] } trust-dns-resolver = { version = "0.21", features = ["dns-over-rustls"] } trust-dns-proto = "0.21" zbus = { version = "3.0", default-features = false, features = ["tokio"] } -zvariant = "3.6" \ No newline at end of file +zvariant = "3.6" +zbus_names = { version = "2.2", optional = true } \ No newline at end of file diff --git a/src/policy_checker/mod.rs b/src/policy_checker/mod.rs index e33b94a..e96af2c 100644 --- a/src/policy_checker/mod.rs +++ b/src/policy_checker/mod.rs @@ -11,6 +11,7 @@ use { std::collections::HashMap, std::net::{SocketAddr, TcpStream}, std::sync::{Arc, RwLock}, + std::time::Duration, trust_dns_proto::rr::domain::Name as DomainName, trust_dns_resolver::config as dns_config, trust_dns_resolver::TokioAsyncResolver, @@ -35,14 +36,36 @@ impl PolicyChecker { } } - async fn restrictions<'a>(&'a self, user: Uid) -> Result { + async fn restrictions(&self, user: Uid) -> Result { if user.is_root() { return Ok(vec![]); }; let connection = zbus::Connection::session().await?; + + #[cfg(not(feature = "integration_test"))] let proxy = MalcontentDnsProxy::new(&connection).await?; - Ok(proxy.get_restrictions(user.as_raw()).await?) + + #[cfg(feature = "integration_test")] + let proxy = { + let dbus_name = std::env::var("TEST_DBUS_SERVICE_NAME").map_err(|_| { + anyhow::anyhow!("The test hasn't set the TEST_DBUS_SERVICE_NAME environment var") + })?; + MalcontentDnsProxy::builder(&connection) + .destination(zbus_names::UniqueName::try_from(dbus_name).unwrap()) + .unwrap() + .build() + .await + .expect("Unable to build DBus proxy object") + }; + + let restrictions = proxy.get_restrictions(user.as_raw()).await; + log::trace!( + "malcontent-nss: user {} restrictions are {:?}", + user, + &restrictions + ); + Ok(restrictions?) } pub async fn resolver(&self, user: Option) -> Result>> { @@ -89,12 +112,17 @@ fn resolver_config_for(restrictions: Vec) -> dns_config::ResolverCo restrictions .into_iter() .fold(NsConfig::new(), |mut config, restr| { - let new_config = - if TcpStream::connect(SocketAddr::new(restr.ip, DNS_TLS_PORT)).is_ok() { - NsConfig::from_ips_tls(&[restr.ip], DNS_TLS_PORT, restr.hostname, true) - } else { - NsConfig::from_ips_clear(&[restr.ip], DNS_UDP_PORT, true) - }; + let supports_tls = TcpStream::connect_timeout( + &SocketAddr::new(restr.ip, DNS_TLS_PORT), + Duration::from_secs(1), + ) + .is_ok(); + + let new_config = if supports_tls { + NsConfig::from_ips_tls(&[restr.ip], DNS_TLS_PORT, restr.hostname, true) + } else { + NsConfig::from_ips_clear(&[restr.ip], DNS_UDP_PORT, true) + }; config.merge(new_config); config diff --git a/tests/common/dbus.rs b/tests/common/dbus.rs index 9a30c20..1e88816 100644 --- a/tests/common/dbus.rs +++ b/tests/common/dbus.rs @@ -73,5 +73,21 @@ impl Drop for MalcontentDBusMock { "MockError: During teardown, {} invocations are still left on the mock object", self.invocations_left ); + + self.finished.notify(1); } } + +pub async fn mock_dbus(responses: HashMap>) -> Result { + let mock = MalcontentDBusMock::new(responses); + let connection = zbus::ConnectionBuilder::session()? + .serve_at("/com/endlessm/ParentalControls/Dns", mock)? + .build() + .await?; + + std::env::set_var( + "TEST_DBUS_SERVICE_NAME", + connection.unique_name().unwrap().as_str(), + ); + Ok(connection) +} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 856ad9e..ff9f2a8 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -11,12 +11,9 @@ include!(concat!(env!("OUT_DIR"), "/bindings.rs")); mod dbus; use { - self::dbus::MalcontentDBusMock, anyhow::{anyhow, bail, ensure, Result}, libc::{freeaddrinfo, gai_strerror, getaddrinfo}, nix::sys::socket::{SockaddrLike as _, SockaddrStorage}, - nix::unistd::Uid, - std::collections::HashMap, std::env, std::ffi::{CStr, CString}, std::net::{IpAddr, Ipv4Addr, Ipv6Addr}, @@ -24,11 +21,9 @@ use { std::process::Command, std::str::FromStr, std::sync::Once, - tokio::task, - tokio::task::JoinHandle, }; -pub use self::dbus::{Restriction, Restrictions}; +pub use self::dbus::{mock_dbus, Restriction, Restrictions}; // Adapted from rusty_forkfork (which inherits it from rusty_fork) // to allow a custom pre-fork function @@ -181,19 +176,3 @@ fn convert_addrinfo(sa: &SockaddrStorage) -> Result { bail!("addrinfo is not either an IPv4 or IPv6 address") } } - -pub fn mock_dbus(responses: HashMap>) -> JoinHandle> { - async fn dbus_loop(responses: HashMap>) -> Result<()> { - let mock = MalcontentDBusMock::new(responses); - let waiter = mock.waiter(); - let _connection = zbus::ConnectionBuilder::session()? - .serve_at("/com/endlessm/ParentalControls/Dns", mock)? - .build() - .await?; - - waiter.wait(); - Ok(()) - } - - task::spawn(async { dbus_loop(responses).await }) -} diff --git a/tests/integration_test.rs b/tests/integration_test.rs index 45a303f..78c0ef9 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -11,8 +11,6 @@ use { once_cell::sync::Lazy, std::collections::HashMap, std::net::{IpAddr, Ipv4Addr, Ipv6Addr}, - std::time::Duration, - tokio::time::timeout, }; static CLOUDFLARE_PARENTALCONTROL_ADDRS: Lazy = Lazy::new(|| { @@ -48,10 +46,10 @@ fork_test! { fn application_dns_is_nxdomain() -> Result<()> { common::setup()?; tokio::runtime::Runtime::new().unwrap().block_on(async { - let dbus = common::mock_dbus(HashMap::from([( + let _dbus = common::mock_dbus(HashMap::from([( getuid(), vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()], - )])); + )])).await?; let hostname = std::ffi::CString::new("use-application-dns.net").unwrap(); unsafe { @@ -73,7 +71,7 @@ fork_test! { freeaddrinfo(addr); }; - timeout(Duration::from_secs(1), dbus).await?? + Ok(()) }) } @@ -84,10 +82,10 @@ fork_test! { const HOSTNAME: &str = "gnome.org"; tokio::runtime::Runtime::new().unwrap().block_on(async { - let dbus = common::mock_dbus(HashMap::from([( + let _dbus = common::mock_dbus(HashMap::from([( getuid(), vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()], - )])); + )])).await?; unsafe { let mut addr = std::ptr::null_mut(); @@ -108,11 +106,10 @@ fork_test! { freeaddrinfo(addr); }; - timeout(Duration::from_secs(1), dbus).await?? + Ok(()) }) } - #[test] fn wikipedia_is_unrestricted() -> Result<()> { common::setup()?; @@ -121,16 +118,14 @@ fork_test! { tokio::runtime::Runtime::new().unwrap().block_on(async { for family in [libc::AF_INET, libc::AF_INET6] { - let dbus = common::mock_dbus(HashMap::from([( + let _dbus = common::mock_dbus(HashMap::from([( getuid(), vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()], - )])); + )])).await?; let system_addr = common::resolve_with_system(family, HOSTNAME)?; let our_addr = common::resolve_with_module(family, HOSTNAME)?; assert_eq!(system_addr, our_addr); - - timeout(Duration::from_secs(1), dbus).await???; } Ok(()) }) @@ -143,17 +138,17 @@ fork_test! { const HOSTNAME: &str = "nudity.testcategory.com"; tokio::runtime::Runtime::new().unwrap().block_on(async { - let dbus = common::mock_dbus(HashMap::from([( + let _dbus = common::mock_dbus(HashMap::from([( getuid(), vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()], - )])); + )])).await?; let system_addr = common::resolve_with_system(libc::AF_INET, HOSTNAME)?; let our_addr = common::resolve_with_module(libc::AF_INET, HOSTNAME)?; assert_ne!(system_addr, our_addr); assert_eq!(our_addr, IpAddr::V4(Ipv4Addr::UNSPECIFIED)); - timeout(Duration::from_secs(1), dbus).await?? + Ok(()) }) } @@ -164,22 +159,20 @@ fork_test! { const HOSTNAME: &str = "nudity.testcategory.com"; tokio::runtime::Runtime::new().unwrap().block_on(async { - let dbus = common::mock_dbus(HashMap::from([( + let _dbus = common::mock_dbus(HashMap::from([( getuid(), vec![CLOUDFLARE_PARENTALCONTROL_ADDRS.clone()], - )])); + )])).await; let system_addr = common::resolve_with_system(libc::AF_INET6, HOSTNAME)?; let our_addr = common::resolve_with_module(libc::AF_INET6, HOSTNAME)?; assert_ne!(system_addr, our_addr); assert_eq!(our_addr, IpAddr::V6(Ipv6Addr::UNSPECIFIED)); - - timeout(Duration::from_secs(1), dbus).await?? + Ok(()) }) } #[test] - #[ignore] fn privileged_user_bypasses_restrictions() -> Result<()> { common::setup()?; @@ -187,11 +180,10 @@ fork_test! { tokio::runtime::Runtime::new().unwrap().block_on(async { for family in [libc::AF_INET, libc::AF_INET6] { - let dbus = common::mock_dbus(HashMap::from([(getuid(), vec![ /* no restriction */])])); + let _dbus = common::mock_dbus(HashMap::from([(getuid(), vec![ /* no restriction */])])).await?; let system_addr = common::resolve_with_system(family, HOSTNAME)?; let our_addr = common::resolve_with_module(family, HOSTNAME)?; assert_eq!(system_addr, our_addr); - timeout(Duration::from_secs(1), dbus).await??? } Ok(()) })