diff --git a/src/cli.rs b/src/cli.rs index cb6f91bb..cb8f4154 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -1,9 +1,16 @@ //! Integration with `clap` +//! +//! CLI argument structs are defined in this module. Execution logic for each +//! command goes in a submodule. use std::path::PathBuf; use clap::{Parser, Subcommand}; +use crate::error; + +mod serve; + /// Command line arguments #[derive(Parser)] #[clap( @@ -49,3 +56,11 @@ pub(crate) struct ServeArgs { #[clap(flatten)] pub(crate) config: ConfigArg, } + +impl Args { + pub(crate) async fn run(self) -> Result<(), error::Main> { + match self.command { + Command::Serve(args) => serve::run(args).await, + } + } +} diff --git a/src/cli/serve.rs b/src/cli/serve.rs new file mode 100644 index 00000000..71980e42 --- /dev/null +++ b/src/cli/serve.rs @@ -0,0 +1,657 @@ +use std::{future::Future, net::SocketAddr, sync::atomic, time::Duration}; + +use axum::{ + extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, + response::IntoResponse, + routing::{any, get, on, MethodFilter}, + Router, +}; +use axum_server::{ + bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle, +}; +use futures_util::FutureExt; +use http::{ + header::{self, HeaderName}, + Method, StatusCode, Uri, +}; +use ruma::api::{ + client::{ + error::{Error as RumaError, ErrorBody, ErrorKind}, + uiaa::UiaaResponse, + }, + IncomingRequest, +}; +use tokio::{signal, task::JoinSet}; +use tower::ServiceBuilder; +use tower_http::{ + cors::{self, CorsLayer}, + trace::TraceLayer, + ServiceBuilderExt as _, +}; +use tracing::{debug, info, info_span, warn, Instrument}; + +use super::ServeArgs; +use crate::{ + api::{ + client_server, + ruma_wrapper::{Ar, Ra}, + server_server, well_known, + }, + config, + config::{Config, ListenConfig}, + database::KeyValueDatabase, + error, observability, services, utils, + utils::error::{Error, Result}, +}; + +pub(crate) async fn run(args: ServeArgs) -> Result<(), error::Main> { + use error::Main as Error; + + let config = config::load(args.config.config.as_ref()).await?; + + let (_guard, reload_handles) = observability::init(&config)?; + + // This is needed for opening lots of file descriptors, which tends to + // happen more often when using RocksDB and making lots of federation + // connections at startup. The soft limit is usually 1024, and the hard + // limit is usually 512000; I've personally seen it hit >2000. + // + // * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6 + // * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741 + #[cfg(unix)] + maximize_fd_limit() + .expect("should be able to increase the soft limit to the hard limit"); + + info!("Loading database"); + KeyValueDatabase::load_or_create(config, reload_handles) + .await + .map_err(Error::DatabaseError)?; + + info!("Starting server"); + run_server().await?; + + Ok(()) +} + +#[allow(clippy::too_many_lines)] +async fn run_server() -> Result<(), error::Serve> { + use error::Serve as Error; + + let config = &services().globals.config; + + let x_requested_with = HeaderName::from_static("x-requested-with"); + + let middlewares = ServiceBuilder::new() + .sensitive_headers([header::AUTHORIZATION]) + .layer(axum::middleware::from_fn(spawn_task)) + .layer( + TraceLayer::new_for_http() + .make_span_with(|request: &http::Request<_>| { + let endpoint = if let Some(endpoint) = + request.extensions().get::() + { + endpoint.as_str() + } else { + request.uri().path() + }; + + let method = request.method(); + + tracing::info_span!( + "http_request", + otel.name = format!("{method} {endpoint}"), + %method, + %endpoint, + ) + }) + .on_request( + |request: &http::Request<_>, _span: &tracing::Span| { + // can be enabled selectively using `filter = + // grapevine[incoming_request_curl]=trace` in config + tracing::trace_span!("incoming_request_curl").in_scope( + || { + tracing::trace!( + cmd = utils::curlify(request), + "curl command line for incoming request \ + (guessed hostname)" + ); + }, + ); + }, + ), + ) + .layer(axum::middleware::from_fn(unrecognized_method)) + .layer( + CorsLayer::new() + .allow_origin(cors::Any) + .allow_methods([ + Method::GET, + Method::POST, + Method::PUT, + Method::DELETE, + Method::OPTIONS, + ]) + .allow_headers([ + header::ORIGIN, + x_requested_with, + header::CONTENT_TYPE, + header::ACCEPT, + header::AUTHORIZATION, + ]) + .max_age(Duration::from_secs(86400)), + ) + .layer(DefaultBodyLimit::max( + config + .max_request_size + .try_into() + .expect("failed to convert max request size"), + )) + .layer(axum::middleware::from_fn(observability::http_metrics_layer)); + + let app = routes(config).layer(middlewares).into_make_service(); + let mut handles = Vec::new(); + let mut servers = JoinSet::new(); + + let tls_config = if let Some(tls) = &config.tls { + Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err( + |err| Error::LoadCerts { + certs: tls.certs.clone(), + key: tls.key.clone(), + err, + }, + )?) + } else { + None + }; + + if config.listen.is_empty() { + return Err(Error::NoListeners); + } + + for listen in &config.listen { + info!(listener = %listen, "Listening for incoming traffic"); + match listen { + ListenConfig::Tcp { + address, + port, + tls, + } => { + let addr = SocketAddr::from((*address, *port)); + let handle = ServerHandle::new(); + handles.push(handle.clone()); + let server = if *tls { + let tls_config = tls_config + .clone() + .ok_or_else(|| Error::NoTlsCerts(listen.clone()))?; + bind_rustls(addr, tls_config) + .handle(handle) + .serve(app.clone()) + .left_future() + } else { + bind(addr).handle(handle).serve(app.clone()).right_future() + }; + servers.spawn( + server.then(|result| async { (listen.clone(), result) }), + ); + } + } + } + + #[cfg(feature = "systemd")] + sd_notify::notify(true, &[sd_notify::NotifyState::Ready]) + .expect("should be able to notify systemd"); + + tokio::spawn(shutdown_signal(handles)); + + while let Some(result) = servers.join_next().await { + let (listen, result) = + result.expect("should be able to join server task"); + result.map_err(|err| Error::Listen(err, listen))?; + } + + Ok(()) +} + +/// Ensures the request runs in a new tokio thread. +/// +/// The axum request handler task gets cancelled if the connection is shut down; +/// by spawning our own task, processing continue after the client disconnects. +async fn spawn_task( + req: axum::extract::Request, + next: axum::middleware::Next, +) -> std::result::Result { + if services().globals.shutdown.load(atomic::Ordering::Relaxed) { + return Err(StatusCode::SERVICE_UNAVAILABLE); + } + tokio::spawn(next.run(req)) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) +} + +async fn unrecognized_method( + req: axum::extract::Request, + next: axum::middleware::Next, +) -> std::result::Result { + let method = req.method().clone(); + let uri = req.uri().clone(); + let inner = next.run(req).await; + if inner.status() == StatusCode::METHOD_NOT_ALLOWED { + warn!(%method, %uri, "Method not allowed"); + return Ok(Ra(UiaaResponse::MatrixError(RumaError { + body: ErrorBody::Standard { + kind: ErrorKind::Unrecognized, + message: "M_UNRECOGNIZED: Unrecognized request".to_owned(), + }, + status_code: StatusCode::METHOD_NOT_ALLOWED, + })) + .into_response()); + } + Ok(inner) +} + +#[allow(clippy::too_many_lines)] +fn routes(config: &Config) -> Router { + use client_server as c2s; + use server_server as s2s; + + let router = Router::new() + .ruma_route(c2s::get_supported_versions_route) + .ruma_route(c2s::get_register_available_route) + .ruma_route(c2s::register_route) + .ruma_route(c2s::get_login_types_route) + .ruma_route(c2s::login_route) + .ruma_route(c2s::whoami_route) + .ruma_route(c2s::logout_route) + .ruma_route(c2s::logout_all_route) + .ruma_route(c2s::change_password_route) + .ruma_route(c2s::deactivate_route) + .ruma_route(c2s::third_party_route) + .ruma_route(c2s::request_3pid_management_token_via_email_route) + .ruma_route(c2s::request_3pid_management_token_via_msisdn_route) + .ruma_route(c2s::get_capabilities_route) + .ruma_route(c2s::get_pushrules_all_route) + .ruma_route(c2s::set_pushrule_route) + .ruma_route(c2s::get_pushrule_route) + .ruma_route(c2s::set_pushrule_enabled_route) + .ruma_route(c2s::get_pushrule_enabled_route) + .ruma_route(c2s::get_pushrule_actions_route) + .ruma_route(c2s::set_pushrule_actions_route) + .ruma_route(c2s::delete_pushrule_route) + .ruma_route(c2s::get_room_event_route) + .ruma_route(c2s::get_room_aliases_route) + .ruma_route(c2s::get_filter_route) + .ruma_route(c2s::create_filter_route) + .ruma_route(c2s::set_global_account_data_route) + .ruma_route(c2s::set_room_account_data_route) + .ruma_route(c2s::get_global_account_data_route) + .ruma_route(c2s::get_room_account_data_route) + .ruma_route(c2s::set_displayname_route) + .ruma_route(c2s::get_displayname_route) + .ruma_route(c2s::set_avatar_url_route) + .ruma_route(c2s::get_avatar_url_route) + .ruma_route(c2s::get_profile_route) + .ruma_route(c2s::upload_keys_route) + .ruma_route(c2s::get_keys_route) + .ruma_route(c2s::claim_keys_route) + .ruma_route(c2s::create_backup_version_route) + .ruma_route(c2s::update_backup_version_route) + .ruma_route(c2s::delete_backup_version_route) + .ruma_route(c2s::get_latest_backup_info_route) + .ruma_route(c2s::get_backup_info_route) + .ruma_route(c2s::add_backup_keys_route) + .ruma_route(c2s::add_backup_keys_for_room_route) + .ruma_route(c2s::add_backup_keys_for_session_route) + .ruma_route(c2s::delete_backup_keys_for_room_route) + .ruma_route(c2s::delete_backup_keys_for_session_route) + .ruma_route(c2s::delete_backup_keys_route) + .ruma_route(c2s::get_backup_keys_for_room_route) + .ruma_route(c2s::get_backup_keys_for_session_route) + .ruma_route(c2s::get_backup_keys_route) + .ruma_route(c2s::set_read_marker_route) + .ruma_route(c2s::create_receipt_route) + .ruma_route(c2s::create_typing_event_route) + .ruma_route(c2s::create_room_route) + .ruma_route(c2s::redact_event_route) + .ruma_route(c2s::report_event_route) + .ruma_route(c2s::create_alias_route) + .ruma_route(c2s::delete_alias_route) + .ruma_route(c2s::get_alias_route) + .ruma_route(c2s::join_room_by_id_route) + .ruma_route(c2s::join_room_by_id_or_alias_route) + .ruma_route(c2s::joined_members_route) + .ruma_route(c2s::leave_room_route) + .ruma_route(c2s::forget_room_route) + .ruma_route(c2s::joined_rooms_route) + .ruma_route(c2s::kick_user_route) + .ruma_route(c2s::ban_user_route) + .ruma_route(c2s::unban_user_route) + .ruma_route(c2s::invite_user_route) + .ruma_route(c2s::set_room_visibility_route) + .ruma_route(c2s::get_room_visibility_route) + .ruma_route(c2s::get_public_rooms_route) + .ruma_route(c2s::get_public_rooms_filtered_route) + .ruma_route(c2s::search_users_route) + .ruma_route(c2s::get_member_events_route) + .ruma_route(c2s::get_protocols_route) + .ruma_route(c2s::send_message_event_route) + .ruma_route(c2s::send_state_event_for_key_route) + .ruma_route(c2s::get_state_events_route) + .ruma_route(c2s::get_state_events_for_key_route) + .ruma_route(c2s::sync_events_route) + .ruma_route(c2s::sync_events_v4_route) + .ruma_route(c2s::get_context_route) + .ruma_route(c2s::get_message_events_route) + .ruma_route(c2s::search_events_route) + .ruma_route(c2s::turn_server_route) + .ruma_route(c2s::send_event_to_device_route); + + // deprecated, but unproblematic + let router = router.ruma_route(c2s::get_media_config_legacy_route); + let router = if config.serve_media_unauthenticated { + router + .ruma_route(c2s::get_content_legacy_route) + .ruma_route(c2s::get_content_as_filename_legacy_route) + .ruma_route(c2s::get_content_thumbnail_legacy_route) + } else { + router + .route( + "/_matrix/media/v3/download/*path", + any(unauthenticated_media_disabled), + ) + .route( + "/_matrix/media/v3/thumbnail/*path", + any(unauthenticated_media_disabled), + ) + }; + + // authenticated media + let router = router + .ruma_route(c2s::get_media_config_route) + .ruma_route(c2s::create_content_route) + .ruma_route(c2s::get_content_route) + .ruma_route(c2s::get_content_as_filename_route) + .ruma_route(c2s::get_content_thumbnail_route); + + let router = router + .ruma_route(c2s::get_devices_route) + .ruma_route(c2s::get_device_route) + .ruma_route(c2s::update_device_route) + .ruma_route(c2s::delete_device_route) + .ruma_route(c2s::delete_devices_route) + .ruma_route(c2s::get_tags_route) + .ruma_route(c2s::update_tag_route) + .ruma_route(c2s::delete_tag_route) + .ruma_route(c2s::upload_signing_keys_route) + .ruma_route(c2s::upload_signatures_route) + .ruma_route(c2s::get_key_changes_route) + .ruma_route(c2s::get_pushers_route) + .ruma_route(c2s::set_pushers_route) + .ruma_route(c2s::upgrade_room_route) + .ruma_route(c2s::get_threads_route) + .ruma_route(c2s::get_relating_events_with_rel_type_and_event_type_route) + .ruma_route(c2s::get_relating_events_with_rel_type_route) + .ruma_route(c2s::get_relating_events_route) + .ruma_route(c2s::get_hierarchy_route); + + // Ruma doesn't have support for multiple paths for a single endpoint yet, + // and these routes share one Ruma request / response type pair with + // {get,send}_state_event_for_key_route. These two endpoints also allow + // trailing slashes. + let router = router + .route( + "/_matrix/client/r0/rooms/:room_id/state/:event_type", + get(c2s::get_state_events_for_empty_key_route) + .put(c2s::send_state_event_for_empty_key_route), + ) + .route( + "/_matrix/client/v3/rooms/:room_id/state/:event_type", + get(c2s::get_state_events_for_empty_key_route) + .put(c2s::send_state_event_for_empty_key_route), + ) + .route( + "/_matrix/client/r0/rooms/:room_id/state/:event_type/", + get(c2s::get_state_events_for_empty_key_route) + .put(c2s::send_state_event_for_empty_key_route), + ) + .route( + "/_matrix/client/v3/rooms/:room_id/state/:event_type/", + get(c2s::get_state_events_for_empty_key_route) + .put(c2s::send_state_event_for_empty_key_route), + ); + + let router = if config.observability.metrics.enable { + router.route( + "/metrics", + get(|| async { observability::METRICS.export() }), + ) + } else { + router + }; + + let router = router + .route( + "/_matrix/client/r0/rooms/:room_id/initialSync", + get(initial_sync), + ) + .route( + "/_matrix/client/v3/rooms/:room_id/initialSync", + get(initial_sync), + ) + .route("/", get(it_works)) + .fallback(not_found); + + let router = router + .route("/.well-known/matrix/client", get(well_known::client)) + .route("/.well-known/matrix/server", get(well_known::server)); + + if config.federation.enable { + router + .ruma_route(s2s::get_server_version_route) + .route("/_matrix/key/v2/server", get(s2s::get_server_keys_route)) + .route( + "/_matrix/key/v2/server/:key_id", + get(s2s::get_server_keys_deprecated_route), + ) + .ruma_route(s2s::get_public_rooms_route) + .ruma_route(s2s::get_public_rooms_filtered_route) + .ruma_route(s2s::send_transaction_message_route) + .ruma_route(s2s::get_event_route) + .ruma_route(s2s::get_backfill_route) + .ruma_route(s2s::get_missing_events_route) + .ruma_route(s2s::get_event_authorization_route) + .ruma_route(s2s::get_room_state_route) + .ruma_route(s2s::get_room_state_ids_route) + .ruma_route(s2s::create_join_event_template_route) + .ruma_route(s2s::create_join_event_v1_route) + .ruma_route(s2s::create_join_event_v2_route) + .ruma_route(s2s::create_invite_route) + .ruma_route(s2s::get_devices_route) + .ruma_route(s2s::get_room_information_route) + .ruma_route(s2s::get_profile_information_route) + .ruma_route(s2s::get_keys_route) + .ruma_route(s2s::claim_keys_route) + .ruma_route(s2s::media_download_route) + .ruma_route(s2s::media_thumbnail_route) + } else { + router + .route("/_matrix/federation/*path", any(federation_disabled)) + .route("/_matrix/key/*path", any(federation_disabled)) + } +} + +async fn shutdown_signal(handles: Vec) { + let ctrl_c = async { + signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + let sig: &str; + + tokio::select! { + () = ctrl_c => { sig = "Ctrl+C"; }, + () = terminate => { sig = "SIGTERM"; }, + } + + warn!(signal = %sig, "Shutting down due to signal"); + + services().globals.shutdown(); + + for handle in handles { + handle.graceful_shutdown(Some(Duration::from_secs(30))); + } + + #[cfg(feature = "systemd")] + sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]) + .expect("should be able to notify systemd"); +} + +async fn federation_disabled(_: Uri) -> impl IntoResponse { + Error::bad_config("Federation is disabled.") +} + +async fn unauthenticated_media_disabled(_: Uri) -> impl IntoResponse { + Error::BadRequest( + ErrorKind::NotFound, + "Unauthenticated media access is disabled", + ) +} + +async fn not_found(method: Method, uri: Uri) -> impl IntoResponse { + debug!(%method, %uri, "Unknown route"); + Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") +} + +async fn initial_sync(_uri: Uri) -> impl IntoResponse { + Error::BadRequest( + ErrorKind::GuestAccessForbidden, + "Guest access not implemented", + ) +} + +async fn it_works() -> &'static str { + "Hello from Grapevine!" +} + +trait RouterExt { + fn ruma_route(self, handler: H) -> Self + where + H: RumaHandler, + T: 'static; +} + +impl RouterExt for Router { + fn ruma_route(self, handler: H) -> Self + where + H: RumaHandler, + T: 'static, + { + handler.add_to_router(self) + } +} + +pub(crate) trait RumaHandler { + // Can't transform to a handler without boxing or relying on the + // nightly-only impl-trait-in-traits feature. Moving a small amount of + // extra logic into the trait allows bypassing both. + fn add_to_router(self, router: Router) -> Router; +} + +macro_rules! impl_ruma_handler { + ( $($ty:ident),* $(,)? ) => { + #[axum::async_trait] + #[allow(non_snake_case)] + impl + RumaHandler<($($ty,)* Ar,)> for F + where + Req: IncomingRequest + Send + 'static, + Resp: IntoResponse, + F: FnOnce($($ty,)* Ar) -> Fut + Clone + Send + 'static, + Fut: Future> + + Send, + E: IntoResponse, + $( $ty: FromRequestParts<()> + Send + 'static, )* + { + fn add_to_router(self, mut router: Router) -> Router { + let meta = Req::METADATA; + let method_filter = method_to_filter(meta.method); + + for path in meta.history.all_paths() { + let handler = self.clone(); + + router = router.route( + path, + on( + method_filter, + |$( $ty: $ty, )* req: Ar| async move { + let span = info_span!( + "run_ruma_handler", + auth.user = ?req.sender_user, + auth.device = ?req.sender_device, + auth.servername = ?req.sender_servername, + auth.appservice_id = ?req.appservice_info + .as_ref() + .map(|i| &i.registration.id) + ); + handler($($ty,)* req).instrument(span).await + } + ) + ) + } + + router + } + } + }; +} + +impl_ruma_handler!(); +impl_ruma_handler!(T1); +impl_ruma_handler!(T1, T2); +impl_ruma_handler!(T1, T2, T3); +impl_ruma_handler!(T1, T2, T3, T4); +impl_ruma_handler!(T1, T2, T3, T4, T5); +impl_ruma_handler!(T1, T2, T3, T4, T5, T6); +impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7); +impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8); + +fn method_to_filter(method: Method) -> MethodFilter { + match method { + Method::DELETE => MethodFilter::DELETE, + Method::GET => MethodFilter::GET, + Method::HEAD => MethodFilter::HEAD, + Method::OPTIONS => MethodFilter::OPTIONS, + Method::PATCH => MethodFilter::PATCH, + Method::POST => MethodFilter::POST, + Method::PUT => MethodFilter::PUT, + Method::TRACE => MethodFilter::TRACE, + m => panic!("Unsupported HTTP method: {m:?}"), + } +} + +#[cfg(unix)] +#[tracing::instrument(err)] +fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { + use nix::sys::resource::{getrlimit, setrlimit, Resource}; + + let res = Resource::RLIMIT_NOFILE; + + let (soft_limit, hard_limit) = getrlimit(res)?; + + debug!(soft_limit, "Current nofile soft limit"); + + setrlimit(res, hard_limit, hard_limit)?; + + debug!(hard_limit, "Increased nofile soft limit to the hard limit"); + + Ok(()) +} diff --git a/src/main.rs b/src/main.rs index e90e9f7d..080b9714 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,44 +2,10 @@ // work anyway #![cfg_attr(not(any(feature = "sqlite", feature = "rocksdb")), allow(unused))] -use std::{ - future::Future, - net::SocketAddr, - process::ExitCode, - sync::{atomic, RwLock}, - time::Duration, -}; +use std::{process::ExitCode, sync::RwLock}; -use axum::{ - extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, - response::IntoResponse, - routing::{any, get, on, MethodFilter}, - Router, -}; -use axum_server::{ - bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle, -}; use clap::Parser; -use futures_util::FutureExt; -use http::{ - header::{self, HeaderName}, - Method, StatusCode, Uri, -}; -use ruma::api::{ - client::{ - error::{Error as RumaError, ErrorBody, ErrorKind}, - uiaa::UiaaResponse, - }, - IncomingRequest, -}; -use tokio::{signal, task::JoinSet}; -use tower::ServiceBuilder; -use tower_http::{ - cors::{self, CorsLayer}, - trace::TraceLayer, - ServiceBuilderExt as _, -}; -use tracing::{debug, error, info, info_span, warn, Instrument}; +use tracing::error; mod api; mod cli; @@ -51,10 +17,7 @@ mod service; mod utils; pub(crate) use api::ruma_wrapper::{Ar, Ra}; -use api::{client_server, server_server, well_known}; -use cli::{Args, Command}; -pub(crate) use config::{Config, ListenConfig}; -pub(crate) use database::KeyValueDatabase; +pub(crate) use config::Config; pub(crate) use service::{pdu::PduEvent, Services}; #[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] use tikv_jemallocator::Jemalloc; @@ -91,7 +54,8 @@ fn version() -> String { #[tokio::main] async fn main() -> ExitCode { - let Err(e) = try_main().await else { + let args = cli::Args::parse(); + let Err(e) = args.run().await else { return ExitCode::SUCCESS; }; @@ -105,621 +69,3 @@ async fn main() -> ExitCode { ExitCode::FAILURE } - -/// Fallible entrypoint -async fn try_main() -> Result<(), error::Main> { - use error::Main as Error; - - let args = Args::parse(); - // This is a placeholder, the logic specific to the 'serve' command will be - // moved to another file in a later commit - let Command::Serve(args) = args.command; - - let config = config::load(args.config.config.as_ref()).await?; - - let (_guard, reload_handles) = observability::init(&config)?; - - // This is needed for opening lots of file descriptors, which tends to - // happen more often when using RocksDB and making lots of federation - // connections at startup. The soft limit is usually 1024, and the hard - // limit is usually 512000; I've personally seen it hit >2000. - // - // * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6 - // * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741 - #[cfg(unix)] - maximize_fd_limit() - .expect("should be able to increase the soft limit to the hard limit"); - - info!("Loading database"); - KeyValueDatabase::load_or_create(config, reload_handles) - .await - .map_err(Error::DatabaseError)?; - - info!("Starting server"); - run_server().await?; - - Ok(()) -} - -#[allow(clippy::too_many_lines)] -async fn run_server() -> Result<(), error::Serve> { - use error::Serve as Error; - - let config = &services().globals.config; - - let x_requested_with = HeaderName::from_static("x-requested-with"); - - let middlewares = ServiceBuilder::new() - .sensitive_headers([header::AUTHORIZATION]) - .layer(axum::middleware::from_fn(spawn_task)) - .layer( - TraceLayer::new_for_http() - .make_span_with(|request: &http::Request<_>| { - let endpoint = if let Some(endpoint) = - request.extensions().get::() - { - endpoint.as_str() - } else { - request.uri().path() - }; - - let method = request.method(); - - tracing::info_span!( - "http_request", - otel.name = format!("{method} {endpoint}"), - %method, - %endpoint, - ) - }) - .on_request( - |request: &http::Request<_>, _span: &tracing::Span| { - // can be enabled selectively using `filter = - // grapevine[incoming_request_curl]=trace` in config - tracing::trace_span!("incoming_request_curl").in_scope( - || { - tracing::trace!( - cmd = utils::curlify(request), - "curl command line for incoming request \ - (guessed hostname)" - ); - }, - ); - }, - ), - ) - .layer(axum::middleware::from_fn(unrecognized_method)) - .layer( - CorsLayer::new() - .allow_origin(cors::Any) - .allow_methods([ - Method::GET, - Method::POST, - Method::PUT, - Method::DELETE, - Method::OPTIONS, - ]) - .allow_headers([ - header::ORIGIN, - x_requested_with, - header::CONTENT_TYPE, - header::ACCEPT, - header::AUTHORIZATION, - ]) - .max_age(Duration::from_secs(86400)), - ) - .layer(DefaultBodyLimit::max( - config - .max_request_size - .try_into() - .expect("failed to convert max request size"), - )) - .layer(axum::middleware::from_fn(observability::http_metrics_layer)); - - let app = routes(config).layer(middlewares).into_make_service(); - let mut handles = Vec::new(); - let mut servers = JoinSet::new(); - - let tls_config = if let Some(tls) = &config.tls { - Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err( - |err| Error::LoadCerts { - certs: tls.certs.clone(), - key: tls.key.clone(), - err, - }, - )?) - } else { - None - }; - - if config.listen.is_empty() { - return Err(Error::NoListeners); - } - - for listen in &config.listen { - info!(listener = %listen, "Listening for incoming traffic"); - match listen { - ListenConfig::Tcp { - address, - port, - tls, - } => { - let addr = SocketAddr::from((*address, *port)); - let handle = ServerHandle::new(); - handles.push(handle.clone()); - let server = if *tls { - let tls_config = tls_config - .clone() - .ok_or_else(|| Error::NoTlsCerts(listen.clone()))?; - bind_rustls(addr, tls_config) - .handle(handle) - .serve(app.clone()) - .left_future() - } else { - bind(addr).handle(handle).serve(app.clone()).right_future() - }; - servers.spawn( - server.then(|result| async { (listen.clone(), result) }), - ); - } - } - } - - #[cfg(feature = "systemd")] - sd_notify::notify(true, &[sd_notify::NotifyState::Ready]) - .expect("should be able to notify systemd"); - - tokio::spawn(shutdown_signal(handles)); - - while let Some(result) = servers.join_next().await { - let (listen, result) = - result.expect("should be able to join server task"); - result.map_err(|err| Error::Listen(err, listen))?; - } - - Ok(()) -} - -/// Ensures the request runs in a new tokio thread. -/// -/// The axum request handler task gets cancelled if the connection is shut down; -/// by spawning our own task, processing continue after the client disconnects. -async fn spawn_task( - req: axum::extract::Request, - next: axum::middleware::Next, -) -> std::result::Result { - if services().globals.shutdown.load(atomic::Ordering::Relaxed) { - return Err(StatusCode::SERVICE_UNAVAILABLE); - } - tokio::spawn(next.run(req)) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) -} - -async fn unrecognized_method( - req: axum::extract::Request, - next: axum::middleware::Next, -) -> std::result::Result { - let method = req.method().clone(); - let uri = req.uri().clone(); - let inner = next.run(req).await; - if inner.status() == StatusCode::METHOD_NOT_ALLOWED { - warn!(%method, %uri, "Method not allowed"); - return Ok(Ra(UiaaResponse::MatrixError(RumaError { - body: ErrorBody::Standard { - kind: ErrorKind::Unrecognized, - message: "M_UNRECOGNIZED: Unrecognized request".to_owned(), - }, - status_code: StatusCode::METHOD_NOT_ALLOWED, - })) - .into_response()); - } - Ok(inner) -} - -#[allow(clippy::too_many_lines)] -fn routes(config: &Config) -> Router { - use client_server as c2s; - use server_server as s2s; - - let router = Router::new() - .ruma_route(c2s::get_supported_versions_route) - .ruma_route(c2s::get_register_available_route) - .ruma_route(c2s::register_route) - .ruma_route(c2s::get_login_types_route) - .ruma_route(c2s::login_route) - .ruma_route(c2s::whoami_route) - .ruma_route(c2s::logout_route) - .ruma_route(c2s::logout_all_route) - .ruma_route(c2s::change_password_route) - .ruma_route(c2s::deactivate_route) - .ruma_route(c2s::third_party_route) - .ruma_route(c2s::request_3pid_management_token_via_email_route) - .ruma_route(c2s::request_3pid_management_token_via_msisdn_route) - .ruma_route(c2s::get_capabilities_route) - .ruma_route(c2s::get_pushrules_all_route) - .ruma_route(c2s::set_pushrule_route) - .ruma_route(c2s::get_pushrule_route) - .ruma_route(c2s::set_pushrule_enabled_route) - .ruma_route(c2s::get_pushrule_enabled_route) - .ruma_route(c2s::get_pushrule_actions_route) - .ruma_route(c2s::set_pushrule_actions_route) - .ruma_route(c2s::delete_pushrule_route) - .ruma_route(c2s::get_room_event_route) - .ruma_route(c2s::get_room_aliases_route) - .ruma_route(c2s::get_filter_route) - .ruma_route(c2s::create_filter_route) - .ruma_route(c2s::set_global_account_data_route) - .ruma_route(c2s::set_room_account_data_route) - .ruma_route(c2s::get_global_account_data_route) - .ruma_route(c2s::get_room_account_data_route) - .ruma_route(c2s::set_displayname_route) - .ruma_route(c2s::get_displayname_route) - .ruma_route(c2s::set_avatar_url_route) - .ruma_route(c2s::get_avatar_url_route) - .ruma_route(c2s::get_profile_route) - .ruma_route(c2s::upload_keys_route) - .ruma_route(c2s::get_keys_route) - .ruma_route(c2s::claim_keys_route) - .ruma_route(c2s::create_backup_version_route) - .ruma_route(c2s::update_backup_version_route) - .ruma_route(c2s::delete_backup_version_route) - .ruma_route(c2s::get_latest_backup_info_route) - .ruma_route(c2s::get_backup_info_route) - .ruma_route(c2s::add_backup_keys_route) - .ruma_route(c2s::add_backup_keys_for_room_route) - .ruma_route(c2s::add_backup_keys_for_session_route) - .ruma_route(c2s::delete_backup_keys_for_room_route) - .ruma_route(c2s::delete_backup_keys_for_session_route) - .ruma_route(c2s::delete_backup_keys_route) - .ruma_route(c2s::get_backup_keys_for_room_route) - .ruma_route(c2s::get_backup_keys_for_session_route) - .ruma_route(c2s::get_backup_keys_route) - .ruma_route(c2s::set_read_marker_route) - .ruma_route(c2s::create_receipt_route) - .ruma_route(c2s::create_typing_event_route) - .ruma_route(c2s::create_room_route) - .ruma_route(c2s::redact_event_route) - .ruma_route(c2s::report_event_route) - .ruma_route(c2s::create_alias_route) - .ruma_route(c2s::delete_alias_route) - .ruma_route(c2s::get_alias_route) - .ruma_route(c2s::join_room_by_id_route) - .ruma_route(c2s::join_room_by_id_or_alias_route) - .ruma_route(c2s::joined_members_route) - .ruma_route(c2s::leave_room_route) - .ruma_route(c2s::forget_room_route) - .ruma_route(c2s::joined_rooms_route) - .ruma_route(c2s::kick_user_route) - .ruma_route(c2s::ban_user_route) - .ruma_route(c2s::unban_user_route) - .ruma_route(c2s::invite_user_route) - .ruma_route(c2s::set_room_visibility_route) - .ruma_route(c2s::get_room_visibility_route) - .ruma_route(c2s::get_public_rooms_route) - .ruma_route(c2s::get_public_rooms_filtered_route) - .ruma_route(c2s::search_users_route) - .ruma_route(c2s::get_member_events_route) - .ruma_route(c2s::get_protocols_route) - .ruma_route(c2s::send_message_event_route) - .ruma_route(c2s::send_state_event_for_key_route) - .ruma_route(c2s::get_state_events_route) - .ruma_route(c2s::get_state_events_for_key_route) - .ruma_route(c2s::sync_events_route) - .ruma_route(c2s::sync_events_v4_route) - .ruma_route(c2s::get_context_route) - .ruma_route(c2s::get_message_events_route) - .ruma_route(c2s::search_events_route) - .ruma_route(c2s::turn_server_route) - .ruma_route(c2s::send_event_to_device_route); - - // deprecated, but unproblematic - let router = router.ruma_route(c2s::get_media_config_legacy_route); - let router = if config.serve_media_unauthenticated { - router - .ruma_route(c2s::get_content_legacy_route) - .ruma_route(c2s::get_content_as_filename_legacy_route) - .ruma_route(c2s::get_content_thumbnail_legacy_route) - } else { - router - .route( - "/_matrix/media/v3/download/*path", - any(unauthenticated_media_disabled), - ) - .route( - "/_matrix/media/v3/thumbnail/*path", - any(unauthenticated_media_disabled), - ) - }; - - // authenticated media - let router = router - .ruma_route(c2s::get_media_config_route) - .ruma_route(c2s::create_content_route) - .ruma_route(c2s::get_content_route) - .ruma_route(c2s::get_content_as_filename_route) - .ruma_route(c2s::get_content_thumbnail_route); - - let router = router - .ruma_route(c2s::get_devices_route) - .ruma_route(c2s::get_device_route) - .ruma_route(c2s::update_device_route) - .ruma_route(c2s::delete_device_route) - .ruma_route(c2s::delete_devices_route) - .ruma_route(c2s::get_tags_route) - .ruma_route(c2s::update_tag_route) - .ruma_route(c2s::delete_tag_route) - .ruma_route(c2s::upload_signing_keys_route) - .ruma_route(c2s::upload_signatures_route) - .ruma_route(c2s::get_key_changes_route) - .ruma_route(c2s::get_pushers_route) - .ruma_route(c2s::set_pushers_route) - .ruma_route(c2s::upgrade_room_route) - .ruma_route(c2s::get_threads_route) - .ruma_route(c2s::get_relating_events_with_rel_type_and_event_type_route) - .ruma_route(c2s::get_relating_events_with_rel_type_route) - .ruma_route(c2s::get_relating_events_route) - .ruma_route(c2s::get_hierarchy_route); - - // Ruma doesn't have support for multiple paths for a single endpoint yet, - // and these routes share one Ruma request / response type pair with - // {get,send}_state_event_for_key_route. These two endpoints also allow - // trailing slashes. - let router = router - .route( - "/_matrix/client/r0/rooms/:room_id/state/:event_type", - get(c2s::get_state_events_for_empty_key_route) - .put(c2s::send_state_event_for_empty_key_route), - ) - .route( - "/_matrix/client/v3/rooms/:room_id/state/:event_type", - get(c2s::get_state_events_for_empty_key_route) - .put(c2s::send_state_event_for_empty_key_route), - ) - .route( - "/_matrix/client/r0/rooms/:room_id/state/:event_type/", - get(c2s::get_state_events_for_empty_key_route) - .put(c2s::send_state_event_for_empty_key_route), - ) - .route( - "/_matrix/client/v3/rooms/:room_id/state/:event_type/", - get(c2s::get_state_events_for_empty_key_route) - .put(c2s::send_state_event_for_empty_key_route), - ); - - let router = if config.observability.metrics.enable { - router.route( - "/metrics", - get(|| async { observability::METRICS.export() }), - ) - } else { - router - }; - - let router = router - .route( - "/_matrix/client/r0/rooms/:room_id/initialSync", - get(initial_sync), - ) - .route( - "/_matrix/client/v3/rooms/:room_id/initialSync", - get(initial_sync), - ) - .route("/", get(it_works)) - .fallback(not_found); - - let router = router - .route("/.well-known/matrix/client", get(well_known::client)) - .route("/.well-known/matrix/server", get(well_known::server)); - - if config.federation.enable { - router - .ruma_route(s2s::get_server_version_route) - .route("/_matrix/key/v2/server", get(s2s::get_server_keys_route)) - .route( - "/_matrix/key/v2/server/:key_id", - get(s2s::get_server_keys_deprecated_route), - ) - .ruma_route(s2s::get_public_rooms_route) - .ruma_route(s2s::get_public_rooms_filtered_route) - .ruma_route(s2s::send_transaction_message_route) - .ruma_route(s2s::get_event_route) - .ruma_route(s2s::get_backfill_route) - .ruma_route(s2s::get_missing_events_route) - .ruma_route(s2s::get_event_authorization_route) - .ruma_route(s2s::get_room_state_route) - .ruma_route(s2s::get_room_state_ids_route) - .ruma_route(s2s::create_join_event_template_route) - .ruma_route(s2s::create_join_event_v1_route) - .ruma_route(s2s::create_join_event_v2_route) - .ruma_route(s2s::create_invite_route) - .ruma_route(s2s::get_devices_route) - .ruma_route(s2s::get_room_information_route) - .ruma_route(s2s::get_profile_information_route) - .ruma_route(s2s::get_keys_route) - .ruma_route(s2s::claim_keys_route) - .ruma_route(s2s::media_download_route) - .ruma_route(s2s::media_thumbnail_route) - } else { - router - .route("/_matrix/federation/*path", any(federation_disabled)) - .route("/_matrix/key/*path", any(federation_disabled)) - } -} - -async fn shutdown_signal(handles: Vec) { - let ctrl_c = async { - signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); - }; - - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; - - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); - - let sig: &str; - - tokio::select! { - () = ctrl_c => { sig = "Ctrl+C"; }, - () = terminate => { sig = "SIGTERM"; }, - } - - warn!(signal = %sig, "Shutting down due to signal"); - - services().globals.shutdown(); - - for handle in handles { - handle.graceful_shutdown(Some(Duration::from_secs(30))); - } - - #[cfg(feature = "systemd")] - sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]) - .expect("should be able to notify systemd"); -} - -async fn federation_disabled(_: Uri) -> impl IntoResponse { - Error::bad_config("Federation is disabled.") -} - -async fn unauthenticated_media_disabled(_: Uri) -> impl IntoResponse { - Error::BadRequest( - ErrorKind::NotFound, - "Unauthenticated media access is disabled", - ) -} - -async fn not_found(method: Method, uri: Uri) -> impl IntoResponse { - debug!(%method, %uri, "Unknown route"); - Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") -} - -async fn initial_sync(_uri: Uri) -> impl IntoResponse { - Error::BadRequest( - ErrorKind::GuestAccessForbidden, - "Guest access not implemented", - ) -} - -async fn it_works() -> &'static str { - "Hello from Grapevine!" -} - -trait RouterExt { - fn ruma_route(self, handler: H) -> Self - where - H: RumaHandler, - T: 'static; -} - -impl RouterExt for Router { - fn ruma_route(self, handler: H) -> Self - where - H: RumaHandler, - T: 'static, - { - handler.add_to_router(self) - } -} - -pub(crate) trait RumaHandler { - // Can't transform to a handler without boxing or relying on the - // nightly-only impl-trait-in-traits feature. Moving a small amount of - // extra logic into the trait allows bypassing both. - fn add_to_router(self, router: Router) -> Router; -} - -macro_rules! impl_ruma_handler { - ( $($ty:ident),* $(,)? ) => { - #[axum::async_trait] - #[allow(non_snake_case)] - impl - RumaHandler<($($ty,)* Ar,)> for F - where - Req: IncomingRequest + Send + 'static, - Resp: IntoResponse, - F: FnOnce($($ty,)* Ar) -> Fut + Clone + Send + 'static, - Fut: Future> - + Send, - E: IntoResponse, - $( $ty: FromRequestParts<()> + Send + 'static, )* - { - fn add_to_router(self, mut router: Router) -> Router { - let meta = Req::METADATA; - let method_filter = method_to_filter(meta.method); - - for path in meta.history.all_paths() { - let handler = self.clone(); - - router = router.route( - path, - on( - method_filter, - |$( $ty: $ty, )* req: Ar| async move { - let span = info_span!( - "run_ruma_handler", - auth.user = ?req.sender_user, - auth.device = ?req.sender_device, - auth.servername = ?req.sender_servername, - auth.appservice_id = ?req.appservice_info - .as_ref() - .map(|i| &i.registration.id) - ); - handler($($ty,)* req).instrument(span).await - } - ) - ) - } - - router - } - } - }; -} - -impl_ruma_handler!(); -impl_ruma_handler!(T1); -impl_ruma_handler!(T1, T2); -impl_ruma_handler!(T1, T2, T3); -impl_ruma_handler!(T1, T2, T3, T4); -impl_ruma_handler!(T1, T2, T3, T4, T5); -impl_ruma_handler!(T1, T2, T3, T4, T5, T6); -impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7); -impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8); - -fn method_to_filter(method: Method) -> MethodFilter { - match method { - Method::DELETE => MethodFilter::DELETE, - Method::GET => MethodFilter::GET, - Method::HEAD => MethodFilter::HEAD, - Method::OPTIONS => MethodFilter::OPTIONS, - Method::PATCH => MethodFilter::PATCH, - Method::POST => MethodFilter::POST, - Method::PUT => MethodFilter::PUT, - Method::TRACE => MethodFilter::TRACE, - m => panic!("Unsupported HTTP method: {m:?}"), - } -} - -#[cfg(unix)] -#[tracing::instrument(err)] -fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { - use nix::sys::resource::{getrlimit, setrlimit, Resource}; - - let res = Resource::RLIMIT_NOFILE; - - let (soft_limit, hard_limit) = getrlimit(res)?; - - debug!(soft_limit, "Current nofile soft limit"); - - setrlimit(res, hard_limit, hard_limit)?; - - debug!(hard_limit, "Increased nofile soft limit to the hard limit"); - - Ok(()) -}