@@ -5,29 +5,119 @@ use std::pin::Pin;
55use std:: task:: { ready, Context , Poll } ;
66
77use http:: Uri ;
8+ use hyper:: rt:: { Read , Write } ;
89use hyper_util:: client:: legacy:: connect:: Connection ;
10+ use hyper_util:: rt:: TokioIo ;
911use pin_project_lite:: pin_project;
1012use tokio:: io:: { AsyncRead , AsyncWrite } ;
1113use tokio_rustls:: server:: TlsStream ;
1214use tonic:: transport:: server:: { Connected , TcpConnectInfo } ;
1315use tower:: Service ;
1416
17+ pin_project ! {
18+ /// A wrapper that adds hyper 1.0's Read/Write traits to any tokio AsyncRead/AsyncWrite type.
19+ /// This uses TokioIo internally to bridge between tokio and hyper traits.
20+ pub struct HyperStream <S > {
21+ #[ pin]
22+ inner: TokioIo <S >,
23+ }
24+ }
25+
26+ impl < S > HyperStream < S > {
27+ pub fn new ( stream : S ) -> Self {
28+ Self {
29+ inner : TokioIo :: new ( stream) ,
30+ }
31+ }
32+
33+ pub fn into_inner ( self ) -> S {
34+ self . inner . into_inner ( )
35+ }
36+ }
37+
38+ impl < S : AsyncRead + AsyncWrite + Unpin > AsyncRead for HyperStream < S > {
39+ fn poll_read (
40+ self : Pin < & mut Self > ,
41+ cx : & mut Context < ' _ > ,
42+ buf : & mut tokio:: io:: ReadBuf < ' _ > ,
43+ ) -> Poll < std:: io:: Result < ( ) > > {
44+ // SAFETY: HyperStream is Unpin if S is Unpin
45+ let this = unsafe { self . get_unchecked_mut ( ) } ;
46+ Pin :: new ( & mut this. inner ) . poll_read ( cx, buf)
47+ }
48+ }
49+
50+ impl < S : AsyncRead + AsyncWrite + Unpin > AsyncWrite for HyperStream < S > {
51+ fn poll_write (
52+ self : Pin < & mut Self > ,
53+ cx : & mut Context < ' _ > ,
54+ buf : & [ u8 ] ,
55+ ) -> Poll < std:: io:: Result < usize > > {
56+ let this = unsafe { self . get_unchecked_mut ( ) } ;
57+ Pin :: new ( & mut this. inner ) . poll_write ( cx, buf)
58+ }
59+
60+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
61+ let this = unsafe { self . get_unchecked_mut ( ) } ;
62+ Pin :: new ( & mut this. inner ) . poll_flush ( cx)
63+ }
64+
65+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
66+ let this = unsafe { self . get_unchecked_mut ( ) } ;
67+ Pin :: new ( & mut this. inner ) . poll_shutdown ( cx)
68+ }
69+ }
70+
71+ impl < S : AsyncRead + AsyncWrite + Unpin > Read for HyperStream < S > {
72+ fn poll_read (
73+ self : Pin < & mut Self > ,
74+ cx : & mut Context < ' _ > ,
75+ buf : hyper:: rt:: ReadBufCursor < ' _ > ,
76+ ) -> Poll < std:: io:: Result < ( ) > > {
77+ self . project ( ) . inner . poll_read ( cx, buf)
78+ }
79+ }
80+
81+ impl < S : AsyncRead + AsyncWrite + Unpin > Write for HyperStream < S > {
82+ fn poll_write (
83+ self : Pin < & mut Self > ,
84+ cx : & mut Context < ' _ > ,
85+ buf : & [ u8 ] ,
86+ ) -> Poll < std:: io:: Result < usize > > {
87+ self . project ( ) . inner . poll_write ( cx, buf)
88+ }
89+
90+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
91+ self . project ( ) . inner . poll_flush ( cx)
92+ }
93+
94+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
95+ self . project ( ) . inner . poll_shutdown ( cx)
96+ }
97+ }
98+
99+ impl < S : AsyncRead + AsyncWrite + Connection + Unpin > Connection for HyperStream < S > {
100+ fn connected ( & self ) -> hyper_util:: client:: legacy:: connect:: Connected {
101+ self . inner . inner ( ) . connected ( )
102+ }
103+ }
104+
15105pub trait Connector :
16106 Service < Uri , Response = Self :: Conn , Future = Self :: Fut , Error = Self :: Err >
17107 + Send
18108 + Sync
19109 + ' static
20110 + Clone
21111{
22- type Conn : Unpin + Send + ' static + AsyncRead + AsyncWrite + Connection ;
112+ type Conn : Unpin + Send + ' static + AsyncRead + AsyncWrite + Read + Write + Connection ;
23113 type Fut : Send + ' static + Unpin ;
24114 type Err : Into < Box < dyn StdError + Send + Sync > > + Send + Sync ;
25115}
26116
27117impl < T > Connector for T
28118where
29119 T : Service < Uri > + Send + Sync + ' static + Clone ,
30- T :: Response : Unpin + Send + ' static + AsyncRead + AsyncWrite + Connection ,
120+ T :: Response : Unpin + Send + ' static + AsyncRead + AsyncWrite + Read + Write + Connection ,
31121 T :: Future : Send + ' static + Unpin ,
32122 T :: Error : Into < Box < dyn StdError + Send + Sync > > + Send + Sync ,
33123{
36126 type Err = Self :: Error ;
37127}
38128
39- pub trait Conn : AsyncRead + AsyncWrite + Unpin + Send + ' static {
129+ pub trait Conn : AsyncRead + AsyncWrite + Read + Write + Unpin + Send + ' static {
40130 fn connect_info ( & self ) -> TcpConnectInfo ;
41131}
42132
@@ -107,11 +197,8 @@ where
107197 }
108198}
109199
110- impl < C : Conn > Conn for TlsStream < C > {
111- fn connect_info ( & self ) -> TcpConnectInfo {
112- self . get_ref ( ) . 0 . connect_info ( )
113- }
114- }
200+ // Note: TlsStream doesn't implement Conn directly because it doesn't implement hyper::rt::Read/Write.
201+ // Use HyperStream<TlsStream<C>> when you need a connection that implements Conn.
115202
116203impl < S > AsyncRead for AddrStream < S >
117204where
@@ -153,6 +240,54 @@ where
153240 }
154241}
155242
243+ impl < S > Read for AddrStream < S >
244+ where
245+ S : AsyncRead + AsyncWrite + Unpin ,
246+ {
247+ fn poll_read (
248+ self : Pin < & mut Self > ,
249+ cx : & mut Context < ' _ > ,
250+ mut buf : hyper:: rt:: ReadBufCursor < ' _ > ,
251+ ) -> Poll < std:: io:: Result < ( ) > > {
252+ // SAFETY: We're creating a tokio ReadBuf from the hyper ReadBufCursor
253+ let slice = unsafe {
254+ std:: slice:: from_raw_parts_mut ( buf. as_mut ( ) . as_mut_ptr ( ) , buf. as_mut ( ) . len ( ) )
255+ } ;
256+ let mut read_buf = tokio:: io:: ReadBuf :: new ( slice) ;
257+
258+ match self . project ( ) . stream . poll_read ( cx, & mut read_buf) {
259+ Poll :: Ready ( Ok ( ( ) ) ) => {
260+ let filled = read_buf. filled ( ) . len ( ) ;
261+ unsafe { buf. advance ( filled) } ;
262+ Poll :: Ready ( Ok ( ( ) ) )
263+ }
264+ Poll :: Ready ( Err ( e) ) => Poll :: Ready ( Err ( e) ) ,
265+ Poll :: Pending => Poll :: Pending ,
266+ }
267+ }
268+ }
269+
270+ impl < S > Write for AddrStream < S >
271+ where
272+ S : AsyncRead + AsyncWrite + Unpin ,
273+ {
274+ fn poll_write (
275+ self : Pin < & mut Self > ,
276+ cx : & mut Context < ' _ > ,
277+ buf : & [ u8 ] ,
278+ ) -> Poll < std:: io:: Result < usize > > {
279+ self . project ( ) . stream . poll_write ( cx, buf)
280+ }
281+
282+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
283+ self . project ( ) . stream . poll_flush ( cx)
284+ }
285+
286+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < std:: io:: Result < ( ) > > {
287+ self . project ( ) . stream . poll_shutdown ( cx)
288+ }
289+ }
290+
156291impl < S > Connected for AddrStream < S > {
157292 type ConnectInfo = TcpConnectInfo ;
158293
0 commit comments