From 62dd097f49a0c6d597391194724bdd1e058da487 Mon Sep 17 00:00:00 2001 From: Lambda Date: Mon, 3 Jun 2024 16:17:43 +0000 Subject: [PATCH] Use Ruma XMatrix type instead of rolling our own Both the hand-rolled parser and serialization were wrong in countless ways. The current Ruma parser is much better, and the Ruma serialization will be fixed by https://github.com/ruma/ruma/pull/1830. --- Cargo.lock | 18 +++++++++ Cargo.toml | 2 +- src/api/ruma_wrapper/axum.rs | 71 ++---------------------------------- src/api/server_server.rs | 42 +++++++++++---------- 4 files changed, 45 insertions(+), 88 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1f913bef..1f3762d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2129,6 +2129,7 @@ dependencies = [ "ruma-federation-api", "ruma-identity-service-api", "ruma-push-gateway-api", + "ruma-server-util", "ruma-signatures", "ruma-state-res", "web-time", @@ -2279,6 +2280,17 @@ dependencies = [ "serde_json", ] +[[package]] +name = "ruma-server-util" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012" +dependencies = [ + "headers", + "ruma-common", + "tracing", + "yap", +] + [[package]] name = "ruma-signatures" version = "0.15.0" @@ -3676,6 +3688,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "yap" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe269e7b803a5e8e20cbd97860e136529cd83bf2c9c6d37b142467e7e1f051f" + [[package]] name = "zerocopy" version = "0.7.34" diff --git a/Cargo.toml b/Cargo.toml index ba259e43..7b9a6872 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -117,7 +117,7 @@ regex = "1.10.4" reqwest = { version = "0.12.4", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] } ring = "0.17.8" rocksdb = { package = "rust-rocksdb", version = "0.26.0", features = ["lz4", "multi-threaded-cf", "zstd"], optional = true } -ruma = { git = "https://github.com/ruma/ruma", branch = "main", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] } +ruma = { git = "https://github.com/ruma/ruma", branch = "main", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "server-util", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] } rusqlite = { version = "0.31.0", optional = true, features = ["bundled"] } rust-argon2 = "2.1.0" sd-notify = { version = "0.4.1", optional = true } diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 4abef7fb..451b6fda 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -8,10 +8,7 @@ use axum::{ RequestExt, RequestPartsExt, }; use axum_extra::{ - headers::{ - authorization::{Bearer, Credentials}, - Authorization, - }, + headers::{authorization::Bearer, Authorization}, typed_header::TypedHeaderRejectionReason, TypedHeader, }; @@ -23,6 +20,7 @@ use ruma::{ client::error::ErrorKind, AuthScheme, IncomingRequest, Metadata, OutgoingResponse, }, + server_util::authorization::XMatrix, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, }; use serde::Deserialize; @@ -199,7 +197,7 @@ async fn ar_from_request_inner( })?; let origin_signatures = BTreeMap::from_iter([( - x_matrix.key.clone(), + x_matrix.key.to_string(), CanonicalJsonValue::String(x_matrix.sig), )]); @@ -248,7 +246,7 @@ async fn ar_from_request_inner( .event_handler .fetch_signing_keys( &x_matrix.origin, - vec![x_matrix.key.clone()], + vec![x_matrix.key.to_string()], ) .await; @@ -400,67 +398,6 @@ where }) } } - -struct XMatrix { - origin: OwnedServerName, - // KeyName? - key: String, - sig: String, -} - -impl Credentials for XMatrix { - const SCHEME: &'static str = "X-Matrix"; - - fn decode(value: &http::HeaderValue) -> Option { - debug_assert!( - value.as_bytes().starts_with(b"X-Matrix "), - "HeaderValue to decode should start with \"X-Matrix ..\", \ - received = {value:?}", - ); - - let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]) - .ok()? - .trim_start(); - - let mut origin = None; - let mut key = None; - let mut sig = None; - - for entry in parameters.split_terminator(',') { - let (name, value) = entry.split_once('=')?; - - // It's not at all clear why some fields are quoted and others not - // in the spec, let's simply accept either form for - // every field. - let value = value - .strip_prefix('"') - .and_then(|rest| rest.strip_suffix('"')) - .unwrap_or(value); - - // FIXME: Catch multiple fields of the same name - match name { - "origin" => origin = Some(value.try_into().ok()?), - "key" => key = Some(value.to_owned()), - "sig" => sig = Some(value.to_owned()), - _ => debug!( - "Unexpected field `{}` in X-Matrix Authorization header", - name - ), - } - } - - Some(Self { - origin: origin?, - key: key?, - sig: sig?, - }) - } - - fn encode(&self) -> http::HeaderValue { - todo!() - } -} - impl IntoResponse for Ra { fn into_response(self) -> Response { match self.0.try_into_http_response::() { diff --git a/src/api/server_server.rs b/src/api/server_server.rs index e69a52fc..b936b1e7 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -10,8 +10,8 @@ use std::{ }; use axum::{response::IntoResponse, Json}; +use axum_extra::headers::{Authorization, HeaderMapExt}; use get_profile_information::v1::ProfileField; -use http::header::{HeaderValue, AUTHORIZATION}; use ruma::{ api::{ client::error::{Error as RumaError, ErrorKind}, @@ -54,10 +54,12 @@ use ruma::{ StateEventType, TimelineEventType, }, serde::{Base64, JsonObject, Raw}, + server_util::authorization::XMatrix, to_device::DeviceIdOrAllDevices, uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedServerName, - OwnedServerSigningKeyId, OwnedUserId, RoomId, ServerName, + OwnedServerSigningKeyId, OwnedSigningKeyId, OwnedUserId, RoomId, + ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::sync::RwLock; @@ -225,25 +227,25 @@ where serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()) .unwrap(); - let signatures = - request_json["signatures"].as_object().unwrap().values().map(|v| { - v.as_object().unwrap().iter().map(|(k, v)| (k, v.as_str().unwrap())) - }); + // There's exactly the one signature we just created, fish it back out again + let (key_id, signature) = request_json["signatures"] + .get(services().globals.server_name().as_str()) + .unwrap() + .as_object() + .unwrap() + .iter() + .next() + .unwrap(); - for signature_server in signatures { - for s in signature_server { - http_request.headers_mut().insert( - AUTHORIZATION, - HeaderValue::from_str(&format!( - "X-Matrix origin={},key=\"{}\",sig=\"{}\"", - services().globals.server_name(), - s.0, - s.1 - )) - .unwrap(), - ); - } - } + let key_id = OwnedSigningKeyId::try_from(key_id.clone()).unwrap(); + let signature = signature.as_str().unwrap().to_owned(); + + http_request.headers_mut().typed_insert(Authorization(XMatrix::new( + services().globals.server_name().to_owned(), + None, + key_id, + signature, + ))); let reqwest_request = reqwest::Request::try_from(http_request)?;