Skip to content

Commit 865154e

Browse files
committed
Add query timeout option to interrupt long-running queries
A single background tokio task with a min-heap manages all query deadlines efficiently. When a query starts, a TimeoutGuard is acquired; if the deadline expires before the guard is dropped, the connection is interrupted via sqlite3_interrupt().
1 parent fe57952 commit 865154e

5 files changed

Lines changed: 271 additions & 5 deletions

File tree

docs/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ You can use the `options` parameter to specify various options. Options supporte
2222
- `syncPeriod`: synchronize the database periodically every `syncPeriod` seconds.
2323
- `authToken`: authentication token for the provider URL (optional).
2424
- `timeout`: number of milliseconds to wait on locked database before returning `SQLITE_BUSY` error
25+
- `queryTimeout`: maximum number of milliseconds a query is allowed to run before being interrupted with `SQLITE_INTERRUPT` error
2526

2627
The function returns a `Database` object.
2728

integration-tests/tests/async.test.js

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,35 @@ test.serial("Timeout option", async (t) => {
393393
fs.unlinkSync(path);
394394
});
395395

396+
test.serial("Query timeout option interrupts long-running query", async (t) => {
397+
const queryTimeout = 100;
398+
const path = genDatabaseFilename();
399+
const [db, errorType] = await connect(path, { queryTimeout });
400+
const stmt = await db.prepare(
401+
"WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"
402+
);
403+
404+
await t.throwsAsync(async () => {
405+
await stmt.all();
406+
}, {
407+
instanceOf: errorType,
408+
message: "interrupted",
409+
code: "SQLITE_INTERRUPT",
410+
});
411+
412+
db.close();
413+
fs.unlinkSync(path);
414+
});
415+
416+
test.serial("Query timeout option allows short-running query", async (t) => {
417+
const path = genDatabaseFilename();
418+
const [db] = await connect(path, { queryTimeout: 100 });
419+
const stmt = await db.prepare("SELECT 1 AS value");
420+
t.deepEqual(await stmt.get(), { value: 1 });
421+
db.close();
422+
fs.unlinkSync(path);
423+
});
424+
396425
test.serial("Concurrent writes over same connection", async (t) => {
397426
const db = t.context.db;
398427
await db.exec(`

integration-tests/tests/sync.test.js

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,44 @@ test.serial("Timeout option", async (t) => {
457457
fs.unlinkSync(path);
458458
});
459459

460+
test.serial("Query timeout option interrupts long-running query", async (t) => {
461+
if (t.context.provider === "sqlite") {
462+
t.assert(true);
463+
return;
464+
}
465+
466+
const path = genDatabaseFilename();
467+
const [db, errorType] = await connect(path, { queryTimeout: 100 });
468+
const stmt = db.prepare(
469+
"WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"
470+
);
471+
472+
t.throws(() => {
473+
stmt.all();
474+
}, {
475+
instanceOf: errorType,
476+
message: "interrupted",
477+
code: "SQLITE_INTERRUPT",
478+
});
479+
480+
db.close();
481+
fs.unlinkSync(path);
482+
});
483+
484+
test.serial("Query timeout option allows short-running query", async (t) => {
485+
if (t.context.provider === "sqlite") {
486+
t.assert(true);
487+
return;
488+
}
489+
490+
const path = genDatabaseFilename();
491+
const [db] = await connect(path, { queryTimeout: 100 });
492+
const stmt = db.prepare("SELECT 1 AS value");
493+
t.deepEqual(stmt.get(), { value: 1 });
494+
db.close();
495+
fs.unlinkSync(path);
496+
});
497+
460498
test.serial("Statement.reader [SELECT is true]", async (t) => {
461499
const db = t.context.db;
462500
const stmt = db.prepare("SELECT * FROM users WHERE id = ?");

src/lib.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#![allow(deprecated)]
2222

2323
mod auth;
24+
mod query_timeout;
2425

2526
use napi::{
2627
bindgen_prelude::{Array, FromNapiValue, ToNapiValue},
@@ -200,6 +201,8 @@ pub struct Options {
200201
pub encryptionKey: Option<String>,
201202
// Encryption key for remote encryption at rest.
202203
pub remoteEncryptionKey: Option<String>,
204+
// Maximum time in milliseconds that a query is allowed to run.
205+
pub queryTimeout: Option<f64>,
203206
}
204207

205208
/// Access mode.
@@ -224,6 +227,10 @@ pub struct Database {
224227
default_safe_integers: AtomicBool,
225228
// Whether to use memory-only mode.
226229
memory: bool,
230+
// Maximum time in milliseconds that a query is allowed to run.
231+
query_timeout: Option<Duration>,
232+
// Shared timeout manager for efficient query timeout handling.
233+
timeout_manager: Arc<query_timeout::QueryTimeoutManager>,
227234
}
228235

229236
impl Drop for Database {
@@ -320,11 +327,19 @@ pub async fn connect(path: String, opts: Option<Options>) -> Result<Database> {
320327
conn.busy_timeout(Duration::from_millis(timeout as u64))
321328
.map_err(Error::from)?
322329
}
330+
let query_timeout = opts
331+
.as_ref()
332+
.and_then(|o| o.queryTimeout)
333+
.filter(|&t| t > 0.0)
334+
.map(|t| Duration::from_millis(t as u64));
335+
let timeout_manager = Arc::new(query_timeout::QueryTimeoutManager::new());
323336
Ok(Database {
324337
db,
325338
conn: Some(Arc::new(conn)),
326339
default_safe_integers,
327340
memory,
341+
query_timeout,
342+
timeout_manager,
328343
})
329344
}
330345

@@ -387,7 +402,7 @@ impl Database {
387402
pluck: false.into(),
388403
timing: false.into(),
389404
};
390-
Ok(Statement::new(conn, stmt, mode))
405+
Ok(Statement::new(conn, stmt, mode, self.query_timeout, self.timeout_manager.clone()))
391406
}
392407

393408
/// Sets the authorizer for the database.
@@ -515,6 +530,7 @@ impl Database {
515530
));
516531
}
517532
};
533+
let _guard = self.query_timeout.map(|t| self.timeout_manager.register(&conn, t));
518534
conn.execute_batch(&sql).await.map_err(Error::from)?;
519535
Ok(())
520536
}
@@ -620,6 +636,10 @@ pub struct Statement {
620636
column_names: Vec<std::ffi::CString>,
621637
// The access mode.
622638
mode: AccessMode,
639+
// Maximum time in milliseconds that a query is allowed to run.
640+
query_timeout: Option<Duration>,
641+
// Shared timeout manager.
642+
timeout_manager: Arc<query_timeout::QueryTimeoutManager>,
623643
}
624644

625645
#[napi]
@@ -635,6 +655,8 @@ impl Statement {
635655
conn: Arc<libsql::Connection>,
636656
stmt: libsql::Statement,
637657
mode: AccessMode,
658+
query_timeout: Option<Duration>,
659+
timeout_manager: Arc<query_timeout::QueryTimeoutManager>,
638660
) -> Self {
639661
let column_names: Vec<std::ffi::CString> = stmt
640662
.columns()
@@ -647,6 +669,8 @@ impl Statement {
647669
stmt,
648670
column_names,
649671
mode,
672+
query_timeout,
673+
timeout_manager,
650674
}
651675
}
652676

@@ -663,8 +687,10 @@ impl Statement {
663687
let start = std::time::Instant::now();
664688
let stmt = self.stmt.clone();
665689
let conn = self.conn.clone();
690+
let guard = self.start_timeout_guard();
666691

667692
let future = async move {
693+
let _guard = guard;
668694
stmt.run(params).await.map_err(Error::from)?;
669695
let changes = if conn.total_changes() == total_changes_before {
670696
0
@@ -707,7 +733,9 @@ impl Statement {
707733
};
708734

709735
let stmt_fut = stmt.clone();
736+
let guard = self.start_timeout_guard();
710737
let future = async move {
738+
let _guard = guard;
711739
let mut rows = stmt_fut.query(params).await.map_err(Error::from)?;
712740
let row = rows.next().await.map_err(Error::from)?;
713741
let duration: Option<f64> = start.map(|start| start.elapsed().as_secs_f64());
@@ -771,6 +799,7 @@ impl Statement {
771799
stmt.reset();
772800
let params = map_params(&stmt, params).unwrap();
773801
let stmt = self.stmt.clone();
802+
let guard = self.start_timeout_guard();
774803
let future = async move {
775804
let rows = stmt.query(params).await.map_err(Error::from)?;
776805
Ok::<_, napi::Error>(rows)
@@ -783,6 +812,7 @@ impl Statement {
783812
safe_ints,
784813
raw,
785814
pluck,
815+
guard,
786816
))
787817
})
788818
}
@@ -864,6 +894,13 @@ impl Statement {
864894
self.stmt.interrupt().map_err(Error::from)?;
865895
Ok(())
866896
}
897+
898+
}
899+
900+
impl Statement {
901+
fn start_timeout_guard(&self) -> Option<query_timeout::TimeoutGuard> {
902+
self.query_timeout.map(|t| self.timeout_manager.register(&self.conn, t))
903+
}
867904
}
868905

869906
/// Gets first row from statement in blocking mode.
@@ -885,6 +922,7 @@ pub fn statement_get_sync(
885922
};
886923

887924
let rt = runtime()?;
925+
let _guard = stmt.start_timeout_guard();
888926
rt.block_on(async move {
889927
let params = map_params(&stmt.stmt, params)?;
890928
let mut rows = stmt.stmt.query(params).await.map_err(Error::from)?;
@@ -909,6 +947,7 @@ pub fn statement_get_sync(
909947
pub fn statement_run_sync(stmt: &Statement, params: Option<napi::JsUnknown>) -> Result<RunResult> {
910948
stmt.stmt.reset();
911949
let rt = runtime()?;
950+
let _guard = stmt.start_timeout_guard();
912951
rt.block_on(async move {
913952
let params = map_params(&stmt.stmt, params)?;
914953
let total_changes_before = stmt.conn.total_changes();
@@ -940,11 +979,12 @@ pub fn statement_iterate_sync(
940979
let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst);
941980
let raw = stmt.mode.raw.load(Ordering::SeqCst);
942981
let pluck = stmt.mode.pluck.load(Ordering::SeqCst);
943-
let stmt = stmt.stmt.clone();
982+
let guard = stmt.start_timeout_guard();
983+
let inner_stmt = stmt.stmt.clone();
944984
let (rows, column_names) = rt.block_on(async move {
945-
stmt.reset();
946-
let params = map_params(&stmt, params)?;
947-
let rows = stmt.query(params).await.map_err(Error::from)?;
985+
inner_stmt.reset();
986+
let params = map_params(&inner_stmt, params)?;
987+
let rows = inner_stmt.query(params).await.map_err(Error::from)?;
948988
let mut column_names = Vec::new();
949989
for i in 0..rows.column_count() {
950990
column_names
@@ -958,6 +998,7 @@ pub fn statement_iterate_sync(
958998
safe_ints,
959999
raw,
9601000
pluck,
1001+
guard,
9611002
))
9621003
}
9631004

@@ -1104,6 +1145,7 @@ pub struct RowsIterator {
11041145
safe_ints: bool,
11051146
raw: bool,
11061147
pluck: bool,
1148+
_timeout_guard: Option<query_timeout::TimeoutGuard>,
11071149
}
11081150

11091151
#[napi]
@@ -1114,13 +1156,15 @@ impl RowsIterator {
11141156
safe_ints: bool,
11151157
raw: bool,
11161158
pluck: bool,
1159+
timeout_guard: Option<query_timeout::TimeoutGuard>,
11171160
) -> Self {
11181161
Self {
11191162
rows,
11201163
column_names,
11211164
safe_ints,
11221165
raw,
11231166
pluck,
1167+
_timeout_guard: timeout_guard,
11241168
}
11251169
}
11261170

0 commit comments

Comments
 (0)