@@ -7,10 +7,13 @@ use anyhow::Context as _;
77use bottomless:: replicator:: Options ;
88use bytes:: Bytes ;
99use enclose:: enclose;
10+ use fallible_iterator:: FallibleIterator ;
1011use futures:: Stream ;
1112use libsql_sys:: EncryptionConfig ;
1213use rusqlite:: hooks:: { AuthAction , AuthContext , Authorization } ;
13- use tokio:: io:: AsyncBufReadExt as _;
14+ use sqlite3_parser:: ast:: { Cmd , Stmt } ;
15+ use sqlite3_parser:: lexer:: sql:: { Parser , ParserError } ;
16+ use tokio:: io:: AsyncReadExt ;
1417use tokio:: task:: JoinSet ;
1518use tokio_util:: io:: StreamReader ;
1619
@@ -33,9 +36,6 @@ use crate::{StatsSender, BLOCKING_RT, DB_CREATE_TIMEOUT, DEFAULT_AUTO_CHECKPOINT
3336
3437use super :: { BaseNamespaceConfig , PrimaryConfig } ;
3538
36- const WASM_TABLE_CREATE : & str =
37- "CREATE TABLE libsql_wasm_func_table (name text PRIMARY KEY, body text) WITHOUT ROWID;" ;
38-
3939#[ tracing:: instrument( skip_all) ]
4040pub ( super ) async fn make_primary_connection_maker (
4141 primary_config : & PrimaryConfig ,
@@ -290,184 +290,94 @@ async fn run_periodic_compactions(logger: Arc<ReplicationLogger>) -> anyhow::Res
290290 }
291291}
292292
293- fn tokenize_sql_keywords ( text : & str ) -> Vec < String > {
294- let mut tokens = Vec :: new ( ) ;
295- let mut chars = text. chars ( ) . peekable ( ) ;
296- let mut current_token = String :: new ( ) ;
297- let mut in_string_literal = false ;
298- let mut string_delimiter = '\0' ;
299-
300- while let Some ( ch) = chars. next ( ) {
301- match ch {
302- '\'' | '"' => {
303- if !in_string_literal {
304- in_string_literal = true ;
305- string_delimiter = ch;
306- } else if ch == string_delimiter {
307- in_string_literal = false ;
308- }
309- }
310- c if c. is_whitespace ( ) || "(){}[];," . contains ( c) => {
311- if in_string_literal {
312- continue ;
313- }
314- if !current_token. is_empty ( ) {
315- tokens. push ( current_token. to_uppercase ( ) ) ;
316- current_token. clear ( ) ;
317- }
318- }
319- // Regular characters
320- _ => {
321- if !in_string_literal {
322- current_token. push ( ch) ;
323- }
324- }
325- }
326- }
327-
328- if !current_token. is_empty ( ) && !in_string_literal {
329- tokens. push ( current_token. to_uppercase ( ) ) ;
330- }
331-
332- tokens
333- }
334-
335- fn is_complete_sql_statement ( sql : & str ) -> bool {
336- let tokens = tokenize_sql_keywords ( sql) ;
337- let mut begin_end_depth = 0 ;
338- let mut case_depth = 0 ;
339-
340- for ( i, token) in tokens. iter ( ) . enumerate ( ) {
341- match token. as_str ( ) {
342- "CASE" => {
343- case_depth += 1 ;
344- }
345- "BEGIN" => {
346- let next_token = tokens. get ( i + 1 ) . map ( |s| s. as_str ( ) ) ;
347- let is_transaction_keyword = matches ! (
348- next_token,
349- Some ( "TRANSACTION" ) | Some ( "IMMEDIATE" ) | Some ( "EXCLUSIVE" ) | Some ( "DEFERRED" )
350- ) ;
351-
352- if !is_transaction_keyword {
353- begin_end_depth += 1 ;
354- }
355- }
356- "END" => {
357- if case_depth > 0 {
358- case_depth -= 1 ;
359- } else {
360- // This is a block-ending END (BEGIN/END, IF/END IF, etc.)
361- let is_control_flow_end = tokens
362- . get ( i + 1 )
363- . map ( |next| matches ! ( next. as_str( ) , "IF" | "LOOP" | "WHILE" ) )
364- . unwrap_or ( false ) ;
365-
366- if !is_control_flow_end {
367- begin_end_depth -= 1 ;
368- }
369- }
370- }
371- _ => { }
372- }
373-
374- if begin_end_depth < 0 {
375- return false ;
376- }
377- }
378-
379- begin_end_depth == 0 && case_depth == 0
380- }
381-
382293async fn load_dump < S > ( dump : S , conn : PrimaryConnection ) -> crate :: Result < ( ) , LoadDumpError >
383294where
384295 S : Stream < Item = std:: io:: Result < Bytes > > + Unpin ,
385296{
386297 let mut reader = tokio:: io:: BufReader :: new ( StreamReader :: new ( dump) ) ;
387- let mut curr = String :: new ( ) ;
388- let mut line = String :: new ( ) ;
298+ let mut dump_content = String :: new ( ) ;
299+ reader
300+ . read_to_string ( & mut dump_content)
301+ . await
302+ . map_err ( |e| LoadDumpError :: Internal ( format ! ( "Failed to read dump content: {}" , e) ) ) ?;
303+
304+ if dump_content. to_lowercase ( ) . contains ( "attach" ) {
305+ return Err ( LoadDumpError :: InvalidSqlInput (
306+ "attach statements are not allowed in dumps" . to_string ( ) ,
307+ ) ) ;
308+ }
309+
310+ let mut parser = Box :: new ( Parser :: new ( dump_content. as_bytes ( ) ) ) ;
389311 let mut skipped_wasm_table = false ;
390312 let mut n_stmt = 0 ;
391- let mut line_id = 0 ;
392313
393- while let Ok ( n) = reader. read_line ( & mut curr) . await {
394- line_id += 1 ;
395- if n == 0 {
396- break ;
397- }
398- let trimmed = curr. trim ( ) ;
399- if trimmed. is_empty ( ) || trimmed. starts_with ( "--" ) {
400- curr. clear ( ) ;
401- continue ;
402- }
314+ loop {
315+ match parser. next ( ) {
316+ Ok ( Some ( cmd) ) => {
317+ n_stmt += 1 ;
318+
319+ if !skipped_wasm_table {
320+ if let Cmd :: Stmt ( Stmt :: CreateTable { tbl_name, .. } ) = & cmd {
321+ if tbl_name. name . 0 == "libsql_wasm_func_table" {
322+ skipped_wasm_table = true ;
323+ tracing:: debug!( "Skipping WASM table creation" ) ;
324+ continue ;
325+ }
326+ }
327+ }
403328
404- // we want to concat original(non-trimmed) lines as trimming will join all them in one
405- // single-line statement which is incorrect if comments in the end are present
406- line. push_str ( & curr) ;
407- let statement_end = trimmed. ends_with ( ';' ) && is_complete_sql_statement ( & line) ;
408- curr. clear ( ) ;
409-
410- // This is a hack to ignore the libsql_wasm_func_table table because it is already created
411- // by the system.
412- if !skipped_wasm_table && line. trim ( ) == WASM_TABLE_CREATE {
413- skipped_wasm_table = true ;
414- line. clear ( ) ;
415- continue ;
416- }
329+ if n_stmt > 2 && conn. is_autocommit ( ) . await . unwrap ( ) {
330+ return Err ( LoadDumpError :: NoTxn ) ;
331+ }
417332
418- if statement_end {
419- n_stmt += 1 ;
420- // dump must be performd within a txn
421- if n_stmt > 2 && conn. is_autocommit ( ) . await . unwrap ( ) {
422- return Err ( LoadDumpError :: NoTxn ) ;
333+ let stmt_sql = cmd. to_string ( ) ;
334+ tokio:: task:: spawn_blocking ( {
335+ let conn = conn. clone ( ) ;
336+ move || -> crate :: Result < ( ) , LoadDumpError > {
337+ conn. with_raw ( |conn| {
338+ conn. authorizer ( Some ( |auth : AuthContext < ' _ > | match auth. action {
339+ AuthAction :: Attach { filename : _ } => Authorization :: Deny ,
340+ _ => Authorization :: Allow ,
341+ } ) ) ;
342+ conn. execute ( & stmt_sql, ( ) )
343+ } )
344+ . map_err ( |e| match e {
345+ rusqlite:: Error :: SqlInputError {
346+ msg, sql, offset, ..
347+ } => LoadDumpError :: InvalidSqlInput ( format ! (
348+ "msg: {}, sql: {}, offset: {}" ,
349+ msg, sql, offset
350+ ) ) ,
351+ e => LoadDumpError :: Internal ( format ! (
352+ "statement: {}, error: {}" ,
353+ n_stmt, e
354+ ) ) ,
355+ } ) ?;
356+ Ok ( ( ) )
357+ }
358+ } )
359+ . await ??;
423360 }
361+ Ok ( None ) => break ,
362+ Err ( e) => {
363+ let error_msg = match e {
364+ sqlite3_parser:: lexer:: sql:: Error :: ParserError (
365+ ParserError :: SyntaxError { token_type, found } ,
366+ Some ( ( line, col) ) ,
367+ ) => {
368+ let near_token = found. as_deref ( ) . unwrap_or ( & token_type) ;
369+ format ! (
370+ "syntax error near '{}' at line {}, column {}" ,
371+ near_token, line, col
372+ )
373+ }
374+ _ => format ! ( "parse error: {}" , e) ,
375+ } ;
424376
425- line = tokio:: task:: spawn_blocking ( {
426- let conn = conn. clone ( ) ;
427- move || -> crate :: Result < String , LoadDumpError > {
428- conn. with_raw ( |conn| {
429- conn. authorizer ( Some ( |auth : AuthContext < ' _ > | match auth. action {
430- AuthAction :: Attach { filename : _ } => Authorization :: Deny ,
431- _ => Authorization :: Allow ,
432- } ) ) ;
433- conn. execute ( & line, ( ) )
434- } )
435- . map_err ( |e| match e {
436- rusqlite:: Error :: SqlInputError {
437- msg, sql, offset, ..
438- } => {
439- let msg = if sql. to_lowercase ( ) . contains ( "attach" ) {
440- format ! (
441- "attach statements are not allowed in dumps, msg: {}, sql: {}, offset: {}" ,
442- msg,
443- sql,
444- offset
445- )
446- } else {
447- format ! ( "msg: {}, sql: {}, offset: {}" , msg, sql, offset)
448- } ;
449-
450- LoadDumpError :: InvalidSqlInput ( msg)
451- }
452- e => LoadDumpError :: Internal ( format ! ( "line: {}, error: {}" , line_id, e) ) ,
453- } ) ?;
454- Ok ( line)
455- }
456- } )
457- . await ??;
458- line. clear ( ) ;
459- } else {
460- line. push ( ' ' ) ;
377+ return Err ( LoadDumpError :: InvalidSqlInput ( error_msg) ) ;
378+ }
461379 }
462380 }
463- tracing:: debug!( "loaded {} lines from dump" , line_id) ;
464-
465- if !line. trim ( ) . is_empty ( ) {
466- return Err ( LoadDumpError :: InvalidSqlInput ( format ! (
467- "Incomplete SQL statement at end of dump: {}" ,
468- line. trim( )
469- ) ) ) ;
470- }
471381
472382 if !conn. is_autocommit ( ) . await . unwrap ( ) {
473383 tokio:: task:: spawn_blocking ( {
0 commit comments