mirror of
https://gitlab.computer.surgery/matrix/grapevine.git
synced 2025-12-16 15:21:24 +01:00
move 'serve' command logic into a submodule of 'cli'
The changes to 'main.rs' and 'cli/serve.rs' in this commit are almost pure code-motion.
This commit is contained in:
parent
be87774a3b
commit
86515d53cc
3 changed files with 677 additions and 659 deletions
15
src/cli.rs
15
src/cli.rs
|
|
@ -1,9 +1,16 @@
|
|||
//! Integration with `clap`
|
||||
//!
|
||||
//! CLI argument structs are defined in this module. Execution logic for each
|
||||
//! command goes in a submodule.
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
use crate::error;
|
||||
|
||||
mod serve;
|
||||
|
||||
/// Command line arguments
|
||||
#[derive(Parser)]
|
||||
#[clap(
|
||||
|
|
@ -49,3 +56,11 @@ pub(crate) struct ServeArgs {
|
|||
#[clap(flatten)]
|
||||
pub(crate) config: ConfigArg,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
pub(crate) async fn run(self) -> Result<(), error::Main> {
|
||||
match self.command {
|
||||
Command::Serve(args) => serve::run(args).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
657
src/cli/serve.rs
Normal file
657
src/cli/serve.rs
Normal file
|
|
@ -0,0 +1,657 @@
|
|||
use std::{future::Future, net::SocketAddr, sync::atomic, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
|
||||
response::IntoResponse,
|
||||
routing::{any, get, on, MethodFilter},
|
||||
Router,
|
||||
};
|
||||
use axum_server::{
|
||||
bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle,
|
||||
};
|
||||
use futures_util::FutureExt;
|
||||
use http::{
|
||||
header::{self, HeaderName},
|
||||
Method, StatusCode, Uri,
|
||||
};
|
||||
use ruma::api::{
|
||||
client::{
|
||||
error::{Error as RumaError, ErrorBody, ErrorKind},
|
||||
uiaa::UiaaResponse,
|
||||
},
|
||||
IncomingRequest,
|
||||
};
|
||||
use tokio::{signal, task::JoinSet};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::{
|
||||
cors::{self, CorsLayer},
|
||||
trace::TraceLayer,
|
||||
ServiceBuilderExt as _,
|
||||
};
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use super::ServeArgs;
|
||||
use crate::{
|
||||
api::{
|
||||
client_server,
|
||||
ruma_wrapper::{Ar, Ra},
|
||||
server_server, well_known,
|
||||
},
|
||||
config,
|
||||
config::{Config, ListenConfig},
|
||||
database::KeyValueDatabase,
|
||||
error, observability, services, utils,
|
||||
utils::error::{Error, Result},
|
||||
};
|
||||
|
||||
pub(crate) async fn run(args: ServeArgs) -> Result<(), error::Main> {
|
||||
use error::Main as Error;
|
||||
|
||||
let config = config::load(args.config.config.as_ref()).await?;
|
||||
|
||||
let (_guard, reload_handles) = observability::init(&config)?;
|
||||
|
||||
// This is needed for opening lots of file descriptors, which tends to
|
||||
// happen more often when using RocksDB and making lots of federation
|
||||
// connections at startup. The soft limit is usually 1024, and the hard
|
||||
// limit is usually 512000; I've personally seen it hit >2000.
|
||||
//
|
||||
// * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6
|
||||
// * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741
|
||||
#[cfg(unix)]
|
||||
maximize_fd_limit()
|
||||
.expect("should be able to increase the soft limit to the hard limit");
|
||||
|
||||
info!("Loading database");
|
||||
KeyValueDatabase::load_or_create(config, reload_handles)
|
||||
.await
|
||||
.map_err(Error::DatabaseError)?;
|
||||
|
||||
info!("Starting server");
|
||||
run_server().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn run_server() -> Result<(), error::Serve> {
|
||||
use error::Serve as Error;
|
||||
|
||||
let config = &services().globals.config;
|
||||
|
||||
let x_requested_with = HeaderName::from_static("x-requested-with");
|
||||
|
||||
let middlewares = ServiceBuilder::new()
|
||||
.sensitive_headers([header::AUTHORIZATION])
|
||||
.layer(axum::middleware::from_fn(spawn_task))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(|request: &http::Request<_>| {
|
||||
let endpoint = if let Some(endpoint) =
|
||||
request.extensions().get::<MatchedPath>()
|
||||
{
|
||||
endpoint.as_str()
|
||||
} else {
|
||||
request.uri().path()
|
||||
};
|
||||
|
||||
let method = request.method();
|
||||
|
||||
tracing::info_span!(
|
||||
"http_request",
|
||||
otel.name = format!("{method} {endpoint}"),
|
||||
%method,
|
||||
%endpoint,
|
||||
)
|
||||
})
|
||||
.on_request(
|
||||
|request: &http::Request<_>, _span: &tracing::Span| {
|
||||
// can be enabled selectively using `filter =
|
||||
// grapevine[incoming_request_curl]=trace` in config
|
||||
tracing::trace_span!("incoming_request_curl").in_scope(
|
||||
|| {
|
||||
tracing::trace!(
|
||||
cmd = utils::curlify(request),
|
||||
"curl command line for incoming request \
|
||||
(guessed hostname)"
|
||||
);
|
||||
},
|
||||
);
|
||||
},
|
||||
),
|
||||
)
|
||||
.layer(axum::middleware::from_fn(unrecognized_method))
|
||||
.layer(
|
||||
CorsLayer::new()
|
||||
.allow_origin(cors::Any)
|
||||
.allow_methods([
|
||||
Method::GET,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::DELETE,
|
||||
Method::OPTIONS,
|
||||
])
|
||||
.allow_headers([
|
||||
header::ORIGIN,
|
||||
x_requested_with,
|
||||
header::CONTENT_TYPE,
|
||||
header::ACCEPT,
|
||||
header::AUTHORIZATION,
|
||||
])
|
||||
.max_age(Duration::from_secs(86400)),
|
||||
)
|
||||
.layer(DefaultBodyLimit::max(
|
||||
config
|
||||
.max_request_size
|
||||
.try_into()
|
||||
.expect("failed to convert max request size"),
|
||||
))
|
||||
.layer(axum::middleware::from_fn(observability::http_metrics_layer));
|
||||
|
||||
let app = routes(config).layer(middlewares).into_make_service();
|
||||
let mut handles = Vec::new();
|
||||
let mut servers = JoinSet::new();
|
||||
|
||||
let tls_config = if let Some(tls) = &config.tls {
|
||||
Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err(
|
||||
|err| Error::LoadCerts {
|
||||
certs: tls.certs.clone(),
|
||||
key: tls.key.clone(),
|
||||
err,
|
||||
},
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if config.listen.is_empty() {
|
||||
return Err(Error::NoListeners);
|
||||
}
|
||||
|
||||
for listen in &config.listen {
|
||||
info!(listener = %listen, "Listening for incoming traffic");
|
||||
match listen {
|
||||
ListenConfig::Tcp {
|
||||
address,
|
||||
port,
|
||||
tls,
|
||||
} => {
|
||||
let addr = SocketAddr::from((*address, *port));
|
||||
let handle = ServerHandle::new();
|
||||
handles.push(handle.clone());
|
||||
let server = if *tls {
|
||||
let tls_config = tls_config
|
||||
.clone()
|
||||
.ok_or_else(|| Error::NoTlsCerts(listen.clone()))?;
|
||||
bind_rustls(addr, tls_config)
|
||||
.handle(handle)
|
||||
.serve(app.clone())
|
||||
.left_future()
|
||||
} else {
|
||||
bind(addr).handle(handle).serve(app.clone()).right_future()
|
||||
};
|
||||
servers.spawn(
|
||||
server.then(|result| async { (listen.clone(), result) }),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "systemd")]
|
||||
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
|
||||
.expect("should be able to notify systemd");
|
||||
|
||||
tokio::spawn(shutdown_signal(handles));
|
||||
|
||||
while let Some(result) = servers.join_next().await {
|
||||
let (listen, result) =
|
||||
result.expect("should be able to join server task");
|
||||
result.map_err(|err| Error::Listen(err, listen))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Ensures the request runs in a new tokio thread.
|
||||
///
|
||||
/// The axum request handler task gets cancelled if the connection is shut down;
|
||||
/// by spawning our own task, processing continue after the client disconnects.
|
||||
async fn spawn_task(
|
||||
req: axum::extract::Request,
|
||||
next: axum::middleware::Next,
|
||||
) -> std::result::Result<axum::response::Response, StatusCode> {
|
||||
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
|
||||
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
tokio::spawn(next.run(req))
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
async fn unrecognized_method(
|
||||
req: axum::extract::Request,
|
||||
next: axum::middleware::Next,
|
||||
) -> std::result::Result<axum::response::Response, StatusCode> {
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let inner = next.run(req).await;
|
||||
if inner.status() == StatusCode::METHOD_NOT_ALLOWED {
|
||||
warn!(%method, %uri, "Method not allowed");
|
||||
return Ok(Ra(UiaaResponse::MatrixError(RumaError {
|
||||
body: ErrorBody::Standard {
|
||||
kind: ErrorKind::Unrecognized,
|
||||
message: "M_UNRECOGNIZED: Unrecognized request".to_owned(),
|
||||
},
|
||||
status_code: StatusCode::METHOD_NOT_ALLOWED,
|
||||
}))
|
||||
.into_response());
|
||||
}
|
||||
Ok(inner)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn routes(config: &Config) -> Router {
|
||||
use client_server as c2s;
|
||||
use server_server as s2s;
|
||||
|
||||
let router = Router::new()
|
||||
.ruma_route(c2s::get_supported_versions_route)
|
||||
.ruma_route(c2s::get_register_available_route)
|
||||
.ruma_route(c2s::register_route)
|
||||
.ruma_route(c2s::get_login_types_route)
|
||||
.ruma_route(c2s::login_route)
|
||||
.ruma_route(c2s::whoami_route)
|
||||
.ruma_route(c2s::logout_route)
|
||||
.ruma_route(c2s::logout_all_route)
|
||||
.ruma_route(c2s::change_password_route)
|
||||
.ruma_route(c2s::deactivate_route)
|
||||
.ruma_route(c2s::third_party_route)
|
||||
.ruma_route(c2s::request_3pid_management_token_via_email_route)
|
||||
.ruma_route(c2s::request_3pid_management_token_via_msisdn_route)
|
||||
.ruma_route(c2s::get_capabilities_route)
|
||||
.ruma_route(c2s::get_pushrules_all_route)
|
||||
.ruma_route(c2s::set_pushrule_route)
|
||||
.ruma_route(c2s::get_pushrule_route)
|
||||
.ruma_route(c2s::set_pushrule_enabled_route)
|
||||
.ruma_route(c2s::get_pushrule_enabled_route)
|
||||
.ruma_route(c2s::get_pushrule_actions_route)
|
||||
.ruma_route(c2s::set_pushrule_actions_route)
|
||||
.ruma_route(c2s::delete_pushrule_route)
|
||||
.ruma_route(c2s::get_room_event_route)
|
||||
.ruma_route(c2s::get_room_aliases_route)
|
||||
.ruma_route(c2s::get_filter_route)
|
||||
.ruma_route(c2s::create_filter_route)
|
||||
.ruma_route(c2s::set_global_account_data_route)
|
||||
.ruma_route(c2s::set_room_account_data_route)
|
||||
.ruma_route(c2s::get_global_account_data_route)
|
||||
.ruma_route(c2s::get_room_account_data_route)
|
||||
.ruma_route(c2s::set_displayname_route)
|
||||
.ruma_route(c2s::get_displayname_route)
|
||||
.ruma_route(c2s::set_avatar_url_route)
|
||||
.ruma_route(c2s::get_avatar_url_route)
|
||||
.ruma_route(c2s::get_profile_route)
|
||||
.ruma_route(c2s::upload_keys_route)
|
||||
.ruma_route(c2s::get_keys_route)
|
||||
.ruma_route(c2s::claim_keys_route)
|
||||
.ruma_route(c2s::create_backup_version_route)
|
||||
.ruma_route(c2s::update_backup_version_route)
|
||||
.ruma_route(c2s::delete_backup_version_route)
|
||||
.ruma_route(c2s::get_latest_backup_info_route)
|
||||
.ruma_route(c2s::get_backup_info_route)
|
||||
.ruma_route(c2s::add_backup_keys_route)
|
||||
.ruma_route(c2s::add_backup_keys_for_room_route)
|
||||
.ruma_route(c2s::add_backup_keys_for_session_route)
|
||||
.ruma_route(c2s::delete_backup_keys_for_room_route)
|
||||
.ruma_route(c2s::delete_backup_keys_for_session_route)
|
||||
.ruma_route(c2s::delete_backup_keys_route)
|
||||
.ruma_route(c2s::get_backup_keys_for_room_route)
|
||||
.ruma_route(c2s::get_backup_keys_for_session_route)
|
||||
.ruma_route(c2s::get_backup_keys_route)
|
||||
.ruma_route(c2s::set_read_marker_route)
|
||||
.ruma_route(c2s::create_receipt_route)
|
||||
.ruma_route(c2s::create_typing_event_route)
|
||||
.ruma_route(c2s::create_room_route)
|
||||
.ruma_route(c2s::redact_event_route)
|
||||
.ruma_route(c2s::report_event_route)
|
||||
.ruma_route(c2s::create_alias_route)
|
||||
.ruma_route(c2s::delete_alias_route)
|
||||
.ruma_route(c2s::get_alias_route)
|
||||
.ruma_route(c2s::join_room_by_id_route)
|
||||
.ruma_route(c2s::join_room_by_id_or_alias_route)
|
||||
.ruma_route(c2s::joined_members_route)
|
||||
.ruma_route(c2s::leave_room_route)
|
||||
.ruma_route(c2s::forget_room_route)
|
||||
.ruma_route(c2s::joined_rooms_route)
|
||||
.ruma_route(c2s::kick_user_route)
|
||||
.ruma_route(c2s::ban_user_route)
|
||||
.ruma_route(c2s::unban_user_route)
|
||||
.ruma_route(c2s::invite_user_route)
|
||||
.ruma_route(c2s::set_room_visibility_route)
|
||||
.ruma_route(c2s::get_room_visibility_route)
|
||||
.ruma_route(c2s::get_public_rooms_route)
|
||||
.ruma_route(c2s::get_public_rooms_filtered_route)
|
||||
.ruma_route(c2s::search_users_route)
|
||||
.ruma_route(c2s::get_member_events_route)
|
||||
.ruma_route(c2s::get_protocols_route)
|
||||
.ruma_route(c2s::send_message_event_route)
|
||||
.ruma_route(c2s::send_state_event_for_key_route)
|
||||
.ruma_route(c2s::get_state_events_route)
|
||||
.ruma_route(c2s::get_state_events_for_key_route)
|
||||
.ruma_route(c2s::sync_events_route)
|
||||
.ruma_route(c2s::sync_events_v4_route)
|
||||
.ruma_route(c2s::get_context_route)
|
||||
.ruma_route(c2s::get_message_events_route)
|
||||
.ruma_route(c2s::search_events_route)
|
||||
.ruma_route(c2s::turn_server_route)
|
||||
.ruma_route(c2s::send_event_to_device_route);
|
||||
|
||||
// deprecated, but unproblematic
|
||||
let router = router.ruma_route(c2s::get_media_config_legacy_route);
|
||||
let router = if config.serve_media_unauthenticated {
|
||||
router
|
||||
.ruma_route(c2s::get_content_legacy_route)
|
||||
.ruma_route(c2s::get_content_as_filename_legacy_route)
|
||||
.ruma_route(c2s::get_content_thumbnail_legacy_route)
|
||||
} else {
|
||||
router
|
||||
.route(
|
||||
"/_matrix/media/v3/download/*path",
|
||||
any(unauthenticated_media_disabled),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/media/v3/thumbnail/*path",
|
||||
any(unauthenticated_media_disabled),
|
||||
)
|
||||
};
|
||||
|
||||
// authenticated media
|
||||
let router = router
|
||||
.ruma_route(c2s::get_media_config_route)
|
||||
.ruma_route(c2s::create_content_route)
|
||||
.ruma_route(c2s::get_content_route)
|
||||
.ruma_route(c2s::get_content_as_filename_route)
|
||||
.ruma_route(c2s::get_content_thumbnail_route);
|
||||
|
||||
let router = router
|
||||
.ruma_route(c2s::get_devices_route)
|
||||
.ruma_route(c2s::get_device_route)
|
||||
.ruma_route(c2s::update_device_route)
|
||||
.ruma_route(c2s::delete_device_route)
|
||||
.ruma_route(c2s::delete_devices_route)
|
||||
.ruma_route(c2s::get_tags_route)
|
||||
.ruma_route(c2s::update_tag_route)
|
||||
.ruma_route(c2s::delete_tag_route)
|
||||
.ruma_route(c2s::upload_signing_keys_route)
|
||||
.ruma_route(c2s::upload_signatures_route)
|
||||
.ruma_route(c2s::get_key_changes_route)
|
||||
.ruma_route(c2s::get_pushers_route)
|
||||
.ruma_route(c2s::set_pushers_route)
|
||||
.ruma_route(c2s::upgrade_room_route)
|
||||
.ruma_route(c2s::get_threads_route)
|
||||
.ruma_route(c2s::get_relating_events_with_rel_type_and_event_type_route)
|
||||
.ruma_route(c2s::get_relating_events_with_rel_type_route)
|
||||
.ruma_route(c2s::get_relating_events_route)
|
||||
.ruma_route(c2s::get_hierarchy_route);
|
||||
|
||||
// Ruma doesn't have support for multiple paths for a single endpoint yet,
|
||||
// and these routes share one Ruma request / response type pair with
|
||||
// {get,send}_state_event_for_key_route. These two endpoints also allow
|
||||
// trailing slashes.
|
||||
let router = router
|
||||
.route(
|
||||
"/_matrix/client/r0/rooms/:room_id/state/:event_type",
|
||||
get(c2s::get_state_events_for_empty_key_route)
|
||||
.put(c2s::send_state_event_for_empty_key_route),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/client/v3/rooms/:room_id/state/:event_type",
|
||||
get(c2s::get_state_events_for_empty_key_route)
|
||||
.put(c2s::send_state_event_for_empty_key_route),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/client/r0/rooms/:room_id/state/:event_type/",
|
||||
get(c2s::get_state_events_for_empty_key_route)
|
||||
.put(c2s::send_state_event_for_empty_key_route),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/client/v3/rooms/:room_id/state/:event_type/",
|
||||
get(c2s::get_state_events_for_empty_key_route)
|
||||
.put(c2s::send_state_event_for_empty_key_route),
|
||||
);
|
||||
|
||||
let router = if config.observability.metrics.enable {
|
||||
router.route(
|
||||
"/metrics",
|
||||
get(|| async { observability::METRICS.export() }),
|
||||
)
|
||||
} else {
|
||||
router
|
||||
};
|
||||
|
||||
let router = router
|
||||
.route(
|
||||
"/_matrix/client/r0/rooms/:room_id/initialSync",
|
||||
get(initial_sync),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/client/v3/rooms/:room_id/initialSync",
|
||||
get(initial_sync),
|
||||
)
|
||||
.route("/", get(it_works))
|
||||
.fallback(not_found);
|
||||
|
||||
let router = router
|
||||
.route("/.well-known/matrix/client", get(well_known::client))
|
||||
.route("/.well-known/matrix/server", get(well_known::server));
|
||||
|
||||
if config.federation.enable {
|
||||
router
|
||||
.ruma_route(s2s::get_server_version_route)
|
||||
.route("/_matrix/key/v2/server", get(s2s::get_server_keys_route))
|
||||
.route(
|
||||
"/_matrix/key/v2/server/:key_id",
|
||||
get(s2s::get_server_keys_deprecated_route),
|
||||
)
|
||||
.ruma_route(s2s::get_public_rooms_route)
|
||||
.ruma_route(s2s::get_public_rooms_filtered_route)
|
||||
.ruma_route(s2s::send_transaction_message_route)
|
||||
.ruma_route(s2s::get_event_route)
|
||||
.ruma_route(s2s::get_backfill_route)
|
||||
.ruma_route(s2s::get_missing_events_route)
|
||||
.ruma_route(s2s::get_event_authorization_route)
|
||||
.ruma_route(s2s::get_room_state_route)
|
||||
.ruma_route(s2s::get_room_state_ids_route)
|
||||
.ruma_route(s2s::create_join_event_template_route)
|
||||
.ruma_route(s2s::create_join_event_v1_route)
|
||||
.ruma_route(s2s::create_join_event_v2_route)
|
||||
.ruma_route(s2s::create_invite_route)
|
||||
.ruma_route(s2s::get_devices_route)
|
||||
.ruma_route(s2s::get_room_information_route)
|
||||
.ruma_route(s2s::get_profile_information_route)
|
||||
.ruma_route(s2s::get_keys_route)
|
||||
.ruma_route(s2s::claim_keys_route)
|
||||
.ruma_route(s2s::media_download_route)
|
||||
.ruma_route(s2s::media_thumbnail_route)
|
||||
} else {
|
||||
router
|
||||
.route("/_matrix/federation/*path", any(federation_disabled))
|
||||
.route("/_matrix/key/*path", any(federation_disabled))
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown_signal(handles: Vec<ServerHandle>) {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
let sig: &str;
|
||||
|
||||
tokio::select! {
|
||||
() = ctrl_c => { sig = "Ctrl+C"; },
|
||||
() = terminate => { sig = "SIGTERM"; },
|
||||
}
|
||||
|
||||
warn!(signal = %sig, "Shutting down due to signal");
|
||||
|
||||
services().globals.shutdown();
|
||||
|
||||
for handle in handles {
|
||||
handle.graceful_shutdown(Some(Duration::from_secs(30)));
|
||||
}
|
||||
|
||||
#[cfg(feature = "systemd")]
|
||||
sd_notify::notify(true, &[sd_notify::NotifyState::Stopping])
|
||||
.expect("should be able to notify systemd");
|
||||
}
|
||||
|
||||
async fn federation_disabled(_: Uri) -> impl IntoResponse {
|
||||
Error::bad_config("Federation is disabled.")
|
||||
}
|
||||
|
||||
async fn unauthenticated_media_disabled(_: Uri) -> impl IntoResponse {
|
||||
Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Unauthenticated media access is disabled",
|
||||
)
|
||||
}
|
||||
|
||||
async fn not_found(method: Method, uri: Uri) -> impl IntoResponse {
|
||||
debug!(%method, %uri, "Unknown route");
|
||||
Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request")
|
||||
}
|
||||
|
||||
async fn initial_sync(_uri: Uri) -> impl IntoResponse {
|
||||
Error::BadRequest(
|
||||
ErrorKind::GuestAccessForbidden,
|
||||
"Guest access not implemented",
|
||||
)
|
||||
}
|
||||
|
||||
async fn it_works() -> &'static str {
|
||||
"Hello from Grapevine!"
|
||||
}
|
||||
|
||||
trait RouterExt {
|
||||
fn ruma_route<H, T>(self, handler: H) -> Self
|
||||
where
|
||||
H: RumaHandler<T>,
|
||||
T: 'static;
|
||||
}
|
||||
|
||||
impl RouterExt for Router {
|
||||
fn ruma_route<H, T>(self, handler: H) -> Self
|
||||
where
|
||||
H: RumaHandler<T>,
|
||||
T: 'static,
|
||||
{
|
||||
handler.add_to_router(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait RumaHandler<T> {
|
||||
// Can't transform to a handler without boxing or relying on the
|
||||
// nightly-only impl-trait-in-traits feature. Moving a small amount of
|
||||
// extra logic into the trait allows bypassing both.
|
||||
fn add_to_router(self, router: Router) -> Router;
|
||||
}
|
||||
|
||||
macro_rules! impl_ruma_handler {
|
||||
( $($ty:ident),* $(,)? ) => {
|
||||
#[axum::async_trait]
|
||||
#[allow(non_snake_case)]
|
||||
impl<Req, Resp, E, F, Fut, $($ty,)*>
|
||||
RumaHandler<($($ty,)* Ar<Req>,)> for F
|
||||
where
|
||||
Req: IncomingRequest + Send + 'static,
|
||||
Resp: IntoResponse,
|
||||
F: FnOnce($($ty,)* Ar<Req>) -> Fut + Clone + Send + 'static,
|
||||
Fut: Future<Output = Result<Resp, E>>
|
||||
+ Send,
|
||||
E: IntoResponse,
|
||||
$( $ty: FromRequestParts<()> + Send + 'static, )*
|
||||
{
|
||||
fn add_to_router(self, mut router: Router) -> Router {
|
||||
let meta = Req::METADATA;
|
||||
let method_filter = method_to_filter(meta.method);
|
||||
|
||||
for path in meta.history.all_paths() {
|
||||
let handler = self.clone();
|
||||
|
||||
router = router.route(
|
||||
path,
|
||||
on(
|
||||
method_filter,
|
||||
|$( $ty: $ty, )* req: Ar<Req>| async move {
|
||||
let span = info_span!(
|
||||
"run_ruma_handler",
|
||||
auth.user = ?req.sender_user,
|
||||
auth.device = ?req.sender_device,
|
||||
auth.servername = ?req.sender_servername,
|
||||
auth.appservice_id = ?req.appservice_info
|
||||
.as_ref()
|
||||
.map(|i| &i.registration.id)
|
||||
);
|
||||
handler($($ty,)* req).instrument(span).await
|
||||
}
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
router
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_ruma_handler!();
|
||||
impl_ruma_handler!(T1);
|
||||
impl_ruma_handler!(T1, T2);
|
||||
impl_ruma_handler!(T1, T2, T3);
|
||||
impl_ruma_handler!(T1, T2, T3, T4);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5, T6);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8);
|
||||
|
||||
fn method_to_filter(method: Method) -> MethodFilter {
|
||||
match method {
|
||||
Method::DELETE => MethodFilter::DELETE,
|
||||
Method::GET => MethodFilter::GET,
|
||||
Method::HEAD => MethodFilter::HEAD,
|
||||
Method::OPTIONS => MethodFilter::OPTIONS,
|
||||
Method::PATCH => MethodFilter::PATCH,
|
||||
Method::POST => MethodFilter::POST,
|
||||
Method::PUT => MethodFilter::PUT,
|
||||
Method::TRACE => MethodFilter::TRACE,
|
||||
m => panic!("Unsupported HTTP method: {m:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tracing::instrument(err)]
|
||||
fn maximize_fd_limit() -> Result<(), nix::errno::Errno> {
|
||||
use nix::sys::resource::{getrlimit, setrlimit, Resource};
|
||||
|
||||
let res = Resource::RLIMIT_NOFILE;
|
||||
|
||||
let (soft_limit, hard_limit) = getrlimit(res)?;
|
||||
|
||||
debug!(soft_limit, "Current nofile soft limit");
|
||||
|
||||
setrlimit(res, hard_limit, hard_limit)?;
|
||||
|
||||
debug!(hard_limit, "Increased nofile soft limit to the hard limit");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
664
src/main.rs
664
src/main.rs
|
|
@ -2,44 +2,10 @@
|
|||
// work anyway
|
||||
#![cfg_attr(not(any(feature = "sqlite", feature = "rocksdb")), allow(unused))]
|
||||
|
||||
use std::{
|
||||
future::Future,
|
||||
net::SocketAddr,
|
||||
process::ExitCode,
|
||||
sync::{atomic, RwLock},
|
||||
time::Duration,
|
||||
};
|
||||
use std::{process::ExitCode, sync::RwLock};
|
||||
|
||||
use axum::{
|
||||
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
|
||||
response::IntoResponse,
|
||||
routing::{any, get, on, MethodFilter},
|
||||
Router,
|
||||
};
|
||||
use axum_server::{
|
||||
bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle,
|
||||
};
|
||||
use clap::Parser;
|
||||
use futures_util::FutureExt;
|
||||
use http::{
|
||||
header::{self, HeaderName},
|
||||
Method, StatusCode, Uri,
|
||||
};
|
||||
use ruma::api::{
|
||||
client::{
|
||||
error::{Error as RumaError, ErrorBody, ErrorKind},
|
||||
uiaa::UiaaResponse,
|
||||
},
|
||||
IncomingRequest,
|
||||
};
|
||||
use tokio::{signal, task::JoinSet};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::{
|
||||
cors::{self, CorsLayer},
|
||||
trace::TraceLayer,
|
||||
ServiceBuilderExt as _,
|
||||
};
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
use tracing::error;
|
||||
|
||||
mod api;
|
||||
mod cli;
|
||||
|
|
@ -51,10 +17,7 @@ mod service;
|
|||
mod utils;
|
||||
|
||||
pub(crate) use api::ruma_wrapper::{Ar, Ra};
|
||||
use api::{client_server, server_server, well_known};
|
||||
use cli::{Args, Command};
|
||||
pub(crate) use config::{Config, ListenConfig};
|
||||
pub(crate) use database::KeyValueDatabase;
|
||||
pub(crate) use config::Config;
|
||||
pub(crate) use service::{pdu::PduEvent, Services};
|
||||
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
|
||||
use tikv_jemallocator::Jemalloc;
|
||||
|
|
@ -91,7 +54,8 @@ fn version() -> String {
|
|||
|
||||
#[tokio::main]
|
||||
async fn main() -> ExitCode {
|
||||
let Err(e) = try_main().await else {
|
||||
let args = cli::Args::parse();
|
||||
let Err(e) = args.run().await else {
|
||||
return ExitCode::SUCCESS;
|
||||
};
|
||||
|
||||
|
|
@ -105,621 +69,3 @@ async fn main() -> ExitCode {
|
|||
|
||||
ExitCode::FAILURE
|
||||
}
|
||||
|
||||
/// Fallible entrypoint
|
||||
async fn try_main() -> Result<(), error::Main> {
|
||||
use error::Main as Error;
|
||||
|
||||
let args = Args::parse();
|
||||
// This is a placeholder, the logic specific to the 'serve' command will be
|
||||
// moved to another file in a later commit
|
||||
let Command::Serve(args) = args.command;
|
||||
|
||||
let config = config::load(args.config.config.as_ref()).await?;
|
||||
|
||||
let (_guard, reload_handles) = observability::init(&config)?;
|
||||
|
||||
// This is needed for opening lots of file descriptors, which tends to
|
||||
// happen more often when using RocksDB and making lots of federation
|
||||
// connections at startup. The soft limit is usually 1024, and the hard
|
||||
// limit is usually 512000; I've personally seen it hit >2000.
|
||||
//
|
||||
// * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6
|
||||
// * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741
|
||||
#[cfg(unix)]
|
||||
maximize_fd_limit()
|
||||
.expect("should be able to increase the soft limit to the hard limit");
|
||||
|
||||
info!("Loading database");
|
||||
KeyValueDatabase::load_or_create(config, reload_handles)
|
||||
.await
|
||||
.map_err(Error::DatabaseError)?;
|
||||
|
||||
info!("Starting server");
|
||||
run_server().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn run_server() -> Result<(), error::Serve> {
|
||||
use error::Serve as Error;
|
||||
|
||||
let config = &services().globals.config;
|
||||
|
||||
let x_requested_with = HeaderName::from_static("x-requested-with");
|
||||
|
||||
let middlewares = ServiceBuilder::new()
|
||||
.sensitive_headers([header::AUTHORIZATION])
|
||||
.layer(axum::middleware::from_fn(spawn_task))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(|request: &http::Request<_>| {
|
||||
let endpoint = if let Some(endpoint) =
|
||||
request.extensions().get::<MatchedPath>()
|
||||
{
|
||||
endpoint.as_str()
|
||||
} else {
|
||||
request.uri().path()
|
||||
};
|
||||
|
||||
let method = request.method();
|
||||
|
||||
tracing::info_span!(
|
||||
"http_request",
|
||||
otel.name = format!("{method} {endpoint}"),
|
||||
%method,
|
||||
%endpoint,
|
||||
)
|
||||
})
|
||||
.on_request(
|
||||
|request: &http::Request<_>, _span: &tracing::Span| {
|
||||
// can be enabled selectively using `filter =
|
||||
// grapevine[incoming_request_curl]=trace` in config
|
||||
tracing::trace_span!("incoming_request_curl").in_scope(
|
||||
|| {
|
||||
tracing::trace!(
|
||||
cmd = utils::curlify(request),
|
||||
"curl command line for incoming request \
|
||||
(guessed hostname)"
|
||||
);
|
||||
},
|
||||
);
|
||||
},
|
||||
),
|
||||
)
|
||||
.layer(axum::middleware::from_fn(unrecognized_method))
|
||||
.layer(
|
||||
CorsLayer::new()
|
||||
.allow_origin(cors::Any)
|
||||
.allow_methods([
|
||||
Method::GET,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::DELETE,
|
||||
Method::OPTIONS,
|
||||
])
|
||||
.allow_headers([
|
||||
header::ORIGIN,
|
||||
x_requested_with,
|
||||
header::CONTENT_TYPE,
|
||||
header::ACCEPT,
|
||||
header::AUTHORIZATION,
|
||||
])
|
||||
.max_age(Duration::from_secs(86400)),
|
||||
)
|
||||
.layer(DefaultBodyLimit::max(
|
||||
config
|
||||
.max_request_size
|
||||
.try_into()
|
||||
.expect("failed to convert max request size"),
|
||||
))
|
||||
.layer(axum::middleware::from_fn(observability::http_metrics_layer));
|
||||
|
||||
let app = routes(config).layer(middlewares).into_make_service();
|
||||
let mut handles = Vec::new();
|
||||
let mut servers = JoinSet::new();
|
||||
|
||||
let tls_config = if let Some(tls) = &config.tls {
|
||||
Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err(
|
||||
|err| Error::LoadCerts {
|
||||
certs: tls.certs.clone(),
|
||||
key: tls.key.clone(),
|
||||
err,
|
||||
},
|
||||
)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if config.listen.is_empty() {
|
||||
return Err(Error::NoListeners);
|
||||
}
|
||||
|
||||
for listen in &config.listen {
|
||||
info!(listener = %listen, "Listening for incoming traffic");
|
||||
match listen {
|
||||
ListenConfig::Tcp {
|
||||
address,
|
||||
port,
|
||||
tls,
|
||||
} => {
|
||||
let addr = SocketAddr::from((*address, *port));
|
||||
let handle = ServerHandle::new();
|
||||
handles.push(handle.clone());
|
||||
let server = if *tls {
|
||||
let tls_config = tls_config
|
||||
.clone()
|
||||
.ok_or_else(|| Error::NoTlsCerts(listen.clone()))?;
|
||||
bind_rustls(addr, tls_config)
|
||||
.handle(handle)
|
||||
.serve(app.clone())
|
||||
.left_future()
|
||||
} else {
|
||||
bind(addr).handle(handle).serve(app.clone()).right_future()
|
||||
};
|
||||
servers.spawn(
|
||||
server.then(|result| async { (listen.clone(), result) }),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "systemd")]
|
||||
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
|
||||
.expect("should be able to notify systemd");
|
||||
|
||||
tokio::spawn(shutdown_signal(handles));
|
||||
|
||||
while let Some(result) = servers.join_next().await {
|
||||
let (listen, result) =
|
||||
result.expect("should be able to join server task");
|
||||
result.map_err(|err| Error::Listen(err, listen))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Ensures the request runs in a new tokio thread.
|
||||
///
|
||||
/// The axum request handler task gets cancelled if the connection is shut down;
|
||||
/// by spawning our own task, processing continue after the client disconnects.
|
||||
async fn spawn_task(
|
||||
req: axum::extract::Request,
|
||||
next: axum::middleware::Next,
|
||||
) -> std::result::Result<axum::response::Response, StatusCode> {
|
||||
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
|
||||
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
tokio::spawn(next.run(req))
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
async fn unrecognized_method(
|
||||
req: axum::extract::Request,
|
||||
next: axum::middleware::Next,
|
||||
) -> std::result::Result<axum::response::Response, StatusCode> {
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let inner = next.run(req).await;
|
||||
if inner.status() == StatusCode::METHOD_NOT_ALLOWED {
|
||||
warn!(%method, %uri, "Method not allowed");
|
||||
return Ok(Ra(UiaaResponse::MatrixError(RumaError {
|
||||
body: ErrorBody::Standard {
|
||||
kind: ErrorKind::Unrecognized,
|
||||
message: "M_UNRECOGNIZED: Unrecognized request".to_owned(),
|
||||
},
|
||||
status_code: StatusCode::METHOD_NOT_ALLOWED,
|
||||
}))
|
||||
.into_response());
|
||||
}
|
||||
Ok(inner)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn routes(config: &Config) -> Router {
|
||||
use client_server as c2s;
|
||||
use server_server as s2s;
|
||||
|
||||
let router = Router::new()
|
||||
.ruma_route(c2s::get_supported_versions_route)
|
||||
.ruma_route(c2s::get_register_available_route)
|
||||
.ruma_route(c2s::register_route)
|
||||
.ruma_route(c2s::get_login_types_route)
|
||||
.ruma_route(c2s::login_route)
|
||||
.ruma_route(c2s::whoami_route)
|
||||
.ruma_route(c2s::logout_route)
|
||||
.ruma_route(c2s::logout_all_route)
|
||||
.ruma_route(c2s::change_password_route)
|
||||
.ruma_route(c2s::deactivate_route)
|
||||
.ruma_route(c2s::third_party_route)
|
||||
.ruma_route(c2s::request_3pid_management_token_via_email_route)
|
||||
.ruma_route(c2s::request_3pid_management_token_via_msisdn_route)
|
||||
.ruma_route(c2s::get_capabilities_route)
|
||||
.ruma_route(c2s::get_pushrules_all_route)
|
||||
.ruma_route(c2s::set_pushrule_route)
|
||||
.ruma_route(c2s::get_pushrule_route)
|
||||
.ruma_route(c2s::set_pushrule_enabled_route)
|
||||
.ruma_route(c2s::get_pushrule_enabled_route)
|
||||
.ruma_route(c2s::get_pushrule_actions_route)
|
||||
.ruma_route(c2s::set_pushrule_actions_route)
|
||||
.ruma_route(c2s::delete_pushrule_route)
|
||||
.ruma_route(c2s::get_room_event_route)
|
||||
.ruma_route(c2s::get_room_aliases_route)
|
||||
.ruma_route(c2s::get_filter_route)
|
||||
.ruma_route(c2s::create_filter_route)
|
||||
.ruma_route(c2s::set_global_account_data_route)
|
||||
.ruma_route(c2s::set_room_account_data_route)
|
||||
.ruma_route(c2s::get_global_account_data_route)
|
||||
.ruma_route(c2s::get_room_account_data_route)
|
||||
.ruma_route(c2s::set_displayname_route)
|
||||
.ruma_route(c2s::get_displayname_route)
|
||||
.ruma_route(c2s::set_avatar_url_route)
|
||||
.ruma_route(c2s::get_avatar_url_route)
|
||||
.ruma_route(c2s::get_profile_route)
|
||||
.ruma_route(c2s::upload_keys_route)
|
||||
.ruma_route(c2s::get_keys_route)
|
||||
.ruma_route(c2s::claim_keys_route)
|
||||
.ruma_route(c2s::create_backup_version_route)
|
||||
.ruma_route(c2s::update_backup_version_route)
|
||||
.ruma_route(c2s::delete_backup_version_route)
|
||||
.ruma_route(c2s::get_latest_backup_info_route)
|
||||
.ruma_route(c2s::get_backup_info_route)
|
||||
.ruma_route(c2s::add_backup_keys_route)
|
||||
.ruma_route(c2s::add_backup_keys_for_room_route)
|
||||
.ruma_route(c2s::add_backup_keys_for_session_route)
|
||||
.ruma_route(c2s::delete_backup_keys_for_room_route)
|
||||
.ruma_route(c2s::delete_backup_keys_for_session_route)
|
||||
.ruma_route(c2s::delete_backup_keys_route)
|
||||
.ruma_route(c2s::get_backup_keys_for_room_route)
|
||||
.ruma_route(c2s::get_backup_keys_for_session_route)
|
||||
.ruma_route(c2s::get_backup_keys_route)
|
||||
.ruma_route(c2s::set_read_marker_route)
|
||||
.ruma_route(c2s::create_receipt_route)
|
||||
.ruma_route(c2s::create_typing_event_route)
|
||||
.ruma_route(c2s::create_room_route)
|
||||
.ruma_route(c2s::redact_event_route)
|
||||
.ruma_route(c2s::report_event_route)
|
||||
.ruma_route(c2s::create_alias_route)
|
||||
.ruma_route(c2s::delete_alias_route)
|
||||
.ruma_route(c2s::get_alias_route)
|
||||
.ruma_route(c2s::join_room_by_id_route)
|
||||
.ruma_route(c2s::join_room_by_id_or_alias_route)
|
||||
.ruma_route(c2s::joined_members_route)
|
||||
.ruma_route(c2s::leave_room_route)
|
||||
.ruma_route(c2s::forget_room_route)
|
||||
.ruma_route(c2s::joined_rooms_route)
|
||||
.ruma_route(c2s::kick_user_route)
|
||||
.ruma_route(c2s::ban_user_route)
|
||||
.ruma_route(c2s::unban_user_route)
|
||||
.ruma_route(c2s::invite_user_route)
|
||||
.ruma_route(c2s::set_room_visibility_route)
|
||||
.ruma_route(c2s::get_room_visibility_route)
|
||||
.ruma_route(c2s::get_public_rooms_route)
|
||||
.ruma_route(c2s::get_public_rooms_filtered_route)
|
||||
.ruma_route(c2s::search_users_route)
|
||||
.ruma_route(c2s::get_member_events_route)
|
||||
.ruma_route(c2s::get_protocols_route)
|
||||
.ruma_route(c2s::send_message_event_route)
|
||||
.ruma_route(c2s::send_state_event_for_key_route)
|
||||
.ruma_route(c2s::get_state_events_route)
|
||||
.ruma_route(c2s::get_state_events_for_key_route)
|
||||
.ruma_route(c2s::sync_events_route)
|
||||
.ruma_route(c2s::sync_events_v4_route)
|
||||
.ruma_route(c2s::get_context_route)
|
||||
.ruma_route(c2s::get_message_events_route)
|
||||
.ruma_route(c2s::search_events_route)
|
||||
.ruma_route(c2s::turn_server_route)
|
||||
.ruma_route(c2s::send_event_to_device_route);
|
||||
|
||||
// deprecated, but unproblematic
|
||||
let router = router.ruma_route(c2s::get_media_config_legacy_route);
|
||||
let router = if config.serve_media_unauthenticated {
|
||||
router
|
||||
.ruma_route(c2s::get_content_legacy_route)
|
||||
.ruma_route(c2s::get_content_as_filename_legacy_route)
|
||||
.ruma_route(c2s::get_content_thumbnail_legacy_route)
|
||||
} else {
|
||||
router
|
||||
.route(
|
||||
"/_matrix/media/v3/download/*path",
|
||||
any(unauthenticated_media_disabled),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/media/v3/thumbnail/*path",
|
||||
any(unauthenticated_media_disabled),
|
||||
)
|
||||
};
|
||||
|
||||
// authenticated media
|
||||
let router = router
|
||||
.ruma_route(c2s::get_media_config_route)
|
||||
.ruma_route(c2s::create_content_route)
|
||||
.ruma_route(c2s::get_content_route)
|
||||
.ruma_route(c2s::get_content_as_filename_route)
|
||||
.ruma_route(c2s::get_content_thumbnail_route);
|
||||
|
||||
let router = router
|
||||
.ruma_route(c2s::get_devices_route)
|
||||
.ruma_route(c2s::get_device_route)
|
||||
.ruma_route(c2s::update_device_route)
|
||||
.ruma_route(c2s::delete_device_route)
|
||||
.ruma_route(c2s::delete_devices_route)
|
||||
.ruma_route(c2s::get_tags_route)
|
||||
.ruma_route(c2s::update_tag_route)
|
||||
.ruma_route(c2s::delete_tag_route)
|
||||
.ruma_route(c2s::upload_signing_keys_route)
|
||||
.ruma_route(c2s::upload_signatures_route)
|
||||
.ruma_route(c2s::get_key_changes_route)
|
||||
.ruma_route(c2s::get_pushers_route)
|
||||
.ruma_route(c2s::set_pushers_route)
|
||||
.ruma_route(c2s::upgrade_room_route)
|
||||
.ruma_route(c2s::get_threads_route)
|
||||
.ruma_route(c2s::get_relating_events_with_rel_type_and_event_type_route)
|
||||
.ruma_route(c2s::get_relating_events_with_rel_type_route)
|
||||
.ruma_route(c2s::get_relating_events_route)
|
||||
.ruma_route(c2s::get_hierarchy_route);
|
||||
|
||||
// Ruma doesn't have support for multiple paths for a single endpoint yet,
|
||||
// and these routes share one Ruma request / response type pair with
|
||||
// {get,send}_state_event_for_key_route. These two endpoints also allow
|
||||
// trailing slashes.
|
||||
let router = router
|
||||
.route(
|
||||
"/_matrix/client/r0/rooms/:room_id/state/:event_type",
|
||||
get(c2s::get_state_events_for_empty_key_route)
|
||||
.put(c2s::send_state_event_for_empty_key_route),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/client/v3/rooms/:room_id/state/:event_type",
|
||||
get(c2s::get_state_events_for_empty_key_route)
|
||||
.put(c2s::send_state_event_for_empty_key_route),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/client/r0/rooms/:room_id/state/:event_type/",
|
||||
get(c2s::get_state_events_for_empty_key_route)
|
||||
.put(c2s::send_state_event_for_empty_key_route),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/client/v3/rooms/:room_id/state/:event_type/",
|
||||
get(c2s::get_state_events_for_empty_key_route)
|
||||
.put(c2s::send_state_event_for_empty_key_route),
|
||||
);
|
||||
|
||||
let router = if config.observability.metrics.enable {
|
||||
router.route(
|
||||
"/metrics",
|
||||
get(|| async { observability::METRICS.export() }),
|
||||
)
|
||||
} else {
|
||||
router
|
||||
};
|
||||
|
||||
let router = router
|
||||
.route(
|
||||
"/_matrix/client/r0/rooms/:room_id/initialSync",
|
||||
get(initial_sync),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/client/v3/rooms/:room_id/initialSync",
|
||||
get(initial_sync),
|
||||
)
|
||||
.route("/", get(it_works))
|
||||
.fallback(not_found);
|
||||
|
||||
let router = router
|
||||
.route("/.well-known/matrix/client", get(well_known::client))
|
||||
.route("/.well-known/matrix/server", get(well_known::server));
|
||||
|
||||
if config.federation.enable {
|
||||
router
|
||||
.ruma_route(s2s::get_server_version_route)
|
||||
.route("/_matrix/key/v2/server", get(s2s::get_server_keys_route))
|
||||
.route(
|
||||
"/_matrix/key/v2/server/:key_id",
|
||||
get(s2s::get_server_keys_deprecated_route),
|
||||
)
|
||||
.ruma_route(s2s::get_public_rooms_route)
|
||||
.ruma_route(s2s::get_public_rooms_filtered_route)
|
||||
.ruma_route(s2s::send_transaction_message_route)
|
||||
.ruma_route(s2s::get_event_route)
|
||||
.ruma_route(s2s::get_backfill_route)
|
||||
.ruma_route(s2s::get_missing_events_route)
|
||||
.ruma_route(s2s::get_event_authorization_route)
|
||||
.ruma_route(s2s::get_room_state_route)
|
||||
.ruma_route(s2s::get_room_state_ids_route)
|
||||
.ruma_route(s2s::create_join_event_template_route)
|
||||
.ruma_route(s2s::create_join_event_v1_route)
|
||||
.ruma_route(s2s::create_join_event_v2_route)
|
||||
.ruma_route(s2s::create_invite_route)
|
||||
.ruma_route(s2s::get_devices_route)
|
||||
.ruma_route(s2s::get_room_information_route)
|
||||
.ruma_route(s2s::get_profile_information_route)
|
||||
.ruma_route(s2s::get_keys_route)
|
||||
.ruma_route(s2s::claim_keys_route)
|
||||
.ruma_route(s2s::media_download_route)
|
||||
.ruma_route(s2s::media_thumbnail_route)
|
||||
} else {
|
||||
router
|
||||
.route("/_matrix/federation/*path", any(federation_disabled))
|
||||
.route("/_matrix/key/*path", any(federation_disabled))
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown_signal(handles: Vec<ServerHandle>) {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
let sig: &str;
|
||||
|
||||
tokio::select! {
|
||||
() = ctrl_c => { sig = "Ctrl+C"; },
|
||||
() = terminate => { sig = "SIGTERM"; },
|
||||
}
|
||||
|
||||
warn!(signal = %sig, "Shutting down due to signal");
|
||||
|
||||
services().globals.shutdown();
|
||||
|
||||
for handle in handles {
|
||||
handle.graceful_shutdown(Some(Duration::from_secs(30)));
|
||||
}
|
||||
|
||||
#[cfg(feature = "systemd")]
|
||||
sd_notify::notify(true, &[sd_notify::NotifyState::Stopping])
|
||||
.expect("should be able to notify systemd");
|
||||
}
|
||||
|
||||
async fn federation_disabled(_: Uri) -> impl IntoResponse {
|
||||
Error::bad_config("Federation is disabled.")
|
||||
}
|
||||
|
||||
async fn unauthenticated_media_disabled(_: Uri) -> impl IntoResponse {
|
||||
Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Unauthenticated media access is disabled",
|
||||
)
|
||||
}
|
||||
|
||||
async fn not_found(method: Method, uri: Uri) -> impl IntoResponse {
|
||||
debug!(%method, %uri, "Unknown route");
|
||||
Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request")
|
||||
}
|
||||
|
||||
async fn initial_sync(_uri: Uri) -> impl IntoResponse {
|
||||
Error::BadRequest(
|
||||
ErrorKind::GuestAccessForbidden,
|
||||
"Guest access not implemented",
|
||||
)
|
||||
}
|
||||
|
||||
async fn it_works() -> &'static str {
|
||||
"Hello from Grapevine!"
|
||||
}
|
||||
|
||||
trait RouterExt {
|
||||
fn ruma_route<H, T>(self, handler: H) -> Self
|
||||
where
|
||||
H: RumaHandler<T>,
|
||||
T: 'static;
|
||||
}
|
||||
|
||||
impl RouterExt for Router {
|
||||
fn ruma_route<H, T>(self, handler: H) -> Self
|
||||
where
|
||||
H: RumaHandler<T>,
|
||||
T: 'static,
|
||||
{
|
||||
handler.add_to_router(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait RumaHandler<T> {
|
||||
// Can't transform to a handler without boxing or relying on the
|
||||
// nightly-only impl-trait-in-traits feature. Moving a small amount of
|
||||
// extra logic into the trait allows bypassing both.
|
||||
fn add_to_router(self, router: Router) -> Router;
|
||||
}
|
||||
|
||||
macro_rules! impl_ruma_handler {
|
||||
( $($ty:ident),* $(,)? ) => {
|
||||
#[axum::async_trait]
|
||||
#[allow(non_snake_case)]
|
||||
impl<Req, Resp, E, F, Fut, $($ty,)*>
|
||||
RumaHandler<($($ty,)* Ar<Req>,)> for F
|
||||
where
|
||||
Req: IncomingRequest + Send + 'static,
|
||||
Resp: IntoResponse,
|
||||
F: FnOnce($($ty,)* Ar<Req>) -> Fut + Clone + Send + 'static,
|
||||
Fut: Future<Output = Result<Resp, E>>
|
||||
+ Send,
|
||||
E: IntoResponse,
|
||||
$( $ty: FromRequestParts<()> + Send + 'static, )*
|
||||
{
|
||||
fn add_to_router(self, mut router: Router) -> Router {
|
||||
let meta = Req::METADATA;
|
||||
let method_filter = method_to_filter(meta.method);
|
||||
|
||||
for path in meta.history.all_paths() {
|
||||
let handler = self.clone();
|
||||
|
||||
router = router.route(
|
||||
path,
|
||||
on(
|
||||
method_filter,
|
||||
|$( $ty: $ty, )* req: Ar<Req>| async move {
|
||||
let span = info_span!(
|
||||
"run_ruma_handler",
|
||||
auth.user = ?req.sender_user,
|
||||
auth.device = ?req.sender_device,
|
||||
auth.servername = ?req.sender_servername,
|
||||
auth.appservice_id = ?req.appservice_info
|
||||
.as_ref()
|
||||
.map(|i| &i.registration.id)
|
||||
);
|
||||
handler($($ty,)* req).instrument(span).await
|
||||
}
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
router
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_ruma_handler!();
|
||||
impl_ruma_handler!(T1);
|
||||
impl_ruma_handler!(T1, T2);
|
||||
impl_ruma_handler!(T1, T2, T3);
|
||||
impl_ruma_handler!(T1, T2, T3, T4);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5, T6);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7);
|
||||
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8);
|
||||
|
||||
fn method_to_filter(method: Method) -> MethodFilter {
|
||||
match method {
|
||||
Method::DELETE => MethodFilter::DELETE,
|
||||
Method::GET => MethodFilter::GET,
|
||||
Method::HEAD => MethodFilter::HEAD,
|
||||
Method::OPTIONS => MethodFilter::OPTIONS,
|
||||
Method::PATCH => MethodFilter::PATCH,
|
||||
Method::POST => MethodFilter::POST,
|
||||
Method::PUT => MethodFilter::PUT,
|
||||
Method::TRACE => MethodFilter::TRACE,
|
||||
m => panic!("Unsupported HTTP method: {m:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tracing::instrument(err)]
|
||||
fn maximize_fd_limit() -> Result<(), nix::errno::Errno> {
|
||||
use nix::sys::resource::{getrlimit, setrlimit, Resource};
|
||||
|
||||
let res = Resource::RLIMIT_NOFILE;
|
||||
|
||||
let (soft_limit, hard_limit) = getrlimit(res)?;
|
||||
|
||||
debug!(soft_limit, "Current nofile soft limit");
|
||||
|
||||
setrlimit(res, hard_limit, hard_limit)?;
|
||||
|
||||
debug!(hard_limit, "Increased nofile soft limit to the hard limit");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue