@@ -149,9 +149,7 @@ const V3_PREFERRED_PROTOCOL_HEADER: http::HeaderValue =
149149#[ cfg( not( feature = "browser" ) ) ]
150150const MAX_V3_OUTBOUND_FRAME_BYTES : usize = 256 * 1024 ;
151151#[ cfg( not( feature = "browser" ) ) ]
152- const BSATN_SUM_TAG_BYTES : usize = 1 ;
153- #[ cfg( not( feature = "browser" ) ) ]
154- const BSATN_LENGTH_PREFIX_BYTES : usize = 4 ;
152+ const EMPTY_V3_SERVER_PAYLOAD_ERR : & str = "v3 websocket binary payload must contain at least one v2 server message" ;
155153
156154fn parse_scheme ( scheme : Option < Scheme > ) -> Result < Scheme , UriError > {
157155 Ok ( match scheme {
@@ -305,48 +303,39 @@ fn decode_v2_server_message(bytes: &[u8]) -> Result<ws::v2::ServerMessage, WsErr
305303 bsatn:: from_slice ( bytes) . map_err ( |source| WsError :: DeserializeMessage { source } )
306304}
307305
308- /// Expands a v3 server frame into the ordered sequence of encoded inner v2
309- /// server messages it carries.
310306#[ cfg( not( feature = "browser" ) ) ]
311- fn flatten_server_frame ( frame : ws:: v3:: ServerFrame ) -> Box < [ Bytes ] > {
312- match frame {
313- ws:: v3:: ServerFrame :: Single ( message) => Box :: new ( [ message] ) ,
314- ws:: v3:: ServerFrame :: Batch ( messages) => messages,
307+ fn parse_v3_server_messages ( bytes : & [ u8 ] ) -> Result < Vec < ws:: v2:: ServerMessage > , WsError > {
308+ let mut remaining = bytes;
309+ if remaining. is_empty ( ) {
310+ return Err ( WsError :: DeserializeMessage {
311+ source : bsatn:: DecodeError :: Other ( EMPTY_V3_SERVER_PAYLOAD_ERR . into ( ) ) ,
312+ } ) ;
313+ }
314+
315+ let mut messages = Vec :: new ( ) ;
316+ while !remaining. is_empty ( ) {
317+ messages. push ( bsatn:: from_reader ( & mut remaining) . map_err ( |source| WsError :: DeserializeMessage { source } ) ?) ;
315318 }
319+ Ok ( messages)
316320}
317321
318322/// Encodes one logical v2 client message into raw BSATN bytes.
319323fn encode_v2_client_message_bytes ( msg : & ws:: v2:: ClientMessage ) -> Bytes {
320324 Bytes :: from ( bsatn:: to_vec ( msg) . expect ( "should be able to bsatn encode v2 client message" ) )
321325}
322326
323- /// Wraps one or more encoded v2 client messages in a v3 transport frame.
324- #[ cfg( not( feature = "browser" ) ) ]
325- fn encode_v3_client_frame ( messages : Vec < Bytes > ) -> Bytes {
326- let frame = if messages. len ( ) == 1 {
327- ws:: v3:: ClientFrame :: Single ( messages. into_iter ( ) . next ( ) . unwrap ( ) )
328- } else {
329- ws:: v3:: ClientFrame :: Batch ( messages. into_boxed_slice ( ) )
330- } ;
331- Bytes :: from ( bsatn:: to_vec ( & frame) . expect ( "should be able to bsatn encode v3 client frame" ) )
332- }
333-
334- /// Returns the encoded size of a v3 `Single` frame carrying `message`.
335- #[ cfg( not( feature = "browser" ) ) ]
336- fn encoded_v3_single_frame_size ( message : & Bytes ) -> usize {
337- BSATN_SUM_TAG_BYTES + BSATN_LENGTH_PREFIX_BYTES + message. len ( )
338- }
339-
340- /// Returns the encoded size of a v3 `Batch` frame containing only its first logical message.
341327#[ cfg( not( feature = "browser" ) ) ]
342- fn encoded_v3_batch_frame_size_for_first_message ( message : & Bytes ) -> usize {
343- BSATN_SUM_TAG_BYTES + BSATN_LENGTH_PREFIX_BYTES + BSATN_LENGTH_PREFIX_BYTES + message. len ( )
344- }
328+ fn concatenate_v3_client_messages ( messages : Vec < Bytes > ) -> Bytes {
329+ if messages. len ( ) == 1 {
330+ return messages. into_iter ( ) . next ( ) . unwrap ( ) ;
331+ }
345332
346- /// Returns the encoded contribution of one additional logical message inside a v3 `Batch` frame.
347- #[ cfg( not( feature = "browser" ) ) ]
348- fn encoded_v3_batch_element_size ( message : & Bytes ) -> usize {
349- BSATN_LENGTH_PREFIX_BYTES + message. len ( )
333+ let total_len = messages. iter ( ) . map ( Bytes :: len) . sum ( ) ;
334+ let mut encoded = Vec :: with_capacity ( total_len) ;
335+ for message in messages {
336+ encoded. extend_from_slice ( & message) ;
337+ }
338+ encoded. into ( )
350339}
351340
352341/// Builds one bounded v3 transport frame from `first_message` and as many
@@ -363,25 +352,25 @@ where
363352 let first_message = encode_v2_client_message_bytes ( & first_message) ;
364353 // Oversized logical messages are still sent alone so they cannot block the
365354 // queue forever behind the frame-size limit.
366- if encoded_v3_single_frame_size ( & first_message) > MAX_V3_OUTBOUND_FRAME_BYTES {
367- if pending_outgoing. is_empty ( )
368- && let Some ( next_message) = try_next_outgoing_now ( )
369- {
370- pending_outgoing . push_front ( next_message ) ;
355+ if first_message. len ( ) > MAX_V3_OUTBOUND_FRAME_BYTES {
356+ if pending_outgoing. is_empty ( ) {
357+ if let Some ( next_message) = try_next_outgoing_now ( ) {
358+ pending_outgoing . push_front ( next_message ) ;
359+ }
371360 }
372361
373- return encode_v3_client_frame ( vec ! [ first_message] ) ;
362+ return first_message;
374363 }
375364
376365 let mut messages = vec ! [ first_message] ;
377- let mut batch_size = encoded_v3_batch_frame_size_for_first_message ( messages. first ( ) . unwrap ( ) ) ;
366+ let mut batch_size = messages. first ( ) . unwrap ( ) . len ( ) ;
378367
379368 loop {
380369 let Some ( next_message) = pending_outgoing. pop_front ( ) . or_else ( & mut try_next_outgoing_now) else {
381370 break ;
382371 } ;
383372 let next_message_bytes = encode_v2_client_message_bytes ( & next_message) ;
384- let next_batch_size = batch_size + encoded_v3_batch_element_size ( & next_message_bytes) ;
373+ let next_batch_size = batch_size + next_message_bytes. len ( ) ;
385374 if next_batch_size > MAX_V3_OUTBOUND_FRAME_BYTES {
386375 pending_outgoing. push_front ( next_message) ;
387376 break ;
@@ -390,7 +379,7 @@ where
390379 messages. push ( next_message_bytes) ;
391380 }
392381
393- encode_v3_client_frame ( messages)
382+ concatenate_v3_client_messages ( messages)
394383}
395384
396385/// Encodes the next outbound logical message according to the negotiated
@@ -583,15 +572,7 @@ impl WsConnection {
583572 let bytes = & * decompress_server_message ( bytes) ?;
584573 match protocol {
585574 NegotiatedWsProtocol :: V2 => Ok ( vec ! [ decode_v2_server_message( bytes) ?] ) ,
586- NegotiatedWsProtocol :: V3 => {
587- let frame: ws:: v3:: ServerFrame =
588- bsatn:: from_slice ( bytes) . map_err ( |source| WsError :: DeserializeMessage { source } ) ?;
589- flatten_server_frame ( frame)
590- . into_vec ( )
591- . into_iter ( )
592- . map ( |message| decode_v2_server_message ( & message) )
593- . collect ( )
594- }
575+ NegotiatedWsProtocol :: V3 => parse_v3_server_messages ( bytes) ,
595576 }
596577 }
597578
@@ -898,12 +879,22 @@ mod tests {
898879 encoded
899880 }
900881
901- fn encode_server_frame ( frame : & ws:: v3 :: ServerFrame ) -> Vec < u8 > {
882+ fn encode_server_messages ( messages : & [ ws:: v2 :: ServerMessage ] ) -> Vec < u8 > {
902883 let mut encoded = vec ! [ ws:: common:: SERVER_MSG_COMPRESSION_TAG_NONE ] ;
903- encoded. extend ( bsatn:: to_vec ( frame) . unwrap ( ) ) ;
884+ for message in messages {
885+ encoded. extend ( bsatn:: to_vec ( message) . unwrap ( ) ) ;
886+ }
904887 encoded
905888 }
906889
890+ fn decode_client_messages ( mut bytes : & [ u8 ] ) -> Vec < ws:: v2:: ClientMessage > {
891+ let mut messages = Vec :: new ( ) ;
892+ while !bytes. is_empty ( ) {
893+ messages. push ( bsatn:: from_reader ( & mut bytes) . unwrap ( ) ) ;
894+ }
895+ messages
896+ }
897+
907898 #[ test]
908899 fn negotiated_protocol_defaults_to_v2 ( ) {
909900 assert_eq ! (
@@ -940,11 +931,14 @@ mod tests {
940931 assert ! ( !has_leftover_pending_outgoing) ;
941932 assert ! ( pending. is_empty( ) ) ;
942933
943- let frame: ws:: v3:: ClientFrame = bsatn:: from_slice ( & raw ) . unwrap ( ) ;
944- let ws:: v3:: ClientFrame :: Batch ( messages) = frame else {
945- panic ! ( "expected batched v3 client frame" ) ;
946- } ;
934+ let messages = decode_client_messages ( & raw ) ;
947935 assert_eq ! ( messages. len( ) , 2 ) ;
936+ for ( expected_request_id, message) in [ 1 , 2 ] . into_iter ( ) . zip ( messages) {
937+ match message {
938+ ws:: v2:: ClientMessage :: CallReducer ( call) => assert_eq ! ( call. request_id, expected_request_id) ,
939+ _ => panic ! ( "expected CallReducer v3 client message" ) ,
940+ }
941+ }
948942 }
949943
950944 #[ test]
@@ -960,12 +954,9 @@ mod tests {
960954 assert ! ( has_leftover_pending_outgoing) ;
961955 assert_eq ! ( pending. len( ) , 1 ) ;
962956
963- let frame: ws:: v3:: ClientFrame = bsatn:: from_slice ( & raw ) . unwrap ( ) ;
964- let ws:: v3:: ClientFrame :: Single ( message) = frame else {
965- panic ! ( "expected single v3 client frame" ) ;
966- } ;
967- let inner: ws:: v2:: ClientMessage = bsatn:: from_slice ( & message) . unwrap ( ) ;
968- match inner {
957+ let messages = decode_client_messages ( & raw ) ;
958+ assert_eq ! ( messages. len( ) , 1 ) ;
959+ match & messages[ 0 ] {
969960 ws:: v2:: ClientMessage :: CallReducer ( call) => assert_eq ! ( call. request_id, 1 ) ,
970961 _ => panic ! ( "expected CallReducer inner message" ) ,
971962 }
@@ -994,14 +985,7 @@ mod tests {
994985 fn parse_response_unwraps_v3_batches ( ) {
995986 let first = procedure_result ( 1 ) ;
996987 let second = procedure_result ( 2 ) ;
997- let frame = ws:: v3:: ServerFrame :: Batch (
998- vec ! [
999- Bytes :: from( bsatn:: to_vec( & first) . unwrap( ) ) ,
1000- Bytes :: from( bsatn:: to_vec( & second) . unwrap( ) ) ,
1001- ]
1002- . into_boxed_slice ( ) ,
1003- ) ;
1004- let encoded = encode_server_frame ( & frame) ;
988+ let encoded = encode_server_messages ( & [ first, second] ) ;
1005989
1006990 let messages = WsConnection :: parse_responses ( NegotiatedWsProtocol :: V3 , & encoded) . unwrap ( ) ;
1007991 assert_eq ! ( messages. len( ) , 2 ) ;
0 commit comments