Use Ruma XMatrix type instead of rolling our own

Both the hand-rolled parser and serialization were wrong in countless
ways. The current Ruma parser is much better, and the Ruma serialization
will be fixed by https://github.com/ruma/ruma/pull/1830.
This commit is contained in:
Lambda 2024-06-03 16:17:43 +00:00
parent aec314ce85
commit 62dd097f49
4 changed files with 45 additions and 88 deletions

18
Cargo.lock generated
View file

@ -2129,6 +2129,7 @@ dependencies = [
"ruma-federation-api", "ruma-federation-api",
"ruma-identity-service-api", "ruma-identity-service-api",
"ruma-push-gateway-api", "ruma-push-gateway-api",
"ruma-server-util",
"ruma-signatures", "ruma-signatures",
"ruma-state-res", "ruma-state-res",
"web-time", "web-time",
@ -2279,6 +2280,17 @@ dependencies = [
"serde_json", "serde_json",
] ]
[[package]]
name = "ruma-server-util"
version = "0.3.0"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"headers",
"ruma-common",
"tracing",
"yap",
]
[[package]] [[package]]
name = "ruma-signatures" name = "ruma-signatures"
version = "0.15.0" version = "0.15.0"
@ -3676,6 +3688,12 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049"
[[package]]
name = "yap"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfe269e7b803a5e8e20cbd97860e136529cd83bf2c9c6d37b142467e7e1f051f"
[[package]] [[package]]
name = "zerocopy" name = "zerocopy"
version = "0.7.34" version = "0.7.34"

View file

@ -117,7 +117,7 @@ regex = "1.10.4"
reqwest = { version = "0.12.4", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] } reqwest = { version = "0.12.4", default-features = false, features = ["http2", "rustls-tls-native-roots", "socks"] }
ring = "0.17.8" ring = "0.17.8"
rocksdb = { package = "rust-rocksdb", version = "0.26.0", features = ["lz4", "multi-threaded-cf", "zstd"], optional = true } rocksdb = { package = "rust-rocksdb", version = "0.26.0", features = ["lz4", "multi-threaded-cf", "zstd"], optional = true }
ruma = { git = "https://github.com/ruma/ruma", branch = "main", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] } ruma = { git = "https://github.com/ruma/ruma", branch = "main", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "server-util", "state-res", "unstable-msc2448", "unstable-msc3575", "unstable-exhaustive-types", "ring-compat", "unstable-unspecified" ] }
rusqlite = { version = "0.31.0", optional = true, features = ["bundled"] } rusqlite = { version = "0.31.0", optional = true, features = ["bundled"] }
rust-argon2 = "2.1.0" rust-argon2 = "2.1.0"
sd-notify = { version = "0.4.1", optional = true } sd-notify = { version = "0.4.1", optional = true }

View file

