Skip to content

Commit fc52448

Browse files
committed
Add query timeout option to interrupt long-running queries
A single background tokio task with a min-heap manages all query deadlines efficiently. When a query starts, a TimeoutGuard is acquired; if the deadline expires before the guard is dropped, the connection is interrupted via sqlite3_interrupt().
1 parent 7f95424 commit fc52448

5 files changed

Lines changed: 283 additions & 5 deletions

File tree

docs/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ You can use the `options` parameter to specify various options. Options supporte
2222
- `syncPeriod`: synchronize the database periodically every `syncPeriod` seconds.
2323
- `authToken`: authentication token for the provider URL (optional).
2424
- `timeout`: number of milliseconds to wait on locked database before returning `SQLITE_BUSY` error
25+
- `queryTimeout`: maximum number of milliseconds a query is allowed to run before being interrupted with `SQLITE_INTERRUPT` error
2526

2627
The function returns a `Database` object.
2728

integration-tests/tests/async.test.js

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,35 @@ test.serial("Timeout option", async (t) => {
393393
fs.unlinkSync(path);
394394
});
395395

396+
test.serial("Query timeout option interrupts long-running query", async (t) => {
397+
const queryTimeout = 100;
398+
const path = genDatabaseFilename();
399+
const [db, errorType] = await connect(path, { queryTimeout });
400+
const stmt = await db.prepare(
401+
"WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"
402+
);
403+
404+
await t.throwsAsync(async () => {
405+
await stmt.all();
406+
}, {
407+
instanceOf: errorType,
408+
message: "interrupted",
409+
code: "SQLITE_INTERRUPT",
410+
});
411+
412+
db.close();
413+
fs.unlinkSync(path);
414+
});
415+
416+
test.serial("Query timeout option allows short-running query", async (t) => {
417+
const path = genDatabaseFilename();
418+
const [db] = await connect(path, { queryTimeout: 100 });
419+
const stmt = await db.prepare("SELECT 1 AS value");
420+
t.deepEqual(await stmt.get(), { value: 1 });
421+
db.close();
422+
fs.unlinkSync(path);
423+
});
424+
396425
test.serial("Concurrent writes over same connection", async (t) => {
397426
const db = t.context.db;
398427
await db.exec(`

integration-tests/tests/sync.test.js

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,44 @@ test.serial("Timeout option", async (t) => {
457457
fs.unlinkSync(path);
458458
});
459459

460+
test.serial("Query timeout option interrupts long-running query", async (t) => {
461+
if (t.context.provider === "sqlite") {
462+
t.assert(true);
463+
return;
464+
}
465+
466+
const path = genDatabaseFilename();
467+
const [db, errorType] = await connect(path, { queryTimeout: 100 });
468+
const stmt = db.prepare(
469+
"WITH RECURSIVE infinite_loop(n) AS (SELECT 1 UNION ALL SELECT n + 1 FROM infinite_loop) SELECT * FROM infinite_loop;"
470+
);
471+
472+
t.throws(() => {
473+
stmt.all();
474+
}, {
475+
instanceOf: errorType,
476+
message: "interrupted",
477+
code: "SQLITE_INTERRUPT",
478+
});
479+
480+
db.close();
481+
fs.unlinkSync(path);
482+
});
483+
484+
test.serial("Query timeout option allows short-running query", async (t) => {
485+
if (t.context.provider === "sqlite") {
486+
t.assert(true);
487+
return;
488+
}
489+
490+
const path = genDatabaseFilename();
491+
const [db] = await connect(path, { queryTimeout: 100 });
492+
const stmt = db.prepare("SELECT 1 AS value");
493+
t.deepEqual(stmt.get(), { value: 1 });
494+
db.close();
495+
fs.unlinkSync(path);
496+
});
497+
460498
test.serial("Statement.reader [SELECT is true]", async (t) => {
461499
const db = t.context.db;
462500
const stmt = db.prepare("SELECT * FROM users WHERE id = ?");

src/lib.rs

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
#![allow(deprecated)]
2222

2323
mod auth;
24+
mod query_timeout;
2425

2526
use napi::{
2627
bindgen_prelude::{Array, FromNapiValue, ToNapiValue},
2728
Env, JsUnknown, Result, ValueType,
2829
};
2930
use napi_derive::napi;
3031
use once_cell::sync::OnceCell;
32+
use query_timeout::{QueryTimeoutManager, TimeoutGuard};
3133
use std::{
3234
str::FromStr,
3335
sync::{
@@ -200,6 +202,8 @@ pub struct Options {
200202
pub encryptionKey: Option<String>,
201203
// Encryption key for remote encryption at rest.
202204
pub remoteEncryptionKey: Option<String>,
205+
// Maximum time in milliseconds that a query is allowed to run.
206+
pub queryTimeout: Option<f64>,
203207
}
204208

205209
/// Access mode.
@@ -224,6 +228,10 @@ pub struct Database {
224228
default_safe_integers: AtomicBool,
225229
// Whether to use memory-only mode.
226230
memory: bool,
231+
// Maximum time in milliseconds that a query is allowed to run.
232+
query_timeout: Option<Duration>,
233+
// Shared timeout manager for efficient query timeout handling.
234+
timeout_manager: Arc<QueryTimeoutManager>,
227235
}
228236

229237
impl Drop for Database {
@@ -320,11 +328,19 @@ pub async fn connect(path: String, opts: Option<Options>) -> Result<Database> {
320328
conn.busy_timeout(Duration::from_millis(timeout as u64))
321329
.map_err(Error::from)?
322330
}
331+
let query_timeout = opts
332+
.as_ref()
333+
.and_then(|o| o.queryTimeout)
334+
.filter(|&t| t > 0.0)
335+
.map(|t| Duration::from_millis(t as u64));
336+
let timeout_manager = Arc::new(QueryTimeoutManager::new());
323337
Ok(Database {
324338
db,
325339
conn: Some(Arc::new(conn)),
326340
default_safe_integers,
327341
memory,
342+
query_timeout,
343+
timeout_manager,
328344
})
329345
}
330346

@@ -387,7 +403,13 @@ impl Database {
387403
pluck: false.into(),
388404
timing: false.into(),
389405
};
390-
Ok(Statement::new(conn, stmt, mode))
406+
Ok(Statement::new(
407+
conn,
408+
stmt,
409+
mode,
410+
self.query_timeout,
411+
self.timeout_manager.clone(),
412+
))
391413
}
392414

393415
/// Sets the authorizer for the database.
@@ -515,6 +537,9 @@ impl Database {
515537
));
516538
}
517539
};
540+
let _guard = self
541+
.query_timeout
542+
.map(|t| self.timeout_manager.register(&conn, t));
518543
conn.execute_batch(&sql).await.map_err(Error::from)?;
519544
Ok(())
520545
}
@@ -620,6 +645,10 @@ pub struct Statement {
620645
column_names: Vec<std::ffi::CString>,
621646
// The access mode.
622647
mode: AccessMode,
648+
// Maximum time in milliseconds that a query is allowed to run.
649+
query_timeout: Option<Duration>,
650+
// Shared timeout manager.
651+
timeout_manager: Arc<QueryTimeoutManager>,
623652
}
624653

625654
#[napi]
@@ -635,6 +664,8 @@ impl Statement {
635664
conn: Arc<libsql::Connection>,
636665
stmt: libsql::Statement,
637666
mode: AccessMode,
667+
query_timeout: Option<Duration>,
668+
timeout_manager: Arc<QueryTimeoutManager>,
638669
) -> Self {
639670
let column_names: Vec<std::ffi::CString> = stmt
640671
.columns()
@@ -647,6 +678,8 @@ impl Statement {
647678
stmt,
648679
column_names,
649680
mode,
681+
query_timeout,
682+
timeout_manager,
650683
}
651684
}
652685

@@ -663,8 +696,10 @@ impl Statement {
663696
let start = std::time::Instant::now();
664697
let stmt = self.stmt.clone();
665698
let conn = self.conn.clone();
699+
let guard = self.start_timeout_guard();
666700

667701
let future = async move {
702+
let _guard = guard;
668703
stmt.run(params).await.map_err(Error::from)?;
669704
let changes = if conn.total_changes() == total_changes_before {
670705
0
@@ -707,7 +742,9 @@ impl Statement {
707742
};
708743

709744
let stmt_fut = stmt.clone();
745+
let guard = self.start_timeout_guard();
710746
let future = async move {
747+
let _guard = guard;
711748
let mut rows = stmt_fut.query(params).await.map_err(Error::from)?;
712749
let row = rows.next().await.map_err(Error::from)?;
713750
let duration: Option<f64> = start.map(|start| start.elapsed().as_secs_f64());
@@ -771,6 +808,7 @@ impl Statement {
771808
stmt.reset();
772809
let params = map_params(&stmt, params).unwrap();
773810
let stmt = self.stmt.clone();
811+
let guard = self.start_timeout_guard();
774812
let future = async move {
775813
let rows = stmt.query(params).await.map_err(Error::from)?;
776814
Ok::<_, napi::Error>(rows)
@@ -783,6 +821,7 @@ impl Statement {
783821
safe_ints,
784822
raw,
785823
pluck,
824+
guard,
786825
))
787826
})
788827
}
@@ -866,6 +905,13 @@ impl Statement {
866905
}
867906
}
868907

908+
impl Statement {
909+
fn start_timeout_guard(&self) -> Option<TimeoutGuard> {
910+
self.query_timeout
911+
.map(|t| self.timeout_manager.register(&self.conn, t))
912+
}
913+
}
914+
869915
/// Gets first row from statement in blocking mode.
870916
#[napi]
871917
pub fn statement_get_sync(
@@ -885,6 +931,7 @@ pub fn statement_get_sync(
885931
};
886932

887933
let rt = runtime()?;
934+
let _guard = stmt.start_timeout_guard();
888935
rt.block_on(async move {
889936
let params = map_params(&stmt.stmt, params)?;
890937
let mut rows = stmt.stmt.query(params).await.map_err(Error::from)?;
@@ -909,6 +956,7 @@ pub fn statement_get_sync(
909956
pub fn statement_run_sync(stmt: &Statement, params: Option<napi::JsUnknown>) -> Result<RunResult> {
910957
stmt.stmt.reset();
911958
let rt = runtime()?;
959+
let _guard = stmt.start_timeout_guard();
912960
rt.block_on(async move {
913961
let params = map_params(&stmt.stmt, params)?;
914962
let total_changes_before = stmt.conn.total_changes();
@@ -940,11 +988,12 @@ pub fn statement_iterate_sync(
940988
let safe_ints = stmt.mode.safe_ints.load(Ordering::SeqCst);
941989
let raw = stmt.mode.raw.load(Ordering::SeqCst);
942990
let pluck = stmt.mode.pluck.load(Ordering::SeqCst);
943-
let stmt = stmt.stmt.clone();
991+
let guard = stmt.start_timeout_guard();
992+
let inner_stmt = stmt.stmt.clone();
944993
let (rows, column_names) = rt.block_on(async move {
945-
stmt.reset();
946-
let params = map_params(&stmt, params)?;
947-
let rows = stmt.query(params).await.map_err(Error::from)?;
994+
inner_stmt.reset();
995+
let params = map_params(&inner_stmt, params)?;
996+
let rows = inner_stmt.query(params).await.map_err(Error::from)?;
948997
let mut column_names = Vec::new();
949998
for i in 0..rows.column_count() {
950999
column_names
@@ -958,6 +1007,7 @@ pub fn statement_iterate_sync(
9581007
safe_ints,
9591008
raw,
9601009
pluck,
1010+
guard,
9611011
))
9621012
}
9631013

@@ -1104,6 +1154,7 @@ pub struct RowsIterator {
11041154
safe_ints: bool,
11051155
raw: bool,
11061156
pluck: bool,
1157+
_timeout_guard: Option<TimeoutGuard>,
11071158
}
11081159

11091160
#[napi]
@@ -1114,13 +1165,15 @@ impl RowsIterator {
11141165
safe_ints: bool,
11151166
raw: bool,
11161167
pluck: bool,
1168+
timeout_guard: Option<TimeoutGuard>,
11171169
) -> Self {
11181170
Self {
11191171
rows,
11201172
column_names,
11211173
safe_ints,
11221174
raw,
11231175
pluck,
1176+
_timeout_guard: timeout_guard,
11241177
}
11251178
}
11261179

0 commit comments

Comments
 (0)