malcontent/nss/resolver.cc

305 lines
9.8 KiB
C++
Raw Normal View History

2024-01-05 18:10:45 +01:00
// SPDX-FileCopyrightText: Matteo Settenvini <matteo.settenvini@montecristosoftware.eu>
// SPDX-License-Identifier: GPL-3.0-or-later
#include "resolver.hh"
#include "cares_init.hh"
#include "helpers.hh"
#include "logger.hh"
#include "wrapper.hh"
#include <ares.h>
#include <arpa/nameser.h>
#include <cstddef>
#include <new>
#include <nss.h>
#include <sys/epoll.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <cerrno>
#include <cstring>
#include <format>
#include <regex>
#include <system_error>
#include <stdexcept>
using namespace std::literals;
namespace /* anonymous */ {
constexpr const auto CANARY_HOSTNAME = "use-application-dns.net"sv;
auto init_channel(const std::vector<std::string> dns) -> ares_channel_t *;
auto setup_servers(ares_channel_t *channel, const std::vector<std::string> dns) -> void;
auto parse_address(const std::string& addr, ares_addr_port_node& into) -> void;
struct CAresAddrListDeletor {
auto operator()(ares_addr_port_node *list) const -> void {
if(list == nullptr) {
return;
}
operator()(list->next);
delete list;
}
};
using CAresAddrList = std::unique_ptr<ares_addr_port_node, CAresAddrListDeletor>;
struct CallbackArgs {
malcontent::Resolver *resolver;
malcontent::ResolverArgs *args;
nss_status return_status = NSS_STATUS_TRYAGAIN;
};
} // ~ anonymous namespace
malcontent::Resolver::Resolver(std::vector<std::string> dns)
: _ensure_cares_initialized(CAresLibrary::instance())
, _channel(init_channel(std::move(dns)))
{
}
auto malcontent::Resolver::resolve(ResolverArgs& args) -> nss_status {
if (args.name == CANARY_HOSTNAME) {
// disable DoH if user did not explicitly turn it on
set_if_valid(args.errnop, 0);
set_if_valid(args.h_errnop, HErrno::HostNotFound);
return NSS_STATUS_SUCCESS;
}
ns_type type;
switch(args.family) {
case AF_INET: {
type = ns_t_a;
break;
}
case AF_INET6: {
type = ns_t_a6;
break;
}
default: {
throw std::invalid_argument("only AF_INET and AF_INET6 are supported families");
}
}
CallbackArgs closure { .resolver = this, .args = &args };
ares_query(_channel.get(), args.name, ns_c_in, type, &Resolver::resolve_cb, &closure);
fd_set readers, writers;
FD_ZERO(&readers);
FD_ZERO(&writers);
int nfds = ares_fds(_channel.get(), &readers, &writers);
int epollfd = epoll_create1(EPOLL_CLOEXEC);
if (epollfd == -1) {
throw std::system_error{ errno, std::system_category(), "epoll_create1" };
}
// translate from obsolete select() to epoll(),
// as calling process might use a big number
// of file descriptors.
epoll_event ev;
for (int& i = ev.data.fd = 0; i < nfds; ++i) {
if (FD_ISSET(i, &readers)) {
ev.events = EPOLLIN;
} else if (FD_ISSET(i, &writers)) {
ev.events = EPOLLOUT;
} else {
continue;
}
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, i, &ev) == -1) {
throw std::system_error{ errno, std::system_category(), "epoll_ctl" };
}
}
timeval tv;
while (true) {
epoll_event ev;
int timeout_ms = 0;
auto tvp = ares_timeout(_channel.get(), nullptr, &tv);
if (tvp != nullptr) {
timeout_ms = tvp->tv_sec * 1000 + tvp->tv_usec / 1000;
}
auto nfds = epoll_wait(epollfd, &ev, 1, timeout_ms);
if (nfds == -1) {
throw std::system_error{ errno, std::system_category(), "epoll_wait" };
} else if (nfds == 0) {
// timeout or end of processing
break;
}
if (ev.events & EPOLLIN) {
ares_process_fd(_channel.get(), ev.data.fd, ARES_SOCKET_BAD);
} else if (ev.events & EPOLLOUT) {
ares_process_fd(_channel.get(), ARES_SOCKET_BAD, ev.data.fd);
}
}
return closure.return_status;
}
auto malcontent::Resolver::resolve_cb(void *arg,
int status,
int timeouts,
unsigned char *abuf,
int alen) -> void {
auto& closure = *static_cast<CallbackArgs *>(arg);
closure.return_status = closure.resolver->resolved(*closure.args, status, timeouts, abuf, alen);
}
auto malcontent::Resolver::resolved(ResolverArgs& args,
int status,
int /*timeouts*/,
unsigned char *abuf,
int alen) -> nss_status {
using std::swap;
switch (status) {
case ARES_SUCCESS: {
hostent *results = nullptr;
int parse_ret;
switch(args.family) {
case AF_INET: {
int n_ttls = 1;
ares_addrttl ttl;
parse_ret = ares_parse_a_reply(abuf, alen, &results, &ttl, &n_ttls);
set_if_valid(args.ttlp, n_ttls == 1 ? ttl.ttl : 0);
break;
}
case AF_INET6: {
int n_ttls = 1;
ares_addr6ttl ttl;
parse_ret = ares_parse_aaaa_reply(abuf, alen, &results, &ttl, &n_ttls);
set_if_valid(args.ttlp, n_ttls == 1 ? ttl.ttl : 0);
break;
}
default: {
throw std::invalid_argument("only AF_INET and AF_INET6 are supported families");
}
}
if (parse_ret != ARES_SUCCESS) {
set_if_valid(args.errnop, EAGAIN);
set_if_valid(args.h_errnop, HErrno::Internal);
return NSS_STATUS_TRYAGAIN;
}
try {
copy_hostent(*results, *args.result, args.buffer, args.buflen);
ares_free_hostent(results);
} catch (const std::bad_alloc&) {
// buffer is too small
ares_free_hostent(results);
set_if_valid(args.errnop, ERANGE);
set_if_valid(args.h_errnop, HErrno::Internal);
return NSS_STATUS_TRYAGAIN;
}
set_if_valid(args.errnop, 0);
set_if_valid(args.h_errnop, HErrno::Success);
return NSS_STATUS_SUCCESS;
}
case ARES_ENOTFOUND: {
set_if_valid(args.errnop, 0);
set_if_valid(args.h_errnop, HErrno::HostNotFound);
return NSS_STATUS_SUCCESS;
}
case ARES_ENODATA: {
set_if_valid(args.errnop, 0);
set_if_valid(args.h_errnop, HErrno::NoData);
return NSS_STATUS_SUCCESS;
}
case ARES_ETIMEOUT: {
set_if_valid(args.errnop, EAGAIN);
set_if_valid(args.h_errnop, HErrno::TryAgain);
return NSS_STATUS_SUCCESS;
}
case ARES_ECANCELLED:
case ARES_EDESTRUCTION:
case ARES_ENOMEM:
default: {
set_if_valid(args.errnop, EAGAIN);
set_if_valid(args.h_errnop, HErrno::Internal);
return NSS_STATUS_TRYAGAIN;
}
}
}
namespace /* anonymous */ {
auto init_channel(const std::vector<std::string> dns) -> ares_channel_t * {
ares_channel_t *channel = nullptr;
ares_init(&channel);
setup_servers(channel, std::move(dns));
return channel;
}
auto setup_servers(ares_channel_t *channel, const std::vector<std::string> dns) -> void {
CAresAddrList list;
for (auto it = dns.crbegin(); it != dns.crend(); ++it) {
auto new_node = std::make_unique<ares_addr_port_node>();
try {
parse_address(*it, *new_node);
auto new_node_unsafe = new_node.release();
new_node_unsafe->next = list.release();
list.reset(new_node_unsafe);
malcontent::Logger::debug(std::format("adding {} to the list of user DNS resolvers", *it));
} catch (const std::exception& e) {
malcontent::Logger::error(e.what());
throw;
}
}
int ret = ares_set_servers_ports(channel, list.get());
if (ret != ARES_SUCCESS) {
const auto err = std::string("ares_set_server_ports: ") + ares_strerror(ret);
throw std::system_error(ret, std::generic_category(), err.c_str());
}
}
auto parse_address(const std::string& addr, ares_addr_port_node& into) -> void {
static const auto ADDR4_REGEX = std::regex{ R"(([0-9\.]+)(?::([0-9]+))?(?:#(.*))?)" };
static const auto ADDR6_REGEX = std::regex{ R"((?:([[0-9a-f:]+)|\[([0-9a-f:]+)\]:([0-9]+)?)(?:#(.*))?)" };
try {
std::smatch matches;
size_t ip_idx, port_idx, host_idx;
if (std::regex_match(addr, matches, ADDR4_REGEX)) {
into.family = AF_INET;
ip_idx = 1;
port_idx = 2;
host_idx = 3;
} else if (std::regex_match(addr, matches, ADDR6_REGEX)) {
into.family = AF_INET6;
ip_idx = matches[1].matched ? 1 : 2;
port_idx = 3;
host_idx = 4;
} else {
throw std::invalid_argument{"expecting '<ip>[:port][#hostname]'"};
}
if (ares_inet_pton(into.family, matches[ip_idx].str().c_str(), &into.addr) <= 0) {
throw std::system_error {errno, std::system_category()};
}
if (matches[host_idx].matched) { // hostname -> TLS
// FIXME: as of now we are ignoring the hostname verification
into.tcp_port = matches[port_idx].matched ? stoi(matches[port_idx]) : 853;
} else { // no hostname -> no TLS
into.udp_port = matches[port_idx].matched ? stoi(matches[port_idx]) : 53;
}
} catch (const std::exception& e) {
throw std::invalid_argument("unable to parse DNS server address '" + addr + "': " + e.what());
}
}
} // ~ namespace anonymous