@ -8,10 +8,7 @@ use axum::{
RequestExt, RequestPartsExt, RequestExt, RequestPartsExt,
}; };
use axum_extra::{ use axum_extra::{
headers::{ headers::{authorization::Bearer, Authorization},
authorization::{Bearer, Credentials},
Authorization,
},
typed_header::TypedHeaderRejectionReason, typed_header::TypedHeaderRejectionReason,
TypedHeader, TypedHeader,
}; };
@ -23,6 +20,7 @@ use ruma::{
client::error::ErrorKind, AuthScheme, IncomingRequest, Metadata, client::error::ErrorKind, AuthScheme, IncomingRequest, Metadata,
OutgoingResponse, OutgoingResponse,
}, },
server_util::authorization::XMatrix,
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
}; };
use serde::Deserialize; use serde::Deserialize;
@ -199,7 +197,7 @@ async fn ar_from_request_inner(
})?; })?;
let origin_signatures = BTreeMap::from_iter([( let origin_signatures = BTreeMap::from_iter([(
x_matrix.key.clone(), x_matrix.key.to_string(),
CanonicalJsonValue::String(x_matrix.sig), CanonicalJsonValue::String(x_matrix.sig),
)]); )]);
@ -248,7 +246,7 @@ async fn ar_from_request_inner(
.event_handler .event_handler
.fetch_signing_keys( .fetch_signing_keys(
&x_matrix.origin, &x_matrix.origin,
vec![x_matrix.key.clone()], vec![x_matrix.key.to_string()],
) )
.await; .await;
@ -400,67 +398,6 @@ where
}) })
} }
} }
struct XMatrix {
origin: OwnedServerName,
// KeyName?
key: String,
sig: String,
}
impl Credentials for XMatrix {
const SCHEME: &'static str = "X-Matrix";
fn decode(value: &http::HeaderValue) -> Option<Self> {
debug_assert!(
value.as_bytes().starts_with(b"X-Matrix "),
"HeaderValue to decode should start with \"X-Matrix ..\", \
received = {value:?}",
);
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..])
.ok()?
.trim_start();
let mut origin = None;
let mut key = None;
let mut sig = None;
for entry in parameters.split_terminator(',') {
let (name, value) = entry.split_once('=')?;
// It's not at all clear why some fields are quoted and others not
// in the spec, let's simply accept either form for
// every field.
let value = value
.strip_prefix('"')
.and_then(|rest| rest.strip_suffix('"'))
.unwrap_or(value);
// FIXME: Catch multiple fields of the same name
match name {
"origin" => origin = Some(value.try_into().ok()?),
"key" => key = Some(value.to_owned()),
"sig" => sig = Some(value.to_owned()),
_ => debug!(
"Unexpected field `{}` in X-Matrix Authorization header",
name
),
}
}
Some(Self {
origin: origin?,
key: key?,
sig: sig?,
})
}
fn encode(&self) -> http::HeaderValue {
todo!()
}
}
impl<T: OutgoingResponse> IntoResponse for Ra<T> { impl<T: OutgoingResponse> IntoResponse for Ra<T> {
fn into_response(self) -> Response { fn into_response(self) -> Response {
match self.0.try_into_http_response::<BytesMut>() { match self.0.try_into_http_response::<BytesMut>() {

View file

@ -10,8 +10,8 @@ use std::{
}; };
use axum::{response::IntoResponse, Json}; use axum::{response::IntoResponse, Json};
use axum_extra::headers::{Authorization, HeaderMapExt};
use get_profile_information::v1::ProfileField; use get_profile_information::v1::ProfileField;
use http::header::{HeaderValue, AUTHORIZATION};
use ruma::{ use ruma::{
api::{ api::{
client::error::{Error as RumaError, ErrorKind}, client::error::{Error as RumaError, ErrorKind},
@ -54,10 +54,12 @@ use ruma::{
StateEventType, TimelineEventType, StateEventType, TimelineEventType,
}, },
serde::{Base64, JsonObject, Raw}, serde::{Base64, JsonObject, Raw},
server_util::authorization::XMatrix,
to_device::DeviceIdOrAllDevices, to_device::DeviceIdOrAllDevices,
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId,
MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedServerName, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedServerName,
OwnedServerSigningKeyId, OwnedUserId, RoomId, ServerName, OwnedServerSigningKeyId, OwnedSigningKeyId, OwnedUserId, RoomId,
ServerName,
}; };
use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::RwLock; use tokio::sync::RwLock;
@ -225,25 +227,25 @@ where
serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()) serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap())
.unwrap(); .unwrap();
let signatures = // There's exactly the one signature we just created, fish it back out again
request_json["signatures"].as_object().unwrap().values().map(|v| { let (key_id, signature) = request_json["signatures"]
v.as_object().unwrap().iter().map(|(k, v)| (k, v.as_str().unwrap())) .get(services().globals.server_name().as_str())
}); .unwrap()
.as_object()
.unwrap()
.iter()
.next()
.unwrap();
for signature_server in signatures { let key_id = OwnedSigningKeyId::try_from(key_id.clone()).unwrap();
for s in signature_server { let signature = signature.as_str().unwrap().to_owned();
http_request.headers_mut().insert(
AUTHORIZATION, http_request.headers_mut().typed_insert(Authorization(XMatrix::new(
HeaderValue::from_str(&format!( services().globals.server_name().to_owned(),
"X-Matrix origin={},key=\"{}\",sig=\"{}\"", None,
services().globals.server_name(), key_id,
s.0, signature,
s.1 )));
))
.unwrap(),
);
}
}
let reqwest_request = reqwest::Request::try_from(http_request)?; let reqwest_request = reqwest::Request::try_from(http_request)?;