mirror of
https://gitlab.computer.surgery/matrix/grapevine.git
synced 2025-12-16 15:21:24 +01:00
1000 lines
34 KiB
Rust
1000 lines
34 KiB
Rust
use std::{
|
|
collections::HashSet, convert::Infallible, future::Future,
|
|
net::SocketAddr as IpSocketAddr,
|
|
os::unix::net::SocketAddr as UnixSocketAddr, sync::atomic, time::Duration,
|
|
};
|
|
|
|
use axum::{
|
|
extract::{
|
|
connect_info::{Connected, IntoMakeServiceWithConnectInfo},
|
|
ConnectInfo, DefaultBodyLimit, FromRequestParts, MatchedPath,
|
|
},
|
|
middleware::AddExtension,
|
|
response::IntoResponse,
|
|
routing::{any, get, on, MethodFilter, Route},
|
|
Router,
|
|
};
|
|
use axum_server::{
|
|
accept::Accept,
|
|
bind,
|
|
service::SendService,
|
|
tls_rustls::{RustlsAcceptor, RustlsConfig},
|
|
Address, Server,
|
|
};
|
|
use http::{
|
|
header::{self, HeaderName},
|
|
Method, StatusCode, Uri,
|
|
};
|
|
use hyper::body::Incoming;
|
|
use proxy_header::ProxyHeader;
|
|
use ruma::api::{
|
|
client::{
|
|
error::{Error as RumaError, ErrorBody, ErrorKind},
|
|
uiaa::UiaaResponse,
|
|
},
|
|
federation::discovery::get_server_version,
|
|
IncomingRequest,
|
|
};
|
|
use strum::Display;
|
|
use tokio::{
|
|
io::{AsyncRead, AsyncWrite},
|
|
signal,
|
|
task::JoinSet,
|
|
};
|
|
use tower::{Layer, Service, ServiceBuilder};
|
|
use tower_http::{
|
|
cors::{self, CorsLayer},
|
|
trace::TraceLayer,
|
|
ServiceBuilderExt as _,
|
|
};
|
|
use tracing::{debug, info, info_span, warn, Instrument};
|
|
|
|
use super::ServeArgs;
|
|
use crate::{
|
|
api::{
|
|
client_server,
|
|
ruma_wrapper::{Ar, Ra},
|
|
server_server::{self, AllowLoopbackRequests, LogRequestError},
|
|
well_known,
|
|
},
|
|
config::{self, Config, ListenComponent, ListenConfig, ListenTransport},
|
|
database::KeyValueDatabase,
|
|
error, observability, services, set_application_state,
|
|
utils::{
|
|
self,
|
|
error::{Error, Result},
|
|
proxy_protocol::{ProxyAcceptor, ProxyAcceptorConfig},
|
|
},
|
|
ApplicationState, Services,
|
|
};
|
|
|
|
pub(crate) async fn run(args: ServeArgs) -> Result<(), error::ServeCommand> {
|
|
use error::ServeCommand as Error;
|
|
|
|
let config = config::load(args.config.config.as_ref(), false).await?;
|
|
|
|
rustls::crypto::ring::default_provider()
|
|
.install_default()
|
|
.expect("rustls default crypto provider should not be already set");
|
|
|
|
let (_guard, reload_handles) = observability::init(&config)?;
|
|
|
|
info!("Loading database");
|
|
let db = Box::leak(Box::new(
|
|
KeyValueDatabase::load_or_create(&config)
|
|
.map_err(Error::DatabaseError)?,
|
|
));
|
|
|
|
// This struct will remove old Unix sockets once it's dropped.
|
|
let _clean_up_socks = CleanUpUnixSockets(config.listen.clone());
|
|
|
|
Services::new(db, config, Some(reload_handles))
|
|
.map_err(Error::InitializeServices)?
|
|
.install();
|
|
|
|
services().globals.err_if_server_name_changed()?;
|
|
|
|
db.apply_migrations().await.map_err(Error::DatabaseError)?;
|
|
|
|
info!("Starting background tasks");
|
|
services().admin.start_handler();
|
|
services().sending.start_handler();
|
|
KeyValueDatabase::start_cleanup_task();
|
|
services().globals.set_emergency_access();
|
|
|
|
info!("Starting server");
|
|
run_server().await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
struct CleanUpUnixSockets(Vec<ListenConfig>);
|
|
|
|
impl Drop for CleanUpUnixSockets {
|
|
fn drop(&mut self) {
|
|
// Remove old Unix sockets
|
|
for listen in &self.0 {
|
|
if let ListenTransport::Unix {
|
|
path,
|
|
..
|
|
} = &listen.transport
|
|
{
|
|
info!(
|
|
path = path.display().to_string(),
|
|
"Removing Unix socket"
|
|
);
|
|
if let Err(error) = std::fs::remove_file(path) {
|
|
warn!(%error, "Couldn't remove Unix socket");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument]
|
|
async fn federation_self_test() -> Result<()> {
|
|
let response = server_server::send_request(
|
|
&services().globals.config.server_name,
|
|
get_server_version::v1::Request {},
|
|
LogRequestError::Yes,
|
|
AllowLoopbackRequests::Yes,
|
|
)
|
|
.await?;
|
|
|
|
if response
|
|
.server
|
|
.as_ref()
|
|
.is_none_or(|s| s.name.as_deref() != Some(env!("CARGO_PKG_NAME")))
|
|
{
|
|
error!(?response, "unexpected server version");
|
|
return Err(Error::BadConfig(
|
|
"Got unexpected version from our own version endpoint",
|
|
));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
// A trait we'll implement on `axum_server::Handle` in order to be able to
|
|
// shutdown handles regardless of their generics.
|
|
trait ServerHandle: Send {
|
|
fn shutdown(&self, timeout: Option<Duration>);
|
|
}
|
|
|
|
impl<A: Address + Send> ServerHandle for axum_server::Handle<A> {
|
|
fn shutdown(&self, timeout: Option<Duration>) {
|
|
self.graceful_shutdown(timeout);
|
|
}
|
|
}
|
|
|
|
/// This type is needed to allow us to find out where incoming connections came
|
|
/// from. Before Unix socket support, we could simply use `IpSocketAddr` here,
|
|
/// but this is no longer possible.
|
|
#[derive(Clone, Display)]
|
|
enum AddrConnectInfo {
|
|
#[strum(to_string = "{0}")]
|
|
Ip(IpSocketAddr),
|
|
|
|
#[strum(to_string = "[unix socket]")]
|
|
UnixSocket,
|
|
|
|
#[strum(to_string = "[unknown]")]
|
|
Unknown,
|
|
}
|
|
|
|
impl Connected<IpSocketAddr> for AddrConnectInfo {
|
|
fn connect_info(target: IpSocketAddr) -> Self {
|
|
Self::Ip(target)
|
|
}
|
|
}
|
|
|
|
impl Connected<UnixSocketAddr> for AddrConnectInfo {
|
|
fn connect_info(_target: UnixSocketAddr) -> Self {
|
|
// The `UnixSocketAddr` we get here is one that we can't recover the
|
|
// path from (`as_pathname` returns `None`), so there's no point
|
|
// in saving it (we only use all this for logging).
|
|
Self::UnixSocket
|
|
}
|
|
}
|
|
|
|
struct ServerSpawner<'cfg, M> {
|
|
config: &'cfg Config,
|
|
middlewares: M,
|
|
|
|
tls_config: Option<RustlsConfig>,
|
|
proxy_config: ProxyAcceptorConfig,
|
|
servers: JoinSet<(ListenConfig, std::io::Result<()>)>,
|
|
handles: Vec<Box<dyn ServerHandle>>,
|
|
}
|
|
|
|
impl<'cfg, M> ServerSpawner<'cfg, M>
|
|
where
|
|
M: Layer<Route> + Clone + Send + 'static,
|
|
M::Service: Service<
|
|
axum::extract::Request,
|
|
Response = axum::response::Response,
|
|
Error = Infallible,
|
|
> + Clone
|
|
+ Send
|
|
+ 'static,
|
|
<M::Service as Service<axum::extract::Request>>::Future: Send + 'static,
|
|
{
|
|
async fn new(
|
|
config: &'cfg Config,
|
|
middlewares: M,
|
|
) -> Result<Self, error::Serve> {
|
|
let tls_config = if let Some(tls) = &config.tls {
|
|
Some(
|
|
RustlsConfig::from_pem_file(&tls.certs, &tls.key)
|
|
.await
|
|
.map_err(|err| error::Serve::LoadCerts {
|
|
certs: tls.certs.clone(),
|
|
key: tls.key.clone(),
|
|
err,
|
|
})?,
|
|
)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let proxy_config = ProxyAcceptorConfig::default();
|
|
|
|
Ok(Self {
|
|
config,
|
|
middlewares,
|
|
tls_config,
|
|
proxy_config,
|
|
servers: JoinSet::new(),
|
|
handles: Vec::new(),
|
|
})
|
|
}
|
|
|
|
/// Returns a function that transforms a lower-layer acceptor into a TLS
|
|
/// acceptor.
|
|
fn tls_acceptor_factory<A>(
|
|
&self,
|
|
listen: &ListenConfig,
|
|
) -> Result<impl FnOnce(A) -> RustlsAcceptor<A>, error::Serve> {
|
|
let config = self
|
|
.tls_config
|
|
.clone()
|
|
.ok_or_else(|| error::Serve::NoTlsCerts(listen.clone()))?;
|
|
|
|
Ok(|inner| RustlsAcceptor::new(config).acceptor(inner))
|
|
}
|
|
|
|
/// Returns a function that transforms a lower-layer acceptor into a Proxy
|
|
/// Protocol acceptor.
|
|
fn proxy_acceptor_factory<A>(&self) -> impl FnOnce(A) -> ProxyAcceptor<A> {
|
|
let config = self.proxy_config.clone();
|
|
|
|
|inner| ProxyAcceptor::new(inner, config)
|
|
}
|
|
|
|
fn spawn_server_inner<Addr, A>(
|
|
&mut self,
|
|
listen: ListenConfig,
|
|
server: Server<Addr, A>,
|
|
app: IntoMakeServiceWithConnectInfo<Router, AddrConnectInfo>,
|
|
) where
|
|
AddrConnectInfo: Connected<Addr>,
|
|
Addr: Address + Send + 'static,
|
|
Addr::Stream: Send,
|
|
Addr::Listener: Send,
|
|
A: Accept<
|
|
Addr::Stream,
|
|
AddExtension<Router, ConnectInfo<AddrConnectInfo>>,
|
|
> + Clone
|
|
+ Send
|
|
+ Sync
|
|
+ 'static,
|
|
A::Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync,
|
|
A::Service: SendService<http::Request<Incoming>> + Send,
|
|
A::Future: Send,
|
|
{
|
|
let handle = axum_server::Handle::new();
|
|
let server = server.handle(handle.clone()).serve(app);
|
|
self.servers.spawn(async move {
|
|
let result = server.await;
|
|
|
|
(listen, result)
|
|
});
|
|
self.handles.push(Box::new(handle));
|
|
}
|
|
|
|
fn spawn_server(
|
|
&mut self,
|
|
listen: ListenConfig,
|
|
) -> Result<(), error::Serve> {
|
|
let app = routes(self.config, &listen.components)
|
|
.layer(self.middlewares.clone())
|
|
.into_make_service_with_connect_info::<AddrConnectInfo>();
|
|
|
|
match &listen.transport {
|
|
ListenTransport::Tcp {
|
|
address,
|
|
port,
|
|
tls,
|
|
proxy_protocol,
|
|
} => {
|
|
let addr = IpSocketAddr::from((*address, *port));
|
|
let server = bind(addr);
|
|
|
|
match (tls, proxy_protocol) {
|
|
(false, false) => {
|
|
self.spawn_server_inner(listen, server, app);
|
|
}
|
|
(false, true) => {
|
|
let server = server.map(self.proxy_acceptor_factory());
|
|
self.spawn_server_inner(listen, server, app);
|
|
}
|
|
(true, false) => {
|
|
let server =
|
|
server.map(self.tls_acceptor_factory(&listen)?);
|
|
self.spawn_server_inner(listen, server, app);
|
|
}
|
|
(true, true) => {
|
|
let server = server
|
|
.map(self.proxy_acceptor_factory())
|
|
.map(self.tls_acceptor_factory(&listen)?);
|
|
self.spawn_server_inner(listen, server, app);
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
ListenTransport::Unix {
|
|
path,
|
|
proxy_protocol,
|
|
} => {
|
|
let addr = match UnixSocketAddr::from_pathname(path) {
|
|
Ok(addr) => addr,
|
|
Err(e) => {
|
|
// We can't use `map_err` here, as that would move
|
|
// `listen` into a closure, preventing us from using it
|
|
// later.
|
|
return Err(error::Serve::Listen(e, listen));
|
|
}
|
|
};
|
|
let server = bind(addr);
|
|
|
|
if *proxy_protocol {
|
|
let server = server.map(self.proxy_acceptor_factory());
|
|
self.spawn_server_inner(listen, server, app);
|
|
} else {
|
|
self.spawn_server_inner(listen, server, app);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines)]
|
|
async fn run_server() -> Result<(), error::Serve> {
|
|
use error::Serve as Error;
|
|
|
|
let config = &services().globals.config;
|
|
|
|
let x_requested_with = HeaderName::from_static("x-requested-with");
|
|
|
|
let middlewares = ServiceBuilder::new()
|
|
.sensitive_headers([header::AUTHORIZATION])
|
|
.layer(axum::middleware::from_fn(spawn_task))
|
|
.layer(
|
|
TraceLayer::new_for_http()
|
|
.make_span_with(|request: &http::Request<_>| {
|
|
let endpoint = if let Some(endpoint) =
|
|
request.extensions().get::<MatchedPath>()
|
|
{
|
|
endpoint.as_str()
|
|
} else {
|
|
request.uri().path()
|
|
};
|
|
|
|
let method = request.method();
|
|
|
|
let source_address = request
|
|
.extensions()
|
|
.get::<ProxyHeader<'_>>()
|
|
.map_or_else(
|
|
|| {
|
|
request
|
|
.extensions()
|
|
.get::<ConnectInfo<AddrConnectInfo>>()
|
|
.map(|ConnectInfo(addr)| addr.clone())
|
|
},
|
|
|h| {
|
|
h.proxied_address().map(|addr| {
|
|
AddrConnectInfo::Ip(addr.source)
|
|
})
|
|
},
|
|
)
|
|
.unwrap_or(AddrConnectInfo::Unknown);
|
|
|
|
tracing::info_span!(
|
|
"http_request",
|
|
otel.name = format!("{method} {endpoint}"),
|
|
%method,
|
|
%endpoint,
|
|
%source_address,
|
|
)
|
|
})
|
|
.on_request(
|
|
|request: &http::Request<_>, _span: &tracing::Span| {
|
|
// can be enabled selectively using `filter =
|
|
// grapevine[incoming_request_curl]=trace` in config
|
|
tracing::trace_span!("incoming_request_curl").in_scope(
|
|
|| {
|
|
tracing::trace!(
|
|
cmd = utils::curlify(request),
|
|
"curl command line for incoming request \
|
|
(guessed hostname)"
|
|
);
|
|
},
|
|
);
|
|
},
|
|
),
|
|
)
|
|
.layer(axum::middleware::from_fn(unrecognized_method))
|
|
.layer(
|
|
CorsLayer::new()
|
|
.allow_origin(cors::Any)
|
|
.allow_methods([
|
|
Method::GET,
|
|
Method::POST,
|
|
Method::PUT,
|
|
Method::DELETE,
|
|
Method::OPTIONS,
|
|
])
|
|
.allow_headers([
|
|
header::ORIGIN,
|
|
x_requested_with,
|
|
header::CONTENT_TYPE,
|
|
header::ACCEPT,
|
|
header::AUTHORIZATION,
|
|
])
|
|
.max_age(Duration::from_secs(86400)),
|
|
)
|
|
.layer(DefaultBodyLimit::max(
|
|
config
|
|
.max_request_size
|
|
.try_into()
|
|
.expect("failed to convert max request size"),
|
|
))
|
|
.layer(axum::middleware::from_fn(observability::http_metrics_layer))
|
|
.layer(axum::middleware::from_fn(observability::traceresponse_layer));
|
|
|
|
let mut spawner = ServerSpawner::new(config, middlewares).await?;
|
|
|
|
if config.listen.is_empty() {
|
|
return Err(Error::NoListeners);
|
|
}
|
|
|
|
for listen in &config.listen {
|
|
info!(listener = %listen, "Listening for incoming traffic");
|
|
spawner.spawn_server(listen.clone())?;
|
|
}
|
|
|
|
tokio::spawn(handle_signals(spawner.tls_config, spawner.handles));
|
|
|
|
if config.federation.enable && config.federation.self_test {
|
|
federation_self_test()
|
|
.await
|
|
.map_err(error::Serve::FederationSelfTestFailed)?;
|
|
debug!("Federation self-test completed successfully");
|
|
}
|
|
|
|
set_application_state(ApplicationState::Ready);
|
|
|
|
while let Some(result) = spawner.servers.join_next().await {
|
|
let (listen, result) =
|
|
result.expect("should be able to join server task");
|
|
result.map_err(|err| Error::Listen(err, listen))?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Ensures the request runs in a new tokio thread.
|
|
///
|
|
/// The axum request handler task gets cancelled if the connection is shut down;
|
|
/// by spawning our own task, processing continue after the client disconnects.
|
|
async fn spawn_task(
|
|
req: axum::extract::Request,
|
|
next: axum::middleware::Next,
|
|
) -> std::result::Result<axum::response::Response, StatusCode> {
|
|
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
|
|
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
|
}
|
|
tokio::spawn(next.run(req))
|
|
.await
|
|
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
|
}
|
|
|
|
async fn unrecognized_method(
|
|
req: axum::extract::Request,
|
|
next: axum::middleware::Next,
|
|
) -> std::result::Result<axum::response::Response, StatusCode> {
|
|
let method = req.method().clone();
|
|
let uri = req.uri().clone();
|
|
let inner = next.run(req).await;
|
|
if inner.status() == StatusCode::METHOD_NOT_ALLOWED {
|
|
warn!(%method, %uri, "Method not allowed");
|
|
return Ok(Ra(UiaaResponse::MatrixError(RumaError {
|
|
body: ErrorBody::Standard {
|
|
kind: ErrorKind::Unrecognized,
|
|
message: "M_UNRECOGNIZED: Unrecognized request".to_owned(),
|
|
},
|
|
status_code: StatusCode::METHOD_NOT_ALLOWED,
|
|
}))
|
|
.into_response());
|
|
}
|
|
Ok(inner)
|
|
}
|
|
|
|
/// Routes for legacy unauthenticated `/_matrix/media/*` APIs (used by both
|
|
/// clients and federation)
|
|
fn legacy_media_routes(config: &Config) -> Router {
|
|
use client_server as c2s;
|
|
|
|
let router = Router::new();
|
|
|
|
// deprecated, but unproblematic
|
|
let router = router.ruma_route(c2s::get_media_config_legacy_route);
|
|
|
|
if config.media.allow_unauthenticated_access {
|
|
router
|
|
.ruma_route(c2s::get_content_legacy_route)
|
|
.ruma_route(c2s::get_content_as_filename_legacy_route)
|
|
.ruma_route(c2s::get_content_thumbnail_legacy_route)
|
|
} else {
|
|
router
|
|
.route(
|
|
"/_matrix/media/v3/download/*path",
|
|
any(unauthenticated_media_disabled),
|
|
)
|
|
.route(
|
|
"/_matrix/media/v3/thumbnail/*path",
|
|
any(unauthenticated_media_disabled),
|
|
)
|
|
}
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines)]
|
|
fn client_routes() -> Router {
|
|
use client_server as c2s;
|
|
|
|
let router = Router::new()
|
|
.ruma_route(c2s::get_supported_versions_route)
|
|
.ruma_route(c2s::get_register_available_route)
|
|
.ruma_route(c2s::register_route)
|
|
.ruma_route(c2s::get_login_types_route)
|
|
.ruma_route(c2s::login_route)
|
|
.ruma_route(c2s::whoami_route)
|
|
.ruma_route(c2s::logout_route)
|
|
.ruma_route(c2s::logout_all_route)
|
|
.ruma_route(c2s::change_password_route)
|
|
.ruma_route(c2s::deactivate_route)
|
|
.ruma_route(c2s::third_party_route)
|
|
.ruma_route(c2s::request_3pid_management_token_via_email_route)
|
|
.ruma_route(c2s::request_3pid_management_token_via_msisdn_route)
|
|
.ruma_route(c2s::get_capabilities_route)
|
|
.ruma_route(c2s::get_pushrules_all_route)
|
|
.ruma_route(c2s::set_pushrule_route)
|
|
.ruma_route(c2s::get_pushrule_route)
|
|
.ruma_route(c2s::set_pushrule_enabled_route)
|
|
.ruma_route(c2s::get_pushrule_enabled_route)
|
|
.ruma_route(c2s::get_pushrule_actions_route)
|
|
.ruma_route(c2s::set_pushrule_actions_route)
|
|
.ruma_route(c2s::delete_pushrule_route)
|
|
.ruma_route(c2s::get_room_event_route)
|
|
.ruma_route(c2s::get_room_aliases_route)
|
|
.ruma_route(c2s::get_filter_route)
|
|
.ruma_route(c2s::create_filter_route)
|
|
.ruma_route(c2s::set_global_account_data_route)
|
|
.ruma_route(c2s::set_room_account_data_route)
|
|
.ruma_route(c2s::get_global_account_data_route)
|
|
.ruma_route(c2s::get_room_account_data_route)
|
|
.ruma_route(c2s::set_displayname_route)
|
|
.ruma_route(c2s::get_displayname_route)
|
|
.ruma_route(c2s::set_avatar_url_route)
|
|
.ruma_route(c2s::get_avatar_url_route)
|
|
.ruma_route(c2s::get_profile_route)
|
|
.ruma_route(c2s::upload_keys_route)
|
|
.ruma_route(c2s::get_keys_route)
|
|
.ruma_route(c2s::claim_keys_route)
|
|
.ruma_route(c2s::create_backup_version_route)
|
|
.ruma_route(c2s::update_backup_version_route)
|
|
.ruma_route(c2s::delete_backup_version_route)
|
|
.ruma_route(c2s::get_latest_backup_info_route)
|
|
.ruma_route(c2s::get_backup_info_route)
|
|
.ruma_route(c2s::add_backup_keys_route)
|
|
.ruma_route(c2s::add_backup_keys_for_room_route)
|
|
.ruma_route(c2s::add_backup_keys_for_session_route)
|
|
.ruma_route(c2s::delete_backup_keys_for_room_route)
|
|
.ruma_route(c2s::delete_backup_keys_for_session_route)
|
|
.ruma_route(c2s::delete_backup_keys_route)
|
|
.ruma_route(c2s::get_backup_keys_for_room_route)
|
|
.ruma_route(c2s::get_backup_keys_for_session_route)
|
|
.ruma_route(c2s::get_backup_keys_route)
|
|
.ruma_route(c2s::set_read_marker_route)
|
|
.ruma_route(c2s::create_receipt_route)
|
|
.ruma_route(c2s::create_typing_event_route)
|
|
.ruma_route(c2s::create_room_route)
|
|
.ruma_route(c2s::redact_event_route)
|
|
.ruma_route(c2s::report_event_route)
|
|
.ruma_route(c2s::create_alias_route)
|
|
.ruma_route(c2s::delete_alias_route)
|
|
.ruma_route(c2s::get_alias_route)
|
|
.ruma_route(c2s::join_room_by_id_route)
|
|
.ruma_route(c2s::join_room_by_id_or_alias_route)
|
|
.ruma_route(c2s::joined_members_route)
|
|
.ruma_route(c2s::leave_room_route)
|
|
.ruma_route(c2s::forget_room_route)
|
|
.ruma_route(c2s::joined_rooms_route)
|
|
.ruma_route(c2s::kick_user_route)
|
|
.ruma_route(c2s::ban_user_route)
|
|
.ruma_route(c2s::unban_user_route)
|
|
.ruma_route(c2s::invite_user_route)
|
|
.ruma_route(c2s::set_room_visibility_route)
|
|
.ruma_route(c2s::get_room_visibility_route)
|
|
.ruma_route(c2s::get_public_rooms_route)
|
|
.ruma_route(c2s::get_public_rooms_filtered_route)
|
|
.ruma_route(c2s::search_users_route)
|
|
.ruma_route(c2s::get_member_events_route)
|
|
.ruma_route(c2s::get_protocols_route)
|
|
.ruma_route(c2s::send_message_event_route)
|
|
.ruma_route(c2s::send_state_event_for_key_route)
|
|
.ruma_route(c2s::get_state_events_route)
|
|
.ruma_route(c2s::get_state_events_for_key_route)
|
|
.ruma_route(c2s::v3::sync_events_route)
|
|
.ruma_route(c2s::msc3575::sync_events_v4_route)
|
|
.ruma_route(c2s::get_context_route)
|
|
.ruma_route(c2s::get_message_events_route)
|
|
.ruma_route(c2s::search_events_route)
|
|
.ruma_route(c2s::turn_server_route)
|
|
.ruma_route(c2s::send_event_to_device_route);
|
|
|
|
// authenticated media
|
|
let router = router
|
|
.ruma_route(c2s::get_media_config_route)
|
|
.ruma_route(c2s::create_content_route)
|
|
.ruma_route(c2s::get_content_route)
|
|
.ruma_route(c2s::get_content_as_filename_route)
|
|
.ruma_route(c2s::get_content_thumbnail_route);
|
|
|
|
let router = router
|
|
.ruma_route(c2s::get_devices_route)
|
|
.ruma_route(c2s::get_device_route)
|
|
.ruma_route(c2s::update_device_route)
|
|
.ruma_route(c2s::delete_device_route)
|
|
.ruma_route(c2s::delete_devices_route)
|
|
.ruma_route(c2s::get_tags_route)
|
|
.ruma_route(c2s::update_tag_route)
|
|
.ruma_route(c2s::delete_tag_route)
|
|
.ruma_route(c2s::upload_signing_keys_route)
|
|
.ruma_route(c2s::upload_signatures_route)
|
|
.ruma_route(c2s::get_key_changes_route)
|
|
.ruma_route(c2s::get_pushers_route)
|
|
.ruma_route(c2s::set_pushers_route)
|
|
.ruma_route(c2s::upgrade_room_route)
|
|
.ruma_route(c2s::get_threads_route)
|
|
.ruma_route(c2s::get_relating_events_with_rel_type_and_event_type_route)
|
|
.ruma_route(c2s::get_relating_events_with_rel_type_route)
|
|
.ruma_route(c2s::get_relating_events_route)
|
|
.ruma_route(c2s::get_hierarchy_route);
|
|
|
|
// Ruma doesn't have support for multiple paths for a single endpoint yet,
|
|
// and these routes share one Ruma request / response type pair with
|
|
// {get,send}_state_event_for_key_route. These two endpoints also allow
|
|
// trailing slashes.
|
|
let router = router
|
|
.route(
|
|
"/_matrix/client/r0/rooms/:room_id/state/:event_type",
|
|
get(c2s::get_state_events_for_empty_key_route)
|
|
.put(c2s::send_state_event_for_empty_key_route),
|
|
)
|
|
.route(
|
|
"/_matrix/client/v3/rooms/:room_id/state/:event_type",
|
|
get(c2s::get_state_events_for_empty_key_route)
|
|
.put(c2s::send_state_event_for_empty_key_route),
|
|
)
|
|
.route(
|
|
"/_matrix/client/r0/rooms/:room_id/state/:event_type/",
|
|
get(c2s::get_state_events_for_empty_key_route)
|
|
.put(c2s::send_state_event_for_empty_key_route),
|
|
)
|
|
.route(
|
|
"/_matrix/client/v3/rooms/:room_id/state/:event_type/",
|
|
get(c2s::get_state_events_for_empty_key_route)
|
|
.put(c2s::send_state_event_for_empty_key_route),
|
|
);
|
|
|
|
router
|
|
.route(
|
|
"/_matrix/client/r0/rooms/:room_id/initialSync",
|
|
get(initial_sync),
|
|
)
|
|
.route(
|
|
"/_matrix/client/v3/rooms/:room_id/initialSync",
|
|
get(initial_sync),
|
|
)
|
|
}
|
|
|
|
fn federation_routes(config: &Config) -> Router {
|
|
use server_server as s2s;
|
|
|
|
if config.federation.enable {
|
|
Router::new()
|
|
.ruma_route(s2s::get_server_version_route)
|
|
.route("/_matrix/key/v2/server", get(s2s::get_server_keys_route))
|
|
.route(
|
|
"/_matrix/key/v2/server/:key_id",
|
|
get(s2s::get_server_keys_deprecated_route),
|
|
)
|
|
.ruma_route(s2s::get_public_rooms_route)
|
|
.ruma_route(s2s::get_public_rooms_filtered_route)
|
|
.ruma_route(s2s::send_transaction_message_route)
|
|
.ruma_route(s2s::get_event_route)
|
|
.ruma_route(s2s::get_backfill_route)
|
|
.ruma_route(s2s::get_missing_events_route)
|
|
.ruma_route(s2s::get_event_authorization_route)
|
|
.ruma_route(s2s::get_room_state_route)
|
|
.ruma_route(s2s::get_room_state_ids_route)
|
|
.ruma_route(s2s::create_join_event_template_route)
|
|
.ruma_route(s2s::create_join_event_v1_route)
|
|
.ruma_route(s2s::create_join_event_v2_route)
|
|
.ruma_route(s2s::create_invite_route)
|
|
.ruma_route(s2s::get_devices_route)
|
|
.ruma_route(s2s::get_room_information_route)
|
|
.ruma_route(s2s::get_profile_information_route)
|
|
.ruma_route(s2s::get_keys_route)
|
|
.ruma_route(s2s::claim_keys_route)
|
|
.ruma_route(s2s::media_download_route)
|
|
.ruma_route(s2s::media_thumbnail_route)
|
|
} else {
|
|
Router::new()
|
|
.route("/_matrix/federation/*path", any(federation_disabled))
|
|
.route("/_matrix/key/*path", any(federation_disabled))
|
|
}
|
|
}
|
|
|
|
fn metrics_routes(config: &Config) -> Router {
|
|
if config.observability.metrics.enable {
|
|
Router::new().route(
|
|
"/metrics",
|
|
get(|| async { observability::METRICS.export() }),
|
|
)
|
|
} else {
|
|
Router::new()
|
|
}
|
|
}
|
|
|
|
fn well_known_routes() -> Router {
|
|
Router::new()
|
|
.route("/.well-known/matrix/client", get(well_known::client))
|
|
.route("/.well-known/matrix/server", get(well_known::server))
|
|
}
|
|
|
|
fn routes(config: &Config, components: &HashSet<ListenComponent>) -> Router {
|
|
let mut router = Router::new();
|
|
for &component in components {
|
|
router = router.merge(match component {
|
|
ListenComponent::Client => client_routes(),
|
|
ListenComponent::Federation => federation_routes(config),
|
|
ListenComponent::Metrics => metrics_routes(config),
|
|
ListenComponent::WellKnown => well_known_routes(),
|
|
});
|
|
}
|
|
|
|
if components.contains(&ListenComponent::Client)
|
|
|| components.contains(&ListenComponent::Federation)
|
|
{
|
|
router = router.merge(legacy_media_routes(config));
|
|
}
|
|
|
|
router.route("/", get(it_works)).fallback(not_found)
|
|
}
|
|
|
|
async fn reload_tls_config(
|
|
tls_config: &RustlsConfig,
|
|
) -> Result<(), error::Serve> {
|
|
let config = services()
|
|
.globals
|
|
.config
|
|
.tls
|
|
.as_ref()
|
|
.expect("TLS config should exist if TLS listener exists");
|
|
|
|
tls_config.reload_from_pem_file(&config.certs, &config.key).await.map_err(
|
|
|err| error::Serve::LoadCerts {
|
|
certs: config.certs.clone(),
|
|
key: config.key.clone(),
|
|
err,
|
|
},
|
|
)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_signals(
|
|
tls_config: Option<RustlsConfig>,
|
|
handles: Vec<Box<dyn ServerHandle>>,
|
|
) {
|
|
#[cfg(unix)]
|
|
async fn wait_signal(sig: signal::unix::SignalKind) {
|
|
signal::unix::signal(sig)
|
|
.expect("failed to install signal handler")
|
|
.recv()
|
|
.await;
|
|
}
|
|
|
|
#[cfg(unix)]
|
|
let terminate = || wait_signal(signal::unix::SignalKind::terminate());
|
|
#[cfg(not(unix))]
|
|
let terminate = || std::future::pending::<()>();
|
|
|
|
#[cfg(unix)]
|
|
let sighup = || wait_signal(signal::unix::SignalKind::hangup());
|
|
#[cfg(not(unix))]
|
|
let sighup = || std::future::pending::<()>();
|
|
|
|
let ctrl_c = || async {
|
|
signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
|
|
};
|
|
|
|
let sig = loop {
|
|
tokio::select! {
|
|
() = sighup() => {
|
|
info!("Received reload request");
|
|
|
|
set_application_state(ApplicationState::Reloading);
|
|
|
|
if let Some(tls_config) = tls_config.as_ref() {
|
|
if let Err(error) = reload_tls_config(tls_config).await {
|
|
error!(?error, "Failed to reload TLS config");
|
|
}
|
|
}
|
|
|
|
set_application_state(ApplicationState::Ready);
|
|
},
|
|
() = terminate() => { break "SIGTERM"; },
|
|
() = ctrl_c() => { break "Ctrl+C"; },
|
|
}
|
|
};
|
|
|
|
warn!(signal = %sig, "Shutting down due to signal");
|
|
|
|
services().globals.shutdown();
|
|
|
|
for handle in handles {
|
|
handle.shutdown(Some(Duration::from_secs(30)));
|
|
}
|
|
|
|
set_application_state(ApplicationState::Stopping);
|
|
}
|
|
|
|
async fn federation_disabled(_: Uri) -> impl IntoResponse {
|
|
Error::bad_config("Federation is disabled.")
|
|
}
|
|
|
|
async fn unauthenticated_media_disabled(_: Uri) -> impl IntoResponse {
|
|
Error::BadRequest(
|
|
ErrorKind::NotFound,
|
|
"Unauthenticated media access is disabled",
|
|
)
|
|
}
|
|
|
|
async fn not_found(method: Method, uri: Uri) -> impl IntoResponse {
|
|
debug!(%method, %uri, "Unknown route");
|
|
Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request")
|
|
}
|
|
|
|
async fn initial_sync(_uri: Uri) -> impl IntoResponse {
|
|
Error::BadRequest(
|
|
ErrorKind::GuestAccessForbidden,
|
|
"Guest access not implemented",
|
|
)
|
|
}
|
|
|
|
async fn it_works() -> &'static str {
|
|
"Hello from Grapevine!"
|
|
}
|
|
|
|
trait RouterExt {
|
|
fn ruma_route<H, T>(self, handler: H) -> Self
|
|
where
|
|
H: RumaHandler<T>,
|
|
T: 'static;
|
|
}
|
|
|
|
impl RouterExt for Router {
|
|
fn ruma_route<H, T>(self, handler: H) -> Self
|
|
where
|
|
H: RumaHandler<T>,
|
|
T: 'static,
|
|
{
|
|
handler.add_to_router(self)
|
|
}
|
|
}
|
|
|
|
pub(crate) trait RumaHandler<T> {
|
|
// Can't transform to a handler without boxing or relying on the
|
|
// nightly-only impl-trait-in-traits feature. Moving a small amount of
|
|
// extra logic into the trait allows bypassing both.
|
|
fn add_to_router(self, router: Router) -> Router;
|
|
}
|
|
|
|
macro_rules! impl_ruma_handler {
|
|
( $($ty:ident),* $(,)? ) => {
|
|
#[axum::async_trait]
|
|
#[allow(non_snake_case)]
|
|
impl<Req, Resp, E, F, Fut, $($ty,)*>
|
|
RumaHandler<($($ty,)* Ar<Req>,)> for F
|
|
where
|
|
Req: IncomingRequest + Send + 'static,
|
|
Resp: IntoResponse,
|
|
F: FnOnce($($ty,)* Ar<Req>) -> Fut + Clone + Send + 'static,
|
|
Fut: Future<Output = Result<Resp, E>>
|
|
+ Send,
|
|
E: IntoResponse,
|
|
$( $ty: FromRequestParts<()> + Send + 'static, )*
|
|
{
|
|
fn add_to_router(self, mut router: Router) -> Router {
|
|
let meta = Req::METADATA;
|
|
let method_filter = method_to_filter(meta.method);
|
|
|
|
for path in meta.history.all_paths() {
|
|
let handler = self.clone();
|
|
|
|
router = router.route(
|
|
path,
|
|
on(
|
|
method_filter,
|
|
|$( $ty: $ty, )* req: Ar<Req>| async move {
|
|
let span = info_span!(
|
|
"run_ruma_handler",
|
|
auth.user = ?req.sender_user,
|
|
auth.device = ?req.sender_device,
|
|
auth.servername = ?req.sender_servername,
|
|
auth.appservice_id = ?req.appservice_info
|
|
.as_ref()
|
|
.map(|i| &i.registration.id)
|
|
);
|
|
handler($($ty,)* req).instrument(span).await
|
|
}
|
|
)
|
|
)
|
|
}
|
|
|
|
router
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
impl_ruma_handler!();
|
|
impl_ruma_handler!(T1);
|
|
impl_ruma_handler!(T1, T2);
|
|
impl_ruma_handler!(T1, T2, T3);
|
|
impl_ruma_handler!(T1, T2, T3, T4);
|
|
impl_ruma_handler!(T1, T2, T3, T4, T5);
|
|
impl_ruma_handler!(T1, T2, T3, T4, T5, T6);
|
|
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7);
|
|
impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8);
|
|
|
|
fn method_to_filter(method: Method) -> MethodFilter {
|
|
match method {
|
|
Method::DELETE => MethodFilter::DELETE,
|
|
Method::GET => MethodFilter::GET,
|
|
Method::HEAD => MethodFilter::HEAD,
|
|
Method::OPTIONS => MethodFilter::OPTIONS,
|
|
Method::PATCH => MethodFilter::PATCH,
|
|
Method::POST => MethodFilter::POST,
|
|
Method::PUT => MethodFilter::PUT,
|
|
Method::TRACE => MethodFilter::TRACE,
|
|
m => panic!("Unsupported HTTP method: {m:?}"),
|
|
}
|
|
}
|