use std::{ collections::HashSet, convert::Infallible, future::Future, net::SocketAddr as IpSocketAddr, os::unix::net::SocketAddr as UnixSocketAddr, sync::atomic, time::Duration, }; use axum::{ extract::{ connect_info::{Connected, IntoMakeServiceWithConnectInfo}, ConnectInfo, DefaultBodyLimit, FromRequestParts, MatchedPath, }, middleware::AddExtension, response::IntoResponse, routing::{any, get, on, MethodFilter, Route}, Router, }; use axum_server::{ accept::Accept, bind, service::SendService, tls_rustls::{RustlsAcceptor, RustlsConfig}, Address, Server, }; use http::{ header::{self, HeaderName}, Method, StatusCode, Uri, }; use hyper::body::Incoming; use proxy_header::ProxyHeader; use ruma::api::{ client::{ error::{Error as RumaError, ErrorBody, ErrorKind}, uiaa::UiaaResponse, }, federation::discovery::get_server_version, IncomingRequest, }; use strum::Display; use tokio::{ io::{AsyncRead, AsyncWrite}, signal, task::JoinSet, }; use tower::{Layer, Service, ServiceBuilder}; use tower_http::{ cors::{self, CorsLayer}, trace::TraceLayer, ServiceBuilderExt as _, }; use tracing::{debug, info, info_span, warn, Instrument}; use super::ServeArgs; use crate::{ api::{ client_server, ruma_wrapper::{Ar, Ra}, server_server::{self, AllowLoopbackRequests, LogRequestError}, well_known, }, config::{self, Config, ListenComponent, ListenConfig, ListenTransport}, database::KeyValueDatabase, error, observability, services, set_application_state, utils::{ self, error::{Error, Result}, proxy_protocol::{ProxyAcceptor, ProxyAcceptorConfig}, }, ApplicationState, Services, }; pub(crate) async fn run(args: ServeArgs) -> Result<(), error::ServeCommand> { use error::ServeCommand as Error; let config = config::load(args.config.config.as_ref(), false).await?; rustls::crypto::ring::default_provider() .install_default() .expect("rustls default crypto provider should not be already set"); let (_guard, reload_handles) = observability::init(&config)?; info!("Loading database"); let db = Box::leak(Box::new( KeyValueDatabase::load_or_create(&config) .map_err(Error::DatabaseError)?, )); // This struct will remove old Unix sockets once it's dropped. let _clean_up_socks = CleanUpUnixSockets(config.listen.clone()); Services::new(db, config, Some(reload_handles)) .map_err(Error::InitializeServices)? .install(); services().globals.err_if_server_name_changed()?; db.apply_migrations().await.map_err(Error::DatabaseError)?; info!("Starting background tasks"); services().admin.start_handler(); services().sending.start_handler(); KeyValueDatabase::start_cleanup_task(); services().globals.set_emergency_access(); info!("Starting server"); run_server().await?; Ok(()) } struct CleanUpUnixSockets(Vec); impl Drop for CleanUpUnixSockets { fn drop(&mut self) { // Remove old Unix sockets for listen in &self.0 { if let ListenTransport::Unix { path, .. } = &listen.transport { info!( path = path.display().to_string(), "Removing Unix socket" ); if let Err(error) = std::fs::remove_file(path) { warn!(%error, "Couldn't remove Unix socket"); } } } } } #[tracing::instrument] async fn federation_self_test() -> Result<()> { let response = server_server::send_request( &services().globals.config.server_name, get_server_version::v1::Request {}, LogRequestError::Yes, AllowLoopbackRequests::Yes, ) .await?; if response .server .as_ref() .is_none_or(|s| s.name.as_deref() != Some(env!("CARGO_PKG_NAME"))) { error!(?response, "unexpected server version"); return Err(Error::BadConfig( "Got unexpected version from our own version endpoint", )); } Ok(()) } // A trait we'll implement on `axum_server::Handle` in order to be able to // shutdown handles regardless of their generics. trait ServerHandle: Send { fn shutdown(&self, timeout: Option); } impl ServerHandle for axum_server::Handle { fn shutdown(&self, timeout: Option) { self.graceful_shutdown(timeout); } } /// This type is needed to allow us to find out where incoming connections came /// from. Before Unix socket support, we could simply use `IpSocketAddr` here, /// but this is no longer possible. #[derive(Clone, Display)] enum AddrConnectInfo { #[strum(to_string = "{0}")] Ip(IpSocketAddr), #[strum(to_string = "[unix socket]")] UnixSocket, #[strum(to_string = "[unknown]")] Unknown, } impl Connected for AddrConnectInfo { fn connect_info(target: IpSocketAddr) -> Self { Self::Ip(target) } } impl Connected for AddrConnectInfo { fn connect_info(_target: UnixSocketAddr) -> Self { // The `UnixSocketAddr` we get here is one that we can't recover the // path from (`as_pathname` returns `None`), so there's no point // in saving it (we only use all this for logging). Self::UnixSocket } } struct ServerSpawner<'cfg, M> { config: &'cfg Config, middlewares: M, tls_config: Option, proxy_config: ProxyAcceptorConfig, servers: JoinSet<(ListenConfig, std::io::Result<()>)>, handles: Vec>, } impl<'cfg, M> ServerSpawner<'cfg, M> where M: Layer + Clone + Send + 'static, M::Service: Service< axum::extract::Request, Response = axum::response::Response, Error = Infallible, > + Clone + Send + 'static, >::Future: Send + 'static, { async fn new( config: &'cfg Config, middlewares: M, ) -> Result { let tls_config = if let Some(tls) = &config.tls { Some( RustlsConfig::from_pem_file(&tls.certs, &tls.key) .await .map_err(|err| error::Serve::LoadCerts { certs: tls.certs.clone(), key: tls.key.clone(), err, })?, ) } else { None }; let proxy_config = ProxyAcceptorConfig::default(); Ok(Self { config, middlewares, tls_config, proxy_config, servers: JoinSet::new(), handles: Vec::new(), }) } /// Returns a function that transforms a lower-layer acceptor into a TLS /// acceptor. fn tls_acceptor_factory( &self, listen: &ListenConfig, ) -> Result RustlsAcceptor, error::Serve> { let config = self .tls_config .clone() .ok_or_else(|| error::Serve::NoTlsCerts(listen.clone()))?; Ok(|inner| RustlsAcceptor::new(config).acceptor(inner)) } /// Returns a function that transforms a lower-layer acceptor into a Proxy /// Protocol acceptor. fn proxy_acceptor_factory(&self) -> impl FnOnce(A) -> ProxyAcceptor { let config = self.proxy_config.clone(); |inner| ProxyAcceptor::new(inner, config) } fn spawn_server_inner( &mut self, listen: ListenConfig, server: Server, app: IntoMakeServiceWithConnectInfo, ) where AddrConnectInfo: Connected, Addr: Address + Send + 'static, Addr::Stream: Send, Addr::Listener: Send, A: Accept< Addr::Stream, AddExtension>, > + Clone + Send + Sync + 'static, A::Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync, A::Service: SendService> + Send, A::Future: Send, { let handle = axum_server::Handle::new(); let server = server.handle(handle.clone()).serve(app); self.servers.spawn(async move { let result = server.await; (listen, result) }); self.handles.push(Box::new(handle)); } fn spawn_server( &mut self, listen: ListenConfig, ) -> Result<(), error::Serve> { let app = routes(self.config, &listen.components) .layer(self.middlewares.clone()) .into_make_service_with_connect_info::(); match &listen.transport { ListenTransport::Tcp { address, port, tls, proxy_protocol, } => { let addr = IpSocketAddr::from((*address, *port)); let server = bind(addr); match (tls, proxy_protocol) { (false, false) => { self.spawn_server_inner(listen, server, app); } (false, true) => { let server = server.map(self.proxy_acceptor_factory()); self.spawn_server_inner(listen, server, app); } (true, false) => { let server = server.map(self.tls_acceptor_factory(&listen)?); self.spawn_server_inner(listen, server, app); } (true, true) => { let server = server .map(self.proxy_acceptor_factory()) .map(self.tls_acceptor_factory(&listen)?); self.spawn_server_inner(listen, server, app); } } Ok(()) } ListenTransport::Unix { path, proxy_protocol, } => { let addr = match UnixSocketAddr::from_pathname(path) { Ok(addr) => addr, Err(e) => { // We can't use `map_err` here, as that would move // `listen` into a closure, preventing us from using it // later. return Err(error::Serve::Listen(e, listen)); } }; let server = bind(addr); if *proxy_protocol { let server = server.map(self.proxy_acceptor_factory()); self.spawn_server_inner(listen, server, app); } else { self.spawn_server_inner(listen, server, app); } Ok(()) } } } } #[allow(clippy::too_many_lines)] async fn run_server() -> Result<(), error::Serve> { use error::Serve as Error; let config = &services().globals.config; let x_requested_with = HeaderName::from_static("x-requested-with"); let middlewares = ServiceBuilder::new() .sensitive_headers([header::AUTHORIZATION]) .layer(axum::middleware::from_fn(spawn_task)) .layer( TraceLayer::new_for_http() .make_span_with(|request: &http::Request<_>| { let endpoint = if let Some(endpoint) = request.extensions().get::() { endpoint.as_str() } else { request.uri().path() }; let method = request.method(); let source_address = request .extensions() .get::>() .map_or_else( || { request .extensions() .get::>() .map(|ConnectInfo(addr)| addr.clone()) }, |h| { h.proxied_address().map(|addr| { AddrConnectInfo::Ip(addr.source) }) }, ) .unwrap_or(AddrConnectInfo::Unknown); tracing::info_span!( "http_request", otel.name = format!("{method} {endpoint}"), %method, %endpoint, %source_address, ) }) .on_request( |request: &http::Request<_>, _span: &tracing::Span| { // can be enabled selectively using `filter = // grapevine[incoming_request_curl]=trace` in config tracing::trace_span!("incoming_request_curl").in_scope( || { tracing::trace!( cmd = utils::curlify(request), "curl command line for incoming request \ (guessed hostname)" ); }, ); }, ), ) .layer(axum::middleware::from_fn(unrecognized_method)) .layer( CorsLayer::new() .allow_origin(cors::Any) .allow_methods([ Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::OPTIONS, ]) .allow_headers([ header::ORIGIN, x_requested_with, header::CONTENT_TYPE, header::ACCEPT, header::AUTHORIZATION, ]) .max_age(Duration::from_secs(86400)), ) .layer(DefaultBodyLimit::max( config .max_request_size .try_into() .expect("failed to convert max request size"), )) .layer(axum::middleware::from_fn(observability::http_metrics_layer)) .layer(axum::middleware::from_fn(observability::traceresponse_layer)); let mut spawner = ServerSpawner::new(config, middlewares).await?; if config.listen.is_empty() { return Err(Error::NoListeners); } for listen in &config.listen { info!(listener = %listen, "Listening for incoming traffic"); spawner.spawn_server(listen.clone())?; } tokio::spawn(handle_signals(spawner.tls_config, spawner.handles)); if config.federation.enable && config.federation.self_test { federation_self_test() .await .map_err(error::Serve::FederationSelfTestFailed)?; debug!("Federation self-test completed successfully"); } set_application_state(ApplicationState::Ready); while let Some(result) = spawner.servers.join_next().await { let (listen, result) = result.expect("should be able to join server task"); result.map_err(|err| Error::Listen(err, listen))?; } Ok(()) } /// Ensures the request runs in a new tokio thread. /// /// The axum request handler task gets cancelled if the connection is shut down; /// by spawning our own task, processing continue after the client disconnects. async fn spawn_task( req: axum::extract::Request, next: axum::middleware::Next, ) -> std::result::Result { if services().globals.shutdown.load(atomic::Ordering::Relaxed) { return Err(StatusCode::SERVICE_UNAVAILABLE); } tokio::spawn(next.run(req)) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } async fn unrecognized_method( req: axum::extract::Request, next: axum::middleware::Next, ) -> std::result::Result { let method = req.method().clone(); let uri = req.uri().clone(); let inner = next.run(req).await; if inner.status() == StatusCode::METHOD_NOT_ALLOWED { warn!(%method, %uri, "Method not allowed"); return Ok(Ra(UiaaResponse::MatrixError(RumaError { body: ErrorBody::Standard { kind: ErrorKind::Unrecognized, message: "M_UNRECOGNIZED: Unrecognized request".to_owned(), }, status_code: StatusCode::METHOD_NOT_ALLOWED, })) .into_response()); } Ok(inner) } /// Routes for legacy unauthenticated `/_matrix/media/*` APIs (used by both /// clients and federation) fn legacy_media_routes(config: &Config) -> Router { use client_server as c2s; let router = Router::new(); // deprecated, but unproblematic let router = router.ruma_route(c2s::get_media_config_legacy_route); if config.media.allow_unauthenticated_access { router .ruma_route(c2s::get_content_legacy_route) .ruma_route(c2s::get_content_as_filename_legacy_route) .ruma_route(c2s::get_content_thumbnail_legacy_route) } else { router .route( "/_matrix/media/v3/download/*path", any(unauthenticated_media_disabled), ) .route( "/_matrix/media/v3/thumbnail/*path", any(unauthenticated_media_disabled), ) } } #[allow(clippy::too_many_lines)] fn client_routes() -> Router { use client_server as c2s; let router = Router::new() .ruma_route(c2s::get_supported_versions_route) .ruma_route(c2s::get_register_available_route) .ruma_route(c2s::register_route) .ruma_route(c2s::get_login_types_route) .ruma_route(c2s::login_route) .ruma_route(c2s::whoami_route) .ruma_route(c2s::logout_route) .ruma_route(c2s::logout_all_route) .ruma_route(c2s::change_password_route) .ruma_route(c2s::deactivate_route) .ruma_route(c2s::third_party_route) .ruma_route(c2s::request_3pid_management_token_via_email_route) .ruma_route(c2s::request_3pid_management_token_via_msisdn_route) .ruma_route(c2s::get_capabilities_route) .ruma_route(c2s::get_pushrules_all_route) .ruma_route(c2s::set_pushrule_route) .ruma_route(c2s::get_pushrule_route) .ruma_route(c2s::set_pushrule_enabled_route) .ruma_route(c2s::get_pushrule_enabled_route) .ruma_route(c2s::get_pushrule_actions_route) .ruma_route(c2s::set_pushrule_actions_route) .ruma_route(c2s::delete_pushrule_route) .ruma_route(c2s::get_room_event_route) .ruma_route(c2s::get_room_aliases_route) .ruma_route(c2s::get_filter_route) .ruma_route(c2s::create_filter_route) .ruma_route(c2s::set_global_account_data_route) .ruma_route(c2s::set_room_account_data_route) .ruma_route(c2s::get_global_account_data_route) .ruma_route(c2s::get_room_account_data_route) .ruma_route(c2s::set_displayname_route) .ruma_route(c2s::get_displayname_route) .ruma_route(c2s::set_avatar_url_route) .ruma_route(c2s::get_avatar_url_route) .ruma_route(c2s::get_profile_route) .ruma_route(c2s::upload_keys_route) .ruma_route(c2s::get_keys_route) .ruma_route(c2s::claim_keys_route) .ruma_route(c2s::create_backup_version_route) .ruma_route(c2s::update_backup_version_route) .ruma_route(c2s::delete_backup_version_route) .ruma_route(c2s::get_latest_backup_info_route) .ruma_route(c2s::get_backup_info_route) .ruma_route(c2s::add_backup_keys_route) .ruma_route(c2s::add_backup_keys_for_room_route) .ruma_route(c2s::add_backup_keys_for_session_route) .ruma_route(c2s::delete_backup_keys_for_room_route) .ruma_route(c2s::delete_backup_keys_for_session_route) .ruma_route(c2s::delete_backup_keys_route) .ruma_route(c2s::get_backup_keys_for_room_route) .ruma_route(c2s::get_backup_keys_for_session_route) .ruma_route(c2s::get_backup_keys_route) .ruma_route(c2s::set_read_marker_route) .ruma_route(c2s::create_receipt_route) .ruma_route(c2s::create_typing_event_route) .ruma_route(c2s::create_room_route) .ruma_route(c2s::redact_event_route) .ruma_route(c2s::report_event_route) .ruma_route(c2s::create_alias_route) .ruma_route(c2s::delete_alias_route) .ruma_route(c2s::get_alias_route) .ruma_route(c2s::join_room_by_id_route) .ruma_route(c2s::join_room_by_id_or_alias_route) .ruma_route(c2s::joined_members_route) .ruma_route(c2s::leave_room_route) .ruma_route(c2s::forget_room_route) .ruma_route(c2s::joined_rooms_route) .ruma_route(c2s::kick_user_route) .ruma_route(c2s::ban_user_route) .ruma_route(c2s::unban_user_route) .ruma_route(c2s::invite_user_route) .ruma_route(c2s::set_room_visibility_route) .ruma_route(c2s::get_room_visibility_route) .ruma_route(c2s::get_public_rooms_route) .ruma_route(c2s::get_public_rooms_filtered_route) .ruma_route(c2s::search_users_route) .ruma_route(c2s::get_member_events_route) .ruma_route(c2s::get_protocols_route) .ruma_route(c2s::send_message_event_route) .ruma_route(c2s::send_state_event_for_key_route) .ruma_route(c2s::get_state_events_route) .ruma_route(c2s::get_state_events_for_key_route) .ruma_route(c2s::v3::sync_events_route) .ruma_route(c2s::msc3575::sync_events_v4_route) .ruma_route(c2s::get_context_route) .ruma_route(c2s::get_message_events_route) .ruma_route(c2s::search_events_route) .ruma_route(c2s::turn_server_route) .ruma_route(c2s::send_event_to_device_route); // authenticated media let router = router .ruma_route(c2s::get_media_config_route) .ruma_route(c2s::create_content_route) .ruma_route(c2s::get_content_route) .ruma_route(c2s::get_content_as_filename_route) .ruma_route(c2s::get_content_thumbnail_route); let router = router .ruma_route(c2s::get_devices_route) .ruma_route(c2s::get_device_route) .ruma_route(c2s::update_device_route) .ruma_route(c2s::delete_device_route) .ruma_route(c2s::delete_devices_route) .ruma_route(c2s::get_tags_route) .ruma_route(c2s::update_tag_route) .ruma_route(c2s::delete_tag_route) .ruma_route(c2s::upload_signing_keys_route) .ruma_route(c2s::upload_signatures_route) .ruma_route(c2s::get_key_changes_route) .ruma_route(c2s::get_pushers_route) .ruma_route(c2s::set_pushers_route) .ruma_route(c2s::upgrade_room_route) .ruma_route(c2s::get_threads_route) .ruma_route(c2s::get_relating_events_with_rel_type_and_event_type_route) .ruma_route(c2s::get_relating_events_with_rel_type_route) .ruma_route(c2s::get_relating_events_route) .ruma_route(c2s::get_hierarchy_route); // Ruma doesn't have support for multiple paths for a single endpoint yet, // and these routes share one Ruma request / response type pair with // {get,send}_state_event_for_key_route. These two endpoints also allow // trailing slashes. let router = router .route( "/_matrix/client/r0/rooms/:room_id/state/:event_type", get(c2s::get_state_events_for_empty_key_route) .put(c2s::send_state_event_for_empty_key_route), ) .route( "/_matrix/client/v3/rooms/:room_id/state/:event_type", get(c2s::get_state_events_for_empty_key_route) .put(c2s::send_state_event_for_empty_key_route), ) .route( "/_matrix/client/r0/rooms/:room_id/state/:event_type/", get(c2s::get_state_events_for_empty_key_route) .put(c2s::send_state_event_for_empty_key_route), ) .route( "/_matrix/client/v3/rooms/:room_id/state/:event_type/", get(c2s::get_state_events_for_empty_key_route) .put(c2s::send_state_event_for_empty_key_route), ); router .route( "/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync), ) .route( "/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync), ) } fn federation_routes(config: &Config) -> Router { use server_server as s2s; if config.federation.enable { Router::new() .ruma_route(s2s::get_server_version_route) .route("/_matrix/key/v2/server", get(s2s::get_server_keys_route)) .route( "/_matrix/key/v2/server/:key_id", get(s2s::get_server_keys_deprecated_route), ) .ruma_route(s2s::get_public_rooms_route) .ruma_route(s2s::get_public_rooms_filtered_route) .ruma_route(s2s::send_transaction_message_route) .ruma_route(s2s::get_event_route) .ruma_route(s2s::get_backfill_route) .ruma_route(s2s::get_missing_events_route) .ruma_route(s2s::get_event_authorization_route) .ruma_route(s2s::get_room_state_route) .ruma_route(s2s::get_room_state_ids_route) .ruma_route(s2s::create_join_event_template_route) .ruma_route(s2s::create_join_event_v1_route) .ruma_route(s2s::create_join_event_v2_route) .ruma_route(s2s::create_invite_route) .ruma_route(s2s::get_devices_route) .ruma_route(s2s::get_room_information_route) .ruma_route(s2s::get_profile_information_route) .ruma_route(s2s::get_keys_route) .ruma_route(s2s::claim_keys_route) .ruma_route(s2s::media_download_route) .ruma_route(s2s::media_thumbnail_route) } else { Router::new() .route("/_matrix/federation/*path", any(federation_disabled)) .route("/_matrix/key/*path", any(federation_disabled)) } } fn metrics_routes(config: &Config) -> Router { if config.observability.metrics.enable { Router::new().route( "/metrics", get(|| async { observability::METRICS.export() }), ) } else { Router::new() } } fn well_known_routes() -> Router { Router::new() .route("/.well-known/matrix/client", get(well_known::client)) .route("/.well-known/matrix/server", get(well_known::server)) } fn routes(config: &Config, components: &HashSet) -> Router { let mut router = Router::new(); for &component in components { router = router.merge(match component { ListenComponent::Client => client_routes(), ListenComponent::Federation => federation_routes(config), ListenComponent::Metrics => metrics_routes(config), ListenComponent::WellKnown => well_known_routes(), }); } if components.contains(&ListenComponent::Client) || components.contains(&ListenComponent::Federation) { router = router.merge(legacy_media_routes(config)); } router.route("/", get(it_works)).fallback(not_found) } async fn reload_tls_config( tls_config: &RustlsConfig, ) -> Result<(), error::Serve> { let config = services() .globals .config .tls .as_ref() .expect("TLS config should exist if TLS listener exists"); tls_config.reload_from_pem_file(&config.certs, &config.key).await.map_err( |err| error::Serve::LoadCerts { certs: config.certs.clone(), key: config.key.clone(), err, }, )?; Ok(()) } async fn handle_signals( tls_config: Option, handles: Vec>, ) { #[cfg(unix)] async fn wait_signal(sig: signal::unix::SignalKind) { signal::unix::signal(sig) .expect("failed to install signal handler") .recv() .await; } #[cfg(unix)] let terminate = || wait_signal(signal::unix::SignalKind::terminate()); #[cfg(not(unix))] let terminate = || std::future::pending::<()>(); #[cfg(unix)] let sighup = || wait_signal(signal::unix::SignalKind::hangup()); #[cfg(not(unix))] let sighup = || std::future::pending::<()>(); let ctrl_c = || async { signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); }; let sig = loop { tokio::select! { () = sighup() => { info!("Received reload request"); set_application_state(ApplicationState::Reloading); if let Some(tls_config) = tls_config.as_ref() { if let Err(error) = reload_tls_config(tls_config).await { error!(?error, "Failed to reload TLS config"); } } set_application_state(ApplicationState::Ready); }, () = terminate() => { break "SIGTERM"; }, () = ctrl_c() => { break "Ctrl+C"; }, } }; warn!(signal = %sig, "Shutting down due to signal"); services().globals.shutdown(); for handle in handles { handle.shutdown(Some(Duration::from_secs(30))); } set_application_state(ApplicationState::Stopping); } async fn federation_disabled(_: Uri) -> impl IntoResponse { Error::bad_config("Federation is disabled.") } async fn unauthenticated_media_disabled(_: Uri) -> impl IntoResponse { Error::BadRequest( ErrorKind::NotFound, "Unauthenticated media access is disabled", ) } async fn not_found(method: Method, uri: Uri) -> impl IntoResponse { debug!(%method, %uri, "Unknown route"); Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") } async fn initial_sync(_uri: Uri) -> impl IntoResponse { Error::BadRequest( ErrorKind::GuestAccessForbidden, "Guest access not implemented", ) } async fn it_works() -> &'static str { "Hello from Grapevine!" } trait RouterExt { fn ruma_route(self, handler: H) -> Self where H: RumaHandler, T: 'static; } impl RouterExt for Router { fn ruma_route(self, handler: H) -> Self where H: RumaHandler, T: 'static, { handler.add_to_router(self) } } pub(crate) trait RumaHandler { // Can't transform to a handler without boxing or relying on the // nightly-only impl-trait-in-traits feature. Moving a small amount of // extra logic into the trait allows bypassing both. fn add_to_router(self, router: Router) -> Router; } macro_rules! impl_ruma_handler { ( $($ty:ident),* $(,)? ) => { #[axum::async_trait] #[allow(non_snake_case)] impl RumaHandler<($($ty,)* Ar,)> for F where Req: IncomingRequest + Send + 'static, Resp: IntoResponse, F: FnOnce($($ty,)* Ar) -> Fut + Clone + Send + 'static, Fut: Future> + Send, E: IntoResponse, $( $ty: FromRequestParts<()> + Send + 'static, )* { fn add_to_router(self, mut router: Router) -> Router { let meta = Req::METADATA; let method_filter = method_to_filter(meta.method); for path in meta.history.all_paths() { let handler = self.clone(); router = router.route( path, on( method_filter, |$( $ty: $ty, )* req: Ar| async move { let span = info_span!( "run_ruma_handler", auth.user = ?req.sender_user, auth.device = ?req.sender_device, auth.servername = ?req.sender_servername, auth.appservice_id = ?req.appservice_info .as_ref() .map(|i| &i.registration.id) ); handler($($ty,)* req).instrument(span).await } ) ) } router } } }; } impl_ruma_handler!(); impl_ruma_handler!(T1); impl_ruma_handler!(T1, T2); impl_ruma_handler!(T1, T2, T3); impl_ruma_handler!(T1, T2, T3, T4); impl_ruma_handler!(T1, T2, T3, T4, T5); impl_ruma_handler!(T1, T2, T3, T4, T5, T6); impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7); impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8); fn method_to_filter(method: Method) -> MethodFilter { match method { Method::DELETE => MethodFilter::DELETE, Method::GET => MethodFilter::GET, Method::HEAD => MethodFilter::HEAD, Method::OPTIONS => MethodFilter::OPTIONS, Method::PATCH => MethodFilter::PATCH, Method::POST => MethodFilter::POST, Method::PUT => MethodFilter::PUT, Method::TRACE => MethodFilter::TRACE, m => panic!("Unsupported HTTP method: {m:?}"), } }