From 5dccfafe80ca83e7a795179ca32f4e1166dd0d86 Mon Sep 17 00:00:00 2001 From: Lambda Date: Wed, 17 Jul 2024 18:46:28 +0000 Subject: [PATCH] Refactor server resolution --- src/api/server_server.rs | 3 +- src/api/server_server/resolution.rs | 366 ++++++++++++++++++++++++++ src/api/server_server/send_request.rs | 351 +----------------------- src/service/globals.rs | 4 +- 4 files changed, 383 insertions(+), 341 deletions(-) create mode 100644 src/api/server_server/resolution.rs diff --git a/src/api/server_server.rs b/src/api/server_server.rs index d8b4a46c..6f33e899 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -70,9 +70,10 @@ use crate::{ Ar, Error, PduEvent, Ra, Result, }; +pub(crate) mod resolution; mod send_request; -pub(crate) use send_request::{send_request, FedDest}; +pub(crate) use send_request::send_request; /// # `GET /_matrix/federation/v1/version` /// diff --git a/src/api/server_server/resolution.rs b/src/api/server_server/resolution.rs new file mode 100644 index 00000000..315b5f51 --- /dev/null +++ b/src/api/server_server/resolution.rs @@ -0,0 +1,366 @@ +use std::{ + borrow::Cow, + fmt::Debug, + net::{IpAddr, SocketAddr}, + str::FromStr, +}; + +use ruma::ServerName; +use thiserror::Error; +use tracing::{debug, error, warn}; + +use crate::{services, Result}; +/// Wraps either a literal IP address or a hostname, plus an optional port. +/// +/// # Examples: +/// ```rust +/// # use grapevine::api::server_server::FedDest; +/// # fn main() -> Result<(), std::net::AddrParseError> { +/// FedDest::Literal("198.51.100.3:8448".parse()?); +/// FedDest::Literal("[2001:db8::4:5]:443".parse()?); +/// FedDest::Named("matrix.example.org".to_owned(), "".to_owned()); +/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned()); +/// FedDest::Named("198.51.100.5".to_owned(), "".to_owned()); +/// # Ok(()) +/// # } +/// ``` +#[derive(Clone, Debug, PartialEq, Eq)] +enum FedDest { + BareLiteral(IpAddr), + PortLiteral(SocketAddr), + Named(String, Option), +} + +impl FedDest { + fn host_and_port_or_default(&self) -> (Cow<'_, str>, u16) { + const DEFAULT_PORT: u16 = 8448; + + match self { + FedDest::BareLiteral(addr) => { + (Cow::Owned(addr.to_string()), DEFAULT_PORT) + } + FedDest::PortLiteral(addr) => { + (Cow::Owned(addr.ip().to_string()), addr.port()) + } + FedDest::Named(host, port) => { + (Cow::Borrowed(host), port.unwrap_or(DEFAULT_PORT)) + } + } + } +} + +#[derive(Debug, Error)] +enum InvalidFedDest { + #[error("invalid port {0}")] + InvalidPort(String), +} + +impl FromStr for FedDest { + type Err = InvalidFedDest; + + fn from_str(s: &str) -> Result { + if let Ok(destination) = s.parse::() { + Ok(FedDest::PortLiteral(destination)) + } else if let Ok(ip_addr) = s.parse::() { + Ok(FedDest::BareLiteral(ip_addr)) + } else if let Some((host, port)) = s.split_once(':') { + Ok(FedDest::Named( + host.to_owned(), + Some(port.parse().map_err(|_| { + InvalidFedDest::InvalidPort(port.to_owned()) + })?), + )) + } else { + Ok(FedDest::Named(s.to_owned(), None)) + } + } +} + +#[derive(Debug, Clone)] +enum WellKnownResult { + Success { + delegated_dest: FedDest, + }, + Error, +} + +#[derive(Debug, Clone)] +enum SrvResult { + Success { + host: String, + port: u16, + }, + Error, +} + +#[derive(Debug, Clone)] +struct LookupResult { + well_known: WellKnownResult, + srv: Option, +} + +#[derive(Debug, Clone)] +pub(crate) struct ResolutionResult { + original: FedDest, + // None if `original` is an IP literal or has a port (used as-is) + lookup: Option, +} + +impl ResolutionResult { + pub(crate) fn host_header(&self) -> String { + let dest = match &self.lookup { + None + | Some(LookupResult { + well_known: WellKnownResult::Error, + .. + }) => { + // no lookup performed or well-known lookup failed + &self.original + } + + Some(LookupResult { + well_known: + WellKnownResult::Success { + delegated_dest, + }, + .. + }) => { + // well-known lookup succeeded + delegated_dest + } + }; + + match dest { + FedDest::BareLiteral(addr) => format!("{addr}"), + FedDest::PortLiteral(addr) => format!("{addr}"), + FedDest::Named(host, port) => { + if let Some(port) = port { + format!("{host}:{port}") + } else { + host.clone() + } + } + } + } + + pub(crate) fn base_url(&self) -> String { + let (host, port) = match &self.lookup { + None + | Some(LookupResult { + well_known: WellKnownResult::Error, + srv: None | Some(SrvResult::Error), + }) => { + // all lookups failed, or no lookups were performed + self.original.host_and_port_or_default() + } + + Some(LookupResult { + well_known: + WellKnownResult::Success { + delegated_dest, + }, + srv: None | Some(SrvResult::Error), + }) => { + // SRV lookup failed, but well-known lookup succeeded + delegated_dest.host_and_port_or_default() + } + + Some(LookupResult { + srv: + Some(SrvResult::Success { + host, + port, + }), + .. + }) => { + // SRV lookup succeeded (result of well-known lookup isn't + // relevant) + (Cow::Borrowed(host.as_str()), *port) + } + }; + + format!("https://{host}:{port}") + } +} + +/// Returns: `actual_destination`, `Host` header +/// Implemented according to the specification at +/// Numbers in comments below refer to bullet points in linked section of +/// specification +#[allow(clippy::too_many_lines)] +#[tracing::instrument(ret(level = "debug"))] +pub(crate) async fn find_actual_destination( + destination: &'_ ServerName, +) -> ResolutionResult { + let original: FedDest = destination + .as_str() + .parse() + .expect("ServerName should always be a valid FedDest"); + + let hostname = match &original { + FedDest::BareLiteral(_) | FedDest::PortLiteral(_) => { + debug!("1: IP literal"); + return ResolutionResult { + original, + lookup: None, + }; + } + FedDest::Named(_, Some(_)) => { + debug!("2: Hostname with port"); + return ResolutionResult { + original, + lookup: None, + }; + } + FedDest::Named(host, None) => host, + }; + + debug!("Requesting .well-known"); + let well_known_error = 'well_known: { + let Some(delegated_hostname) = request_well_known(hostname).await + else { + debug!("Invalid/failed .well-known response"); + break 'well_known WellKnownResult::Error; + }; + let Ok(delegated_dest) = delegated_hostname.parse() else { + debug!("Malformed delegation in .well-known"); + break 'well_known WellKnownResult::Error; + }; + + debug!("3: A .well-known file is available"); + + let srv = match &delegated_dest { + FedDest::BareLiteral(_) | FedDest::PortLiteral(_) => { + debug!("3.1: IP literal in .well-known file"); + + None + } + FedDest::Named(_, Some(_)) => { + debug!("3.2: Hostname with port in .well-known file"); + + None + } + FedDest::Named(delegated_hostname, None) => { + let srv = query_and_store_srv_record(delegated_hostname).await; + if let SrvResult::Success { + .. + } = &srv + { + debug!( + "3.3/3.4: SRV lookup of delegated destination \ + successful" + ); + } else { + debug!( + "3.5: SRV lookup failed, using delegated destination" + ); + } + + Some(srv) + } + }; + + return ResolutionResult { + original, + lookup: Some(LookupResult { + well_known: WellKnownResult::Success { + delegated_dest, + }, + srv, + }), + }; + }; + + let srv = query_and_store_srv_record(hostname).await; + if let SrvResult::Success { + .. + } = &srv + { + debug!("4/5: SRV lookup of original destination successful"); + } else { + debug!("6: SRV lookup failed, using original destination"); + } + + ResolutionResult { + original, + lookup: Some(LookupResult { + well_known: well_known_error, + srv: Some(srv), + }), + } +} + +#[tracing::instrument(ret(level = "debug"))] +async fn query_given_srv_record(record: &str) -> SrvResult { + services() + .globals + .dns_resolver() + .srv_lookup(record) + .await + .ok() + .and_then(|srv| { + srv.iter().next().map(|result| SrvResult::Success { + host: result + .target() + .to_string() + .trim_end_matches('.') + .to_owned(), + port: result.port(), + }) + }) + .unwrap_or(SrvResult::Error) +} + +#[tracing::instrument(ret(level = "debug"))] +async fn query_and_store_srv_record(hostname: &'_ str) -> SrvResult { + let hostname = hostname.trim_end_matches('.'); + + let mut result = + query_given_srv_record(&format!("_matrix-fed._tcp.{hostname}.")).await; + if matches!(result, SrvResult::Error) { + result = + query_given_srv_record(&format!("_matrix._tcp.{hostname}.")).await; + } + + let SrvResult::Success { + host, + port, + } = &result + else { + return result; + }; + + if let Ok(override_ip) = + services().globals.dns_resolver().lookup_ip(host).await + { + services() + .globals + .tls_name_override + .write() + .unwrap() + .insert(hostname.to_owned(), (override_ip.iter().collect(), *port)); + } else { + warn!("Using SRV record, but could not resolve to IP"); + } + + result +} + +#[tracing::instrument(ret(level = "debug"))] +async fn request_well_known(destination: &str) -> Option { + let response = services() + .globals + .default_client() + .get(&format!("https://{destination}/.well-known/matrix/server")) + .send() + .await; + debug!("Got well known response"); + if let Err(e) = &response { + debug!("Well known error: {e:?}"); + return None; + } + let text = response.ok()?.text().await; + debug!("Got well known response text"); + let body: serde_json::Value = serde_json::from_str(&text.ok()?).ok()?; + Some(body.get("m.server")?.as_str()?.to_owned()) +} diff --git a/src/api/server_server/send_request.rs b/src/api/server_server/send_request.rs index 22d1c611..8cf77dc2 100644 --- a/src/api/server_server/send_request.rs +++ b/src/api/server_server/send_request.rs @@ -1,8 +1,4 @@ -use std::{ - fmt::Debug, - mem, - net::{IpAddr, SocketAddr}, -}; +use std::{fmt::Debug, mem}; use axum_extra::headers::{Authorization, HeaderMapExt}; use bytes::Bytes; @@ -18,6 +14,7 @@ use ruma::{ use thiserror::Error; use tracing::{debug, error, field, warn}; +use super::resolution::find_actual_destination; use crate::{ observability::{FoundIn, Lookup, METRICS}, services, @@ -25,62 +22,6 @@ use crate::{ Error, Result, }; -/// Wraps either an literal IP address plus port, or a hostname plus complement -/// (colon-plus-port if it was specified). -/// -/// Note: A [`FedDest::Named`] might contain an IP address in string form if -/// there was no port specified to construct a [`SocketAddr`] with. -/// -/// # Examples: -/// ```rust -/// # use grapevine::api::server_server::FedDest; -/// # fn main() -> Result<(), std::net::AddrParseError> { -/// FedDest::Literal("198.51.100.3:8448".parse()?); -/// FedDest::Literal("[2001:db8::4:5]:443".parse()?); -/// FedDest::Named("matrix.example.org".to_owned(), "".to_owned()); -/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned()); -/// FedDest::Named("198.51.100.5".to_owned(), "".to_owned()); -/// # Ok(()) -/// # } -/// ``` -#[derive(Clone, Debug, PartialEq, Eq)] -pub(crate) enum FedDest { - Literal(SocketAddr), - Named(String, String), -} - -impl FedDest { - fn to_https_string(&self) -> String { - match self { - Self::Literal(addr) => format!("https://{addr}"), - Self::Named(host, port) => format!("https://{host}{port}"), - } - } - - fn to_uri_string(&self) -> String { - match self { - Self::Literal(addr) => addr.to_string(), - Self::Named(host, port) => format!("{host}{port}"), - } - } - - fn hostname(&self) -> String { - match &self { - Self::Literal(addr) => addr.ip().to_string(), - Self::Named(host, _) => host.clone(), - } - } - - fn port(&self) -> Option { - match &self { - Self::Literal(addr) => Some(addr.port()), - Self::Named(_, port) => { - port.strip_prefix(':').and_then(|x| x.parse().ok()) - } - } - } -} - #[derive(Debug, Error)] enum RequestSignError { #[error("invalid JSON in request body")] @@ -258,29 +199,27 @@ where .get(destination) .cloned(); - let (actual_destination, host) = if let Some(result) = cached_result { + let resolution = if let Some(result) = cached_result { METRICS.record_lookup(Lookup::FederationDestination, FoundIn::Cache); result } else { write_destination_to_cache = true; - let result = find_actual_destination(destination).await; - - (result.0, result.1.to_uri_string()) + find_actual_destination(destination).await }; - let actual_destination_str = actual_destination.to_https_string(); + let base_url = resolution.base_url(); let http_request = request .try_into_http_request::>( - &actual_destination_str, + &base_url, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_11], ) .map_err(|error| { warn!( %error, - actual_destination = actual_destination_str, + base_url, "Failed to serialize request", ); Error::BadServerResponse("Invalid request") @@ -298,10 +237,12 @@ where let response = T::IncomingResponse::try_from_http_response(http_response); if response.is_ok() && write_destination_to_cache { METRICS.record_lookup(Lookup::FederationDestination, FoundIn::Remote); - services().globals.actual_destination_cache.write().await.insert( - OwnedServerName::from(destination), - (actual_destination, host), - ); + services() + .globals + .actual_destination_cache + .write() + .await + .insert(OwnedServerName::from(destination), resolution); } response.map_err(|e| { @@ -309,269 +250,3 @@ where Error::BadServerResponse("Server returned bad 200 response.") }) } - -fn get_ip_with_port(destination_str: &str) -> Option { - if let Ok(destination) = destination_str.parse::() { - Some(FedDest::Literal(destination)) - } else if let Ok(ip_addr) = destination_str.parse::() { - Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) - } else { - None - } -} - -fn add_port_to_hostname(destination_str: &str) -> FedDest { - let (host, port) = match destination_str.find(':') { - None => (destination_str, ":8448"), - Some(pos) => destination_str.split_at(pos), - }; - FedDest::Named(host.to_owned(), port.to_owned()) -} - -/// Returns: `actual_destination`, `Host` header -/// Implemented according to the specification at -/// Numbers in comments below refer to bullet points in linked section of -/// specification -#[allow(clippy::too_many_lines)] -#[tracing::instrument(ret(level = "debug"))] -async fn find_actual_destination( - destination: &'_ ServerName, -) -> (FedDest, FedDest) { - debug!("Finding actual destination"); - let destination_str = destination.as_str().to_owned(); - let mut hostname = destination_str.clone(); - let actual_destination = match get_ip_with_port(&destination_str) { - Some(host_port) => { - debug!("1: IP literal with provided or default port"); - host_port - } - None => { - if let Some(pos) = destination_str.find(':') { - debug!("2: Hostname with included port"); - let (host, port) = destination_str.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - debug!(%destination, "Requesting well known"); - if let Some(delegated_hostname) = - request_well_known(destination.as_str()).await - { - debug!("3: A .well-known file is available"); - hostname = add_port_to_hostname(&delegated_hostname) - .to_uri_string(); - if let Some(host_and_port) = - get_ip_with_port(&delegated_hostname) - { - host_and_port - } else if let Some(pos) = delegated_hostname.find(':') { - debug!("3.2: Hostname with port in .well-known file"); - let (host, port) = delegated_hostname.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - debug!("Delegated hostname has no port in this branch"); - if let Some(hostname_override) = - query_srv_record(&delegated_hostname).await - { - debug!("3.3: SRV lookup successful"); - let force_port = hostname_override.port(); - - if let Ok(override_ip) = services() - .globals - .dns_resolver() - .lookup_ip(hostname_override.hostname()) - .await - { - services() - .globals - .tls_name_override - .write() - .unwrap() - .insert( - delegated_hostname.clone(), - ( - override_ip.iter().collect(), - force_port.unwrap_or(8448), - ), - ); - } else { - warn!( - "Using SRV record, but could not resolve \ - to IP" - ); - } - - if let Some(port) = force_port { - FedDest::Named( - delegated_hostname, - format!(":{port}"), - ) - } else { - add_port_to_hostname(&delegated_hostname) - } - } else { - debug!( - "3.4: No SRV records, just use the hostname \ - from .well-known" - ); - add_port_to_hostname(&delegated_hostname) - } - } - } else { - debug!("4: No .well-known or an error occured"); - if let Some(hostname_override) = - query_srv_record(&destination_str).await - { - debug!("4: SRV record found"); - let force_port = hostname_override.port(); - - if let Ok(override_ip) = services() - .globals - .dns_resolver() - .lookup_ip(hostname_override.hostname()) - .await - { - services() - .globals - .tls_name_override - .write() - .unwrap() - .insert( - hostname.clone(), - ( - override_ip.iter().collect(), - force_port.unwrap_or(8448), - ), - ); - } else { - warn!( - "Using SRV record, but could not resolve to IP" - ); - } - - if let Some(port) = force_port { - FedDest::Named(hostname.clone(), format!(":{port}")) - } else { - add_port_to_hostname(&hostname) - } - } else { - debug!("5: No SRV record found"); - add_port_to_hostname(&destination_str) - } - } - } - } - }; - debug!(?actual_destination, "Resolved actual destination"); - - // Can't use get_ip_with_port here because we don't want to add a port - // to an IP address if it wasn't specified - let hostname = if let Ok(addr) = hostname.parse::() { - FedDest::Literal(addr) - } else if let Ok(addr) = hostname.parse::() { - FedDest::Named(addr.to_string(), ":8448".to_owned()) - } else if let Some(pos) = hostname.find(':') { - let (host, port) = hostname.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - FedDest::Named(hostname, ":8448".to_owned()) - }; - (actual_destination, hostname) -} - -#[tracing::instrument(ret(level = "debug"))] -async fn query_given_srv_record(record: &str) -> Option { - services() - .globals - .dns_resolver() - .srv_lookup(record) - .await - .map(|srv| { - srv.iter().next().map(|result| { - FedDest::Named( - result - .target() - .to_string() - .trim_end_matches('.') - .to_owned(), - format!(":{}", result.port()), - ) - }) - }) - .unwrap_or(None) -} - -#[tracing::instrument(ret(level = "debug"))] -async fn query_srv_record(hostname: &'_ str) -> Option { - let hostname = hostname.trim_end_matches('.'); - - if let Some(host_port) = - query_given_srv_record(&format!("_matrix-fed._tcp.{hostname}.")).await - { - Some(host_port) - } else { - query_given_srv_record(&format!("_matrix._tcp.{hostname}.")).await - } -} - -#[tracing::instrument(ret(level = "debug"))] -async fn request_well_known(destination: &str) -> Option { - let response = services() - .globals - .default_client() - .get(&format!("https://{destination}/.well-known/matrix/server")) - .send() - .await; - debug!("Got well known response"); - if let Err(error) = &response { - debug!(%error, "Failed to request .well-known"); - return None; - } - let text = response.ok()?.text().await; - debug!("Got well known response text"); - let body: serde_json::Value = serde_json::from_str(&text.ok()?).ok()?; - Some(body.get("m.server")?.as_str()?.to_owned()) -} - -#[cfg(test)] -mod tests { - use super::{add_port_to_hostname, get_ip_with_port, FedDest}; - - #[test] - fn ips_get_default_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1"), - Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("dead:beef::"), - Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) - ); - } - - #[test] - fn ips_keep_custom_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1:1234"), - Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("[dead::beef]:8933"), - Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) - ); - } - - #[test] - fn hostnames_get_default_ports() { - assert_eq!( - add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) - ); - } - - #[test] - fn hostnames_keep_custom_ports() { - assert_eq!( - add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) - ); - } -} diff --git a/src/service/globals.rs b/src/service/globals.rs index 75b545fa..4d083fe3 100644 --- a/src/service/globals.rs +++ b/src/service/globals.rs @@ -33,7 +33,7 @@ use tracing::{error, Instrument}; use trust_dns_resolver::TokioAsyncResolver; use crate::{ - api::server_server::FedDest, + api::server_server::resolution::ResolutionResult, observability::FilterReloadHandles, service::media::MediaFileKey, services, @@ -41,7 +41,7 @@ use crate::{ Config, Error, Result, }; -type WellKnownMap = HashMap; +type WellKnownMap = HashMap; type TlsNameMap = HashMap, u16)>; // Time if last failed try, number of failed tries type RateLimitState = (Instant, u32);