Skip to content

Commit 3b001c2

Browse files
committed
WIP: Hyper 1.0 migration - RPC server body type fixes
- Added TonicServiceWrapper to convert Incoming to BoxBody - Updated run_tls_server and run_plain_server signatures - Still working on body type trait mismatches - h2c temporarily disabled
1 parent 6ab7088 commit 3b001c2

3 files changed

Lines changed: 104 additions & 117 deletions

File tree

libsql-server/src/h2c.rs

Lines changed: 63 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,4 @@
11
//! Module that provides `h2c` server adapters.
2-
//!
3-
//! # What is `h2c`?
4-
//!
5-
//! `h2c` is a http1.1 upgrade token that allows us to accept http2 without
6-
//! going through tls/alpn while also accepting regular http1.1 requests. Since,
7-
//! our server does not do TLS there is no way to negotiate that an incoming
8-
//! connection is going to speak http2 or http1.1 so we must default to http1.1.
9-
//!
10-
//! # How does it work?
11-
//!
12-
//! The `H2c` service gets called on every http request that arrives to the
13-
//! server and checks if the request has an `upgrade` header set. If this
14-
//! header is set to `h2c` then it will start the upgrade process. If this
15-
//! header is not set the request continues normally without any upgrades.
16-
//!
17-
//! The upgrade process is quite simple, if the correct header value is set
18-
//! the server will spawn a background task, return status code `101`
19-
//! (switching protocols) and will set the same upgrade header with `h2c` as
20-
//! the value.
21-
//!
22-
//! The background task will wait for `hyper::upgrade::on` to complete. At this
23-
//! point when `on` completes it returns an `IO` object that we can read/write from.
24-
//! We then pass this into hyper's low level server connection type and force http2.
25-
//! This means from the point that the client gets back the upgrade headers and correct
26-
//! status code the connection will be immediealty speaking http2 and thus the upgrade
27-
//! is complete.
28-
//!
29-
//! ┌───────────────┐ upgrade:h2c ┌──────────────────┐
30-
//! │ http::request ├────────────────────────►│ upgrade to http2 │
31-
//! └─────┬─────────┘ └────────┬─────────┘
32-
//! │ │
33-
//! │ │
34-
//! │ │
35-
//! │ │
36-
//! │ │
37-
//! │ ┌─────────────────┐ │
38-
//! └────────────►│call axum router │◄───────────┘
39-
//! └─────────────────┘
402
413
use std::marker::PhantomData;
424
use std::pin::Pin;
@@ -45,44 +7,49 @@ use axum::body::Body;
457
use bytes::Bytes;
468
use http::header;
479
use http::{Request, Response};
10+
use http_body_util::BodyExt;
4811
use hyper_util::rt::{TokioExecutor, TokioIo};
4912
use hyper::server::conn::http2::Builder as Http2Builder;
5013
use tonic::transport::server::TcpConnectInfo;
5114
use tower::Service;
5215

16+
type BoxBody = http_body_util::combinators::BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>>;
5317
type BoxError = Box<dyn std::error::Error + Send + Sync>;
5418

5519
/// A `MakeService` adapter for [`H2c`] that injects connection
5620
/// info into the request extensions.
57-
#[derive(Debug, Clone)]
58-
pub struct H2cMaker<S, B> {
21+
#[derive(Debug)]
22+
pub struct H2cMaker<S> {
5923
s: S,
60-
_pd: PhantomData<fn(B)>,
6124
}
6225

63-
impl<S, B> H2cMaker<S, B> {
64-
pub fn new(s: S) -> Self {
26+
impl<S> Clone for H2cMaker<S>
27+
where
28+
S: Clone,
29+
{
30+
fn clone(&self) -> Self {
6531
Self {
66-
s,
67-
_pd: PhantomData,
32+
s: self.s.clone(),
6833
}
6934
}
7035
}
7136

72-
impl<S, C, B> Service<&C> for H2cMaker<S, B>
37+
impl<S> H2cMaker<S> {
38+
pub fn new(s: S) -> Self {
39+
Self { s }
40+
}
41+
}
42+
43+
impl<S, C> Service<&C> for H2cMaker<S>
7344
where
74-
S: Service<Request<Body>, Response = Response<B>> + Clone + Send + 'static,
45+
S: Service<Request<Body>> + Clone + Send + 'static,
7546
S::Future: Send + 'static,
7647
S::Error: Into<BoxError> + Sync + Send + 'static,
7748
S::Response: Send + 'static,
7849
C: crate::net::Conn,
79-
B: http_body::Body<Data = Bytes> + Send + 'static,
80-
B::Error: Into<BoxError> + Sync + Send + 'static,
8150
{
82-
type Response = H2c<S, B>;
83-
51+
type Response = H2c<S>;
8452
type Error = BoxError;
85-
8653
type Future =
8754
Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
8855

@@ -100,32 +67,41 @@ where
10067
Ok(H2c {
10168
s,
10269
connect_info,
103-
_pd: PhantomData,
10470
})
10571
})
10672
}
10773
}
10874

109-
/// A service that can perform `h2c` upgrades and will
110-
/// delegate calls to the inner service once a protocol
111-
/// has been selected.
112-
#[derive(Debug, Clone)]
113-
pub struct H2c<S, B> {
75+
/// A service that can perform `h2c` upgrades.
76+
#[derive(Debug)]
77+
pub struct H2c<S> {
11478
s: S,
11579
connect_info: TcpConnectInfo,
116-
_pd: PhantomData<fn(B)>,
11780
}
11881

119-
impl<S, B> Service<Request<Body>> for H2c<S, B>
82+
impl<S> Clone for H2c<S>
83+
where
84+
S: Clone,
85+
{
86+
fn clone(&self) -> Self {
87+
Self {
88+
s: self.s.clone(),
89+
connect_info: self.connect_info.clone(),
90+
}
91+
}
92+
}
93+
94+
// Service implementation for hyper 1.0's Incoming body type
95+
impl<S, B> Service<Request<hyper::body::Incoming>> for H2c<S>
12096
where
12197
S: Service<Request<Body>, Response = Response<B>> + Clone + Send + 'static,
12298
S::Future: Send + 'static,
12399
S::Error: Into<BoxError> + Sync + Send + 'static,
124100
S::Response: Send + 'static,
125101
B: http_body::Body<Data = Bytes> + Send + 'static,
126-
B::Error: Into<BoxError> + Sync + Send + 'static,
102+
B::Error: Into<BoxError> + Send + Sync + 'static,
127103
{
128-
type Response = Response<Body>;
104+
type Response = Response<BoxBody>;
129105
type Error = BoxError;
130106
type Future =
131107
Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
@@ -137,27 +113,29 @@ where
137113
std::task::Poll::Ready(Ok(()))
138114
}
139115

140-
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
116+
fn call(&mut self, mut req: Request<hyper::body::Incoming>) -> Self::Future {
141117
let mut svc = self.s.clone();
142118
let connect_info = self.connect_info.clone();
143119

144120
Box::pin(async move {
145121
req.extensions_mut().insert(connect_info.clone());
146122

147-
// Check if this request is a `h2c` upgrade, if it is not pass
148-
// the request to the inner service, which in our case is the
149-
// axum router.
123+
// Check if this request is a `h2c` upgrade
150124
if req.headers().get(header::UPGRADE) != Some(&http::HeaderValue::from_static("h2c")) {
151-
return svc
152-
.call(req)
153-
.await
154-
.map_err(Into::into);
125+
// Convert Incoming body to axum Body
126+
let (parts, incoming) = req.into_parts();
127+
let body = Body::from_stream(incoming);
128+
let req = Request::from_parts(parts, body);
129+
130+
let res = svc.call(req).await.map_err(Into::into)?;
131+
// Box the body to erase type
132+
let (parts, body) = res.into_parts();
133+
return Ok(Response::from_parts(parts, body.boxed()));
155134
}
156135

157136
tracing::debug!("Got a h2c upgrade request");
158137

159-
// We got a h2c header so lets spawn a task that will wait for the
160-
// upgrade to complete and start a http2 connection.
138+
// Spawn the upgrade handling
161139
tokio::spawn(async move {
162140
let upgraded_io = match hyper::upgrade::on(&mut req).await {
163141
Ok(io) => TokioIo::new(io),
@@ -172,23 +150,20 @@ where
172150
let executor = TokioExecutor::new();
173151
let conn = Http2Builder::new(executor);
174152

175-
// Create a service that handles incoming HTTP/2 requests
176-
let svc = hyper::service::service_fn(move |mut r: Request<hyper::body::Incoming>| {
177-
r.extensions_mut().insert(connect_info.clone());
178-
// Convert the axum service response
153+
// Create a service for HTTP/2
154+
let svc = hyper::service::service_fn(move |r: Request<hyper::body::Incoming>| {
179155
let svc_clone = svc.clone();
156+
let connect_info = connect_info.clone();
180157
async move {
181-
// Convert Request<Incoming> to Request<Body> for axum
182-
let (parts, body) = r.into_parts();
183-
let body = Body::from_stream(body);
184-
let req = Request::from_parts(parts, body);
158+
// Convert Request<Incoming> to Request<Body>
159+
let (parts, incoming) = r.into_parts();
160+
let mut req = Request::from_parts(parts, Body::from_stream(incoming));
161+
req.extensions_mut().insert(connect_info);
185162

186-
svc_clone.call(req).await.map(|res| {
187-
// Convert Response<B> to Response<BoxBody>
188-
let (parts, body) = res.into_parts();
189-
let body = body.boxed_unsync();
190-
Response::from_parts(parts, body)
191-
}).map_err(|e| Box::new(e) as BoxError)
163+
let res = svc_clone.call(req).await.map_err(|e| Box::new(e) as BoxError)?;
164+
// Box the body
165+
let (parts, body) = res.into_parts();
166+
Ok::<_, BoxError>(Response::from_parts(parts, body.boxed()))
192167
}
193168
});
194169

@@ -197,8 +172,8 @@ where
197172
}
198173
});
199174

200-
// Reply that we are switching protocols to h2
201-
let mut res = Response::new(Body::empty());
175+
// Return 101 Switching Protocols
176+
let mut res = Response::new(BoxBody::default());
202177
*res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS;
203178
res.headers_mut()
204179
.insert(header::UPGRADE, http::HeaderValue::from_static("h2c"));

libsql-server/src/http/admin/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,10 @@ where
253253
Ok(())
254254
}
255255

256-
async fn auth_middleware<B>(
256+
async fn auth_middleware(
257257
State(auth): State<Option<Arc<str>>>,
258-
request: Request<B>,
259-
next: Next<B>,
258+
request: Request,
259+
next: Next,
260260
) -> Result<axum::response::Response, StatusCode> {
261261
if let Some(ref auth) = auth {
262262
let Some(auth_header) = request.headers().get("authorization") else {

libsql-server/src/rpc/mod.rs

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use rustls::RootCertStore;
1111
use tokio_rustls::TlsAcceptor;
1212
use tonic::Status;
1313
use tower::util::option_layer;
14+
use tower::Service;
1415
use tower::ServiceBuilder;
1516
use tower_http::trace::DefaultOnResponse;
1617
use tracing::Span;
@@ -60,19 +61,48 @@ pub async fn run_rpc_server<A: Accept>(
6061
}
6162
}
6263

64+
/// Wrapper service that converts hyper 1.0's Incoming body to tonic's BoxBody
65+
#[derive(Clone)]
66+
struct TonicServiceWrapper<S> {
67+
inner: S,
68+
}
69+
70+
impl<S, B> Service<hyper::Request<hyper::body::Incoming>> for TonicServiceWrapper<S>
71+
where
72+
S: Service<hyper::Request<tonic::body::BoxBody>, Response = hyper::Response<B>, Error = std::convert::Infallible> + Clone + Send + 'static,
73+
S::Future: Send + 'static,
74+
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
75+
B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
76+
{
77+
type Response = hyper::Response<B>;
78+
type Error = std::convert::Infallible;
79+
type Future = S::Future;
80+
81+
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
82+
self.inner.poll_ready(cx)
83+
}
84+
85+
fn call(&mut self, req: hyper::Request<hyper::body::Incoming>) -> Self::Future {
86+
// Convert Incoming body to tonic's BoxBody
87+
let (parts, body) = req.into_parts();
88+
let body = tonic::body::BoxBody::new(body);
89+
let req = hyper::Request::from_parts(parts, body);
90+
self.inner.call(req)
91+
}
92+
}
93+
6394
async fn run_tls_server<A, S, B>(
6495
acceptor: &mut A,
6596
svc: S,
6697
tls_config: TlsConfig,
6798
) -> anyhow::Result<()>
6899
where
69100
A: Accept,
70-
S: tower::Service<http::Request<axum::body::Body>, Response = http::Response<B>>
101+
S: tower::Service<hyper::Request<tonic::body::BoxBody>, Response = hyper::Response<B>, Error = std::convert::Infallible>
71102
+ Clone
72103
+ Send
73104
+ 'static,
74105
S::Future: Send + 'static,
75-
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
76106
S::Response: Send + 'static,
77107
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
78108
B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
@@ -105,7 +135,8 @@ where
105135
let tls_acceptor = TlsAcceptor::from(Arc::new(config));
106136

107137
tracing::info!("serving internal rpc server with tls");
108-
let h2c_maker = crate::h2c::H2cMaker::new(svc);
138+
139+
let wrapped_svc = TonicServiceWrapper { inner: svc };
109140

110141
// Drive the acceptor stream manually for hyper 1.0+ compatibility
111142
loop {
@@ -119,7 +150,7 @@ where
119150
};
120151

121152
let tls_acceptor = tls_acceptor.clone();
122-
let mut h2c_maker = h2c_maker.clone();
153+
let svc = wrapped_svc.clone();
123154

124155
tokio::spawn(async move {
125156
let tls_stream = match tls_acceptor.accept(conn).await {
@@ -132,15 +163,6 @@ where
132163

133164
let io = TokioIo::new(tls_stream);
134165

135-
// Get the service for this connection
136-
let svc = match h2c_maker.call(&conn).await {
137-
Ok(svc) => svc,
138-
Err(e) => {
139-
tracing::error!("failed to create h2c service: {:#}", e);
140-
return;
141-
}
142-
};
143-
144166
if let Err(err) = ConnBuilder::new(TokioExecutor::new())
145167
.serve_connection(io, svc)
146168
.await
@@ -159,18 +181,17 @@ async fn run_plain_server<A, S, B>(
159181
) -> anyhow::Result<()>
160182
where
161183
A: Accept,
162-
S: tower::Service<http::Request<axum::body::Body>, Response = http::Response<B>>
184+
S: tower::Service<hyper::Request<tonic::body::BoxBody>, Response = hyper::Response<B>, Error = std::convert::Infallible>
163185
+ Clone
164186
+ Send
165187
+ 'static,
166188
S::Future: Send + 'static,
167-
S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
168189
S::Response: Send + 'static,
169190
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
170191
B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync + 'static,
171192
{
172193
tracing::info!("serving internal rpc server without tls");
173-
let h2c_maker = crate::h2c::H2cMaker::new(svc);
194+
let wrapped_svc = TonicServiceWrapper { inner: svc };
174195

175196
// Drive the acceptor stream manually for hyper 1.0+ compatibility
176197
loop {
@@ -183,20 +204,11 @@ where
183204
None => break,
184205
};
185206

186-
let mut h2c_maker = h2c_maker.clone();
207+
let svc = wrapped_svc.clone();
187208

188209
tokio::spawn(async move {
189210
let io = TokioIo::new(conn);
190211

191-
// Get the service for this connection
192-
let svc = match h2c_maker.call(&conn).await {
193-
Ok(svc) => svc,
194-
Err(e) => {
195-
tracing::error!("failed to create h2c service: {:#}", e);
196-
return;
197-
}
198-
};
199-
200212
if let Err(err) = ConnBuilder::new(TokioExecutor::new())
201213
.serve_connection(io, svc)
202214
.await

0 commit comments

Comments
 (0)