diff --git a/Cargo.lock b/Cargo.lock index f8e40fb3c6..4cc520d8ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3227,6 +3227,7 @@ dependencies = [ "rusqlite", "sha2", "thiserror 2.0.18", + "tokio", "tracing", ] @@ -3738,7 +3739,6 @@ dependencies = [ "aws-sdk-kms", "build-rs", "clap", - "diesel", "fs-err", "hex", "miden-node-db", diff --git a/Cargo.toml b/Cargo.toml index 8cc12b5316..0f1e2ce814 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -107,7 +107,7 @@ rand_chacha = { default-features = false, version = "0.9" } rayon = { version = "1.10" } reqwest = { version = "0.13" } rstest = { version = "0.26" } -rusqlite = { features = ["bundled"], version = "0.37" } +rusqlite = { features = ["array", "bundled"], version = "0.37" } serde = { features = ["derive"], version = "1" } serial_test = { version = "3.2" } sha2 = { version = "0.10" } diff --git a/bin/validator/Cargo.toml b/bin/validator/Cargo.toml index 17ddb44db1..0eb4002a6a 100644 --- a/bin/validator/Cargo.toml +++ b/bin/validator/Cargo.toml @@ -22,7 +22,6 @@ anyhow = { workspace = true } aws-config = { version = "1.8.14" } aws-sdk-kms = { version = "1.100" } clap = { features = ["env", "string"], workspace = true } -diesel = { workspace = true } fs-err = { workspace = true } hex = { workspace = true } miden-node-db = { workspace = true } diff --git a/bin/validator/diesel.toml b/bin/validator/diesel.toml deleted file mode 100644 index bdce9175fa..0000000000 --- a/bin/validator/diesel.toml +++ /dev/null @@ -1,5 +0,0 @@ -# For documentation on how to configure this file, -# see https://diesel.rs/guides/configuring-diesel-cli - -[print_schema] -file = "src/db/schema.rs" diff --git a/bin/validator/src/commands/bootstrap.rs b/bin/validator/src/commands/bootstrap.rs index 30d3484f0a..37a568ffd4 100644 --- a/bin/validator/src/commands/bootstrap.rs +++ b/bin/validator/src/commands/bootstrap.rs @@ -88,8 +88,8 @@ async fn build_and_write_genesis( ) .await .context("failed to initialize validator database during bootstrap")?; - db.transact("upsert_block_header", move |conn| { - miden_validator::db::upsert_block_header(conn, &genesis_header) + db.write("upsert_block_header", move |tx| { + miden_validator::db::upsert_block_header(tx, &genesis_header) }) .await .context("failed to persist genesis block header as chain tip")?; diff --git a/bin/validator/src/db/migrations.rs b/bin/validator/src/db/migrations.rs index 062b6fe165..c72eec2795 100644 --- a/bin/validator/src/db/migrations.rs +++ b/bin/validator/src/db/migrations.rs @@ -50,9 +50,7 @@ pub fn verify_latest_schema(database_filepath: &Path) -> std::result::Result<(), #[cfg(test)] mod tests { - use std::process::Command; - - use anyhow::{Context, Result, ensure}; + use anyhow::Result; use miden_node_db::migration::{SchemaHash, SchemaHashes}; use super::*; @@ -68,34 +66,4 @@ mod tests { assert_eq!(migrator.schema_hashes(), SchemaHashes(&EXPECTED_SCHEMA_HASHES)); Ok(()) } - - #[test] - #[ignore = "requires diesel CLI; CI runs this in the diesel-schema job"] - fn diesel_schema_is_in_sync_with_migrations() -> Result<()> { - let temp_dir = tempfile::tempdir()?; - let database_filepath = temp_dir.path().join("validator.sqlite3"); - bootstrap_database(&database_filepath)?; - - let output = Command::new("diesel") - .arg("print-schema") - .arg("--database-url") - .arg(&database_filepath) - .current_dir(env!("CARGO_MANIFEST_DIR")) - .output() - .context( - "failed to run diesel CLI; install it with \ - `cargo install diesel_cli --no-default-features --features sqlite`", - )?; - - ensure!( - output.status.success(), - "diesel print-schema failed: {}", - String::from_utf8_lossy(&output.stderr) - ); - - let generated = - String::from_utf8(output.stdout).context("diesel CLI output is not UTF-8")?; - assert_eq!(generated, include_str!("schema.rs")); - Ok(()) - } } diff --git a/bin/validator/src/db/mod.rs b/bin/validator/src/db/mod.rs index a2246cfed6..0bed78006e 100644 --- a/bin/validator/src/db/mod.rs +++ b/bin/validator/src/db/mod.rs @@ -1,27 +1,22 @@ mod migrations; -mod models; -mod schema; use std::num::NonZeroUsize; use std::path::{Path, PathBuf}; -use diesel::SqliteConnection; -use diesel::dsl::{count_star, exists}; -use diesel::prelude::*; -use miden_node_db::{DatabaseError, Db, SqlTypeConvert}; +use miden_node_db::DatabaseError; +use miden_node_db::sqlite::{Database, ReadTx, WriteTx}; use miden_protocol::block::{BlockHeader, BlockNumber}; use miden_protocol::transaction::TransactionId; -use miden_protocol::utils::serde::{Deserializable, Serializable}; +use miden_protocol::utils::serde::Serializable; use tracing::instrument; use crate::COMPONENT; use crate::db::migrations::{bootstrap_database, migrate_database, verify_latest_schema}; -use crate::db::models::{BlockHeaderRowInsert, ValidatedTransactionRowInsert}; use crate::tx_validation::ValidatedTransaction; /// Open a connection to the DB after verifying that it is at the latest schema version. #[instrument(target = COMPONENT, skip_all)] -pub async fn load(database_filepath: PathBuf) -> Result { +pub async fn load(database_filepath: PathBuf) -> Result { load_with_pool_size(database_filepath, miden_node_db::default_connection_pool_size()).await } @@ -31,7 +26,7 @@ pub async fn load(database_filepath: PathBuf) -> Result { pub async fn load_with_pool_size( database_filepath: PathBuf, connection_pool_size: NonZeroUsize, -) -> Result { +) -> Result { verify_latest_schema(&database_filepath)?; open_with_pool_size(&database_filepath, connection_pool_size) @@ -39,7 +34,7 @@ pub async fn load_with_pool_size( /// Creates a new database, applies all migrations, and opens a connection pool. #[instrument(target = COMPONENT, skip_all)] -pub async fn setup(database_filepath: PathBuf) -> Result { +pub async fn setup(database_filepath: PathBuf) -> Result { setup_with_pool_size(database_filepath, miden_node_db::default_connection_pool_size()).await } @@ -48,7 +43,7 @@ pub async fn setup(database_filepath: PathBuf) -> Result { pub async fn setup_with_pool_size( database_filepath: PathBuf, connection_pool_size: NonZeroUsize, -) -> Result { +) -> Result { bootstrap_database(&database_filepath)?; open_with_pool_size(&database_filepath, connection_pool_size) @@ -64,8 +59,8 @@ pub fn migrate(database_filepath: impl AsRef) -> Result<(), DatabaseError> fn open_with_pool_size( database_filepath: &Path, connection_pool_size: NonZeroUsize, -) -> Result { - let db = Db::new_with_pool_size(database_filepath, connection_pool_size)?; +) -> Result { + let db = Database::new_with_pool_size(database_filepath, connection_pool_size)?; tracing::info!( target: COMPONENT, sqlite= %database_filepath.display(), @@ -78,15 +73,37 @@ fn open_with_pool_size( /// Inserts a new validated transaction into the database. #[instrument(target = COMPONENT, skip_all, fields(tx_id = %tx_info.tx_id()), err)] pub(crate) fn insert_transaction( - conn: &mut SqliteConnection, + tx: &WriteTx<'_>, tx_info: &ValidatedTransaction, ) -> Result { - let row = ValidatedTransactionRowInsert::new(tx_info); - let count = diesel::insert_into(schema::validated_transactions::table) - .values(row) - .on_conflict_do_nothing() - .execute(conn)?; - Ok(count) + let id = tx_info.tx_id().to_bytes(); + let block_num = i64::from(tx_info.block_num().as_u32()); + let account_id = tx_info.account_id().to_bytes(); + let account_delta = tx_info.account_delta().to_bytes(); + let input_notes = tx_info.input_notes().to_bytes(); + let output_notes = tx_info.output_notes().to_bytes(); + let initial_account_hash = tx_info.initial_account_hash().to_bytes(); + let final_account_hash = tx_info.final_account_hash().to_bytes(); + let fee = tx_info.fee().amount().as_u64().to_le_bytes().to_vec(); + + tx.execute( + "INSERT INTO validated_transactions \ + (id, block_num, account_id, account_delta, input_notes, output_notes, \ + initial_account_hash, final_account_hash, fee) \ + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) \ + ON CONFLICT DO NOTHING", + &[ + &id, + &block_num, + &account_id, + &account_delta, + &input_notes, + &output_notes, + &initial_account_hash, + &final_account_hash, + &fee, + ], + ) } /// Scans the database for transaction Ids that do not exist. @@ -102,19 +119,24 @@ pub(crate) fn insert_transaction( /// WHERE id = ? /// ); /// ``` -#[instrument(target = COMPONENT, skip(conn), err)] +#[instrument(target = COMPONENT, skip(tx), err)] pub(crate) fn find_unvalidated_transactions( - conn: &mut SqliteConnection, + tx: &ReadTx<'_>, tx_ids: &[TransactionId], ) -> Result, DatabaseError> { let mut unvalidated_tx_ids = Vec::new(); for tx_id in tx_ids { // Check whether each transaction id exists in the database. - let exists = diesel::select(exists( - schema::validated_transactions::table - .filter(schema::validated_transactions::id.eq(tx_id.to_bytes())), - )) - .get_result::(conn)?; + let exists = tx + .query( + "SELECT EXISTS(SELECT 1 FROM validated_transactions WHERE id = ?1)", + &[&tx_id.to_bytes()], + |row| row.get::(0), + )? + .first() + .copied() + .unwrap_or(0) + != 0; // Record any transaction ids that do not exist. if !exists { unvalidated_tx_ids.push(*tx_id); @@ -127,70 +149,68 @@ pub(crate) fn find_unvalidated_transactions( /// /// Inserts a new row if no block header exists at the given block number, or replaces the /// existing block header if one already exists. -#[instrument(target = COMPONENT, skip(conn, header), err)] -pub fn upsert_block_header( - conn: &mut SqliteConnection, - header: &BlockHeader, -) -> Result<(), DatabaseError> { - let row = BlockHeaderRowInsert { - block_num: header.block_num().to_raw_sql(), - block_header: header.to_bytes(), - }; - diesel::replace_into(schema::block_headers::table).values(row).execute(conn)?; +#[instrument(target = COMPONENT, skip(tx, header), err)] +pub fn upsert_block_header(tx: &WriteTx<'_>, header: &BlockHeader) -> Result<(), DatabaseError> { + let block_num = i64::from(header.block_num().as_u32()); + let block_header = header.to_bytes(); + tx.execute( + "REPLACE INTO block_headers (block_num, block_header) VALUES (?1, ?2)", + &[&block_num, &block_header], + )?; Ok(()) } /// Loads the chain tip (block header with the highest block number) from the database. /// /// Returns `None` if no block headers have been persisted (i.e. bootstrap has not been run). -#[instrument(target = COMPONENT, skip(conn), err)] -pub fn load_chain_tip(conn: &mut SqliteConnection) -> Result, DatabaseError> { - let row = schema::block_headers::table - .order(schema::block_headers::block_num.desc()) - .select(schema::block_headers::block_header) - .first::>(conn) - .optional()?; - - row.map(|bytes| { - BlockHeader::read_from_bytes(&bytes) - .map_err(|err| DatabaseError::deserialization("BlockHeader", err)) - }) - .transpose() +#[instrument(target = COMPONENT, skip(tx), err)] +pub fn load_chain_tip(tx: &ReadTx<'_>) -> Result, DatabaseError> { + Ok(tx + .query( + "SELECT block_header FROM block_headers ORDER BY block_num DESC LIMIT 1", + &[], + |row| row.get::(0), + )? + .into_iter() + .next()) } /// Loads a block header by its block number. /// /// Returns `None` if no block header exists at the given block number. -#[instrument(target = COMPONENT, skip(conn), err)] +#[instrument(target = COMPONENT, skip(tx), err)] pub fn load_block_header( - conn: &mut SqliteConnection, + tx: &ReadTx<'_>, block_num: BlockNumber, ) -> Result, DatabaseError> { - let row = schema::block_headers::table - .filter(schema::block_headers::block_num.eq(block_num.to_raw_sql())) - .select(schema::block_headers::block_header) - .first::>(conn) - .optional()?; - - row.map(|bytes| { - BlockHeader::read_from_bytes(&bytes) - .map_err(|err| DatabaseError::deserialization("BlockHeader", err)) - }) - .transpose() + Ok(tx + .query( + "SELECT block_header FROM block_headers WHERE block_num = ?1", + &[&i64::from(block_num.as_u32())], + |row| row.get::(0), + )? + .into_iter() + .next()) } /// Returns the total number of validated transactions in the database. -#[instrument(target = COMPONENT, skip(conn), err)] -pub fn count_validated_transactions(conn: &mut SqliteConnection) -> Result { - let count = schema::validated_transactions::table.select(count_star()).first::(conn)?; - Ok(count) +#[instrument(target = COMPONENT, skip(tx), err)] +pub fn count_validated_transactions(tx: &ReadTx<'_>) -> Result { + Ok(tx + .query("SELECT COUNT(*) FROM validated_transactions", &[], |row| row.get::(0))? + .into_iter() + .next() + .unwrap_or(0)) } /// Returns the total number of signed blocks in the database. -#[instrument(target = COMPONENT, skip(conn), err)] -pub fn count_signed_blocks(conn: &mut SqliteConnection) -> Result { - let count = schema::block_headers::table.select(count_star()).first::(conn)?; - Ok(count) +#[instrument(target = COMPONENT, skip(tx), err)] +pub fn count_signed_blocks(tx: &ReadTx<'_>) -> Result { + Ok(tx + .query("SELECT COUNT(*) FROM block_headers", &[], |row| row.get::(0))? + .into_iter() + .next() + .unwrap_or(0)) } #[cfg(test)] diff --git a/bin/validator/src/db/models.rs b/bin/validator/src/db/models.rs deleted file mode 100644 index 85f1e7354d..0000000000 --- a/bin/validator/src/db/models.rs +++ /dev/null @@ -1,45 +0,0 @@ -use diesel::prelude::*; -use miden_node_db::SqlTypeConvert; -use miden_protocol::utils::serde::Serializable; - -use crate::db::schema; -use crate::tx_validation::ValidatedTransaction; - -#[derive(Debug, Clone, Insertable)] -#[diesel(table_name = schema::block_headers)] -#[diesel(check_for_backend(diesel::sqlite::Sqlite))] -pub struct BlockHeaderRowInsert { - pub block_num: i64, - pub block_header: Vec, -} - -#[derive(Debug, Clone, PartialEq, Insertable)] -#[diesel(table_name = schema::validated_transactions)] -#[diesel(check_for_backend(diesel::sqlite::Sqlite))] -pub struct ValidatedTransactionRowInsert { - pub id: Vec, - pub block_num: i64, - pub account_id: Vec, - pub account_delta: Vec, - pub input_notes: Vec, - pub output_notes: Vec, - pub initial_account_hash: Vec, - pub final_account_hash: Vec, - pub fee: Vec, -} - -impl ValidatedTransactionRowInsert { - pub fn new(tx: &ValidatedTransaction) -> Self { - Self { - id: tx.tx_id().to_bytes(), - block_num: tx.block_num().to_raw_sql(), - account_id: tx.account_id().to_bytes(), - account_delta: tx.account_delta().to_bytes(), - input_notes: tx.input_notes().to_bytes(), - output_notes: tx.output_notes().to_bytes(), - initial_account_hash: tx.initial_account_hash().to_bytes(), - final_account_hash: tx.final_account_hash().to_bytes(), - fee: tx.fee().amount().as_u64().to_le_bytes().to_vec(), - } - } -} diff --git a/bin/validator/src/db/schema.rs b/bin/validator/src/db/schema.rs deleted file mode 100644 index e78833e43b..0000000000 --- a/bin/validator/src/db/schema.rs +++ /dev/null @@ -1,24 +0,0 @@ -// @generated automatically by Diesel CLI. - -diesel::table! { - block_headers (block_num) { - block_num -> BigInt, - block_header -> Binary, - } -} - -diesel::table! { - validated_transactions (id) { - id -> Binary, - block_num -> BigInt, - account_id -> Binary, - account_delta -> Binary, - input_notes -> Nullable, - output_notes -> Nullable, - initial_account_hash -> Binary, - final_account_hash -> Binary, - fee -> Binary, - } -} - -diesel::allow_tables_to_appear_in_same_query!(block_headers, validated_transactions,); diff --git a/bin/validator/src/server/mod.rs b/bin/validator/src/server/mod.rs index 030d2cd3f5..6383d0a13c 100644 --- a/bin/validator/src/server/mod.rs +++ b/bin/validator/src/server/mod.rs @@ -71,10 +71,10 @@ impl ValidatorServer { // Load initial metrics from the database for the in-memory counters. let (initial_chain_tip, initial_tx_count, initial_block_count) = db - .query("load_initial_metrics", |conn| { - let tip = load_chain_tip(conn)?.map_or(0, |h| h.block_num().as_u32()); - let tx_count = u64::try_from(count_validated_transactions(conn)?).unwrap_or(0); - let block_count = u64::try_from(count_signed_blocks(conn)?).unwrap_or(0); + .read("load_initial_metrics", |tx| { + let tip = load_chain_tip(tx)?.map_or(0, |h| h.block_num().as_u32()); + let tx_count = u64::try_from(count_validated_transactions(tx)?).unwrap_or(0); + let block_count = u64::try_from(count_signed_blocks(tx)?).unwrap_or(0); Ok::<_, miden_node_db::DatabaseError>((tip, tx_count, block_count)) }) .await diff --git a/bin/validator/src/server/validator_service/mod.rs b/bin/validator/src/server/validator_service/mod.rs index b88a3a9533..789e07c954 100644 --- a/bin/validator/src/server/validator_service/mod.rs +++ b/bin/validator/src/server/validator_service/mod.rs @@ -1,7 +1,8 @@ use std::sync::Arc; use std::sync::atomic::AtomicU64; -use miden_node_db::{DatabaseError, Db}; +use miden_node_db::DatabaseError; +use miden_node_db::sqlite::Database; use miden_node_store::BlockStore; use miden_node_utils::tracing::OpenTelemetrySpanExt; use miden_protocol::block::{BlockHeader, BlockNumber, ProposedBlock, SignedBlock}; @@ -63,7 +64,7 @@ pub enum ValidatorError { /// Implements the gRPC API for the validator. pub(crate) struct ValidatorService { signer: ValidatorSigner, - db: Arc, + db: Arc, block_store: BlockStore, /// Serializes `sign_block` requests so that concurrent calls are processed sequentially, /// ensuring consistent chain tip reads and preventing race conditions. @@ -80,7 +81,7 @@ pub(crate) struct ValidatorService { impl ValidatorService { pub(crate) async fn new( signer: ValidatorSigner, - db: Db, + db: Database, block_store: BlockStore, initial_chain_tip: u32, initial_tx_count: u64, @@ -90,7 +91,7 @@ impl ValidatorService { // the signing key must match the chain's validator key for this validator's lifetime. // Reject a misconfigured key here. let chain_tip = db - .query("load_chain_tip", load_chain_tip) + .read("load_chain_tip", load_chain_tip) .await .map_err(ValidatorError::DatabaseError)? .ok_or(ValidatorError::NoChainTip)?; @@ -132,8 +133,8 @@ impl ValidatorService { proposed_block.transactions().map(TransactionHeader::id).collect::>(); let unvalidated_txs = self .db - .transact("find_unvalidated_transactions", move |conn| { - find_unvalidated_transactions(conn, &proposed_tx_ids) + .read("find_unvalidated_transactions", move |tx| { + find_unvalidated_transactions(tx, &proposed_tx_ids) }) .await .map_err(ValidatorError::DatabaseError)?; @@ -159,7 +160,7 @@ impl ValidatorService { let prev_block_num = chain_tip.block_num().parent().ok_or(ValidatorError::NoPrevBlockHeader)?; self.db - .query("load_block_header", move |conn| load_block_header(conn, prev_block_num)) + .read("load_block_header", move |tx| load_block_header(tx, prev_block_num)) .await .map_err(ValidatorError::DatabaseError)? .ok_or(ValidatorError::NoPrevBlockHeader)? diff --git a/bin/validator/src/server/validator_service/sign_block.rs b/bin/validator/src/server/validator_service/sign_block.rs index 699c3400ad..7152096e06 100644 --- a/bin/validator/src/server/validator_service/sign_block.rs +++ b/bin/validator/src/server/validator_service/sign_block.rs @@ -41,7 +41,7 @@ impl grpc::server::validator_api::SignBlock for ValidatorService { // Load the current chain tip from the database. let chain_tip = self .db - .query("load_chain_tip", load_chain_tip) + .read("load_chain_tip", load_chain_tip) .await .map_err(|err| { tonic::Status::internal(format!("Failed to load chain tip: {}", err.as_report())) @@ -64,7 +64,7 @@ impl grpc::server::validator_api::SignBlock for ValidatorService { // Persist the validated block header. let new_block_num = header.block_num().as_u32(); self.db - .transact("upsert_block_header", move |conn| upsert_block_header(conn, &header)) + .write("upsert_block_header", move |tx| upsert_block_header(tx, &header)) .await .map_err(|err| { tonic::Status::internal(format!( diff --git a/bin/validator/src/server/validator_service/submit_proven_transaction.rs b/bin/validator/src/server/validator_service/submit_proven_transaction.rs index d981fd90a6..e865c8f222 100644 --- a/bin/validator/src/server/validator_service/submit_proven_transaction.rs +++ b/bin/validator/src/server/validator_service/submit_proven_transaction.rs @@ -27,7 +27,7 @@ impl grpc::server::validator_api::SubmitProvenTransaction for ValidatorService { // Store the validated transaction. let count = self .db - .transact("insert_transaction", move |conn| insert_transaction(conn, &tx_info)) + .write("insert_transaction", move |tx| insert_transaction(tx, &tx_info)) .await .map_err(|err| { Status::internal(err.as_report_context("Failed to insert transaction")) diff --git a/bin/validator/src/server/validator_service/tests.rs b/bin/validator/src/server/validator_service/tests.rs index 80ecb28210..a35cb868f9 100644 --- a/bin/validator/src/server/validator_service/tests.rs +++ b/bin/validator/src/server/validator_service/tests.rs @@ -24,6 +24,9 @@ struct TestValidator { server: ValidatorService, chain: PartialBlockchain, chain_tip: BlockHeader, + // Keeps the database's temp directory alive for the validator's lifetime: the reader pool opens + // connections lazily, so the file must still exist when the first read runs. + _temp_dir: tempfile::TempDir, } impl TestValidator { @@ -32,12 +35,13 @@ impl TestValidator { async fn new() -> Self { let key = random_secret_key(); let signer = ValidatorSigner::new_local(key.clone()); - let (db, block_store, genesis_header) = setup_db_with_genesis(&key).await; + let (temp_dir, db, block_store, genesis_header) = setup_db_with_genesis(&key).await; Self { server: ValidatorService::new(signer, db, block_store, 0, 0, 0).await.unwrap(), chain: PartialBlockchain::default(), chain_tip: genesis_header, + _temp_dir: temp_dir, } } @@ -73,7 +77,7 @@ impl TestValidator { async fn load_chain_tip(&self) -> BlockHeader { self.server .db - .query("load_chain_tip", load_chain_tip) + .read("load_chain_tip", load_chain_tip) .await .unwrap() .expect("chain tip should exist") @@ -94,7 +98,9 @@ impl TestValidator { /// Creates a validator database seeded with a genesis block whose `validator_key` is the public key /// of `key`. Returns the database handle and the genesis block header. -async fn setup_db_with_genesis(key: &SigningKey) -> (miden_node_db::Db, BlockStore, BlockHeader) { +async fn setup_db_with_genesis( + key: &SigningKey, +) -> (tempfile::TempDir, miden_node_db::sqlite::Database, BlockStore, BlockHeader) { let genesis_state = GenesisState::new(vec![], test_fee_params(), 1, 0, key.public_key()); let genesis_block = genesis_state.into_block(key).unwrap(); let genesis_header = genesis_block.inner().header().clone(); @@ -104,14 +110,14 @@ async fn setup_db_with_genesis(key: &SigningKey) -> (miden_node_db::Db, BlockSto let block_store = BlockStore::bootstrap(dir.path().join("blocks").clone(), &genesis_block).unwrap(); - db.transact("upsert_genesis", { + db.write("upsert_genesis", { let h = genesis_header.clone(); - move |conn| upsert_block_header(conn, &h) + move |tx| upsert_block_header(tx, &h) }) .await .unwrap(); - (db, block_store, genesis_header) + (dir, db, block_store, genesis_header) } /// Builds an empty [`ProposedBlock`] that extends the given parent block header using the provided @@ -137,7 +143,7 @@ fn empty_block(parent_header: &BlockHeader, chain: &PartialBlockchain) -> Propos async fn signing_key_mismatch_rejected() { // Seed a database whose genesis designates `genesis_key` as the validator key. let genesis_key = random_secret_key(); - let (db, block_store, genesis_header) = setup_db_with_genesis(&genesis_key).await; + let (_temp_dir, db, block_store, genesis_header) = setup_db_with_genesis(&genesis_key).await; // Start a validator with a different key, modelling a validator configured with the wrong key. let rogue_signer = ValidatorSigner::new_local(random_secret_key()); diff --git a/crates/db/Cargo.toml b/crates/db/Cargo.toml index e3908822bf..487930ef4d 100644 --- a/crates/db/Cargo.toml +++ b/crates/db/Cargo.toml @@ -29,5 +29,8 @@ sha2 = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } +[dev-dependencies] +tokio = { features = ["macros", "rt-multi-thread"], workspace = true } + [lib] doctest = false diff --git a/crates/db/src/errors.rs b/crates/db/src/errors.rs index 3737f50693..955d4a1151 100644 --- a/crates/db/src/errors.rs +++ b/crates/db/src/errors.rs @@ -45,6 +45,8 @@ pub enum DatabaseError { }, #[error(transparent)] Diesel(#[from] diesel::result::Error), + #[error(transparent)] + Rusqlite(#[from] rusqlite::Error), #[error("failed to apply database migrations")] Migration(#[source] Box), #[error("schema verification failed")] diff --git a/crates/db/src/lib.rs b/crates/db/src/lib.rs index 2da3b5eb6f..29c2af3bbb 100644 --- a/crates/db/src/lib.rs +++ b/crates/db/src/lib.rs @@ -2,6 +2,7 @@ mod conv; mod errors; mod manager; pub mod migration; +pub mod sqlite; use std::num::NonZeroUsize; use std::path::Path; diff --git a/crates/db/src/sqlite/codec.rs b/crates/db/src/sqlite/codec.rs new file mode 100644 index 0000000000..9c2c4ac224 --- /dev/null +++ b/crates/db/src/sqlite/codec.rs @@ -0,0 +1,250 @@ +//! Column codec for the rusqlite-based SQLite framework. +//! +//! [`ToSqlValue`] and [`FromSqlValue`] are the per-column write/read codec for our domain types. +//! They operate on [`DbValue`]/[`DbValueRef`], thin wrappers over rusqlite's value types, so that +//! crates implementing a codec for their own types never have to name `rusqlite` directly. +//! +//! Most node types are stored as a BLOB via their `Serializable`/`Deserializable` impls; the +//! [`impl_blob_codec!`](crate::impl_blob_codec) macro generates both traits for such a type. Scalar +//! types map onto an SQLite `INTEGER`/`TEXT` and implement the traits directly (see the impls ported +//! from the legacy `SqlTypeConvert` below). + +use std::rc::Rc; + +use rusqlite::ToSql; +use rusqlite::types::{ToSqlOutput, Value, ValueRef}; + +use crate::DatabaseError; + +// DB VALUE WRAPPERS +// ================================================================================================= + +/// An owned SQL value produced when binding a Rust value as a query parameter. +/// +/// Wraps `rusqlite`'s value types so codec implementors never name `rusqlite`. A value is either a +/// single column value or a list bound for a `rarray(?)` table-valued parameter (used by the +/// cacheable IN-list helpers in [`in_list`](crate::sqlite::in_list)). +#[derive(Debug, Clone)] +pub enum DbValue { + /// A single SQL column value. + Single(Value), + /// A list of values bound via rusqlite's `array` extension for use with `rarray(?)`. + Array(Rc>), +} + +impl DbValue { + /// Builds an `INTEGER` value. + pub fn integer(value: i64) -> Self { + Self::Single(Value::Integer(value)) + } + + /// Builds a `REAL` value. + pub fn real(value: f64) -> Self { + Self::Single(Value::Real(value)) + } + + /// Builds a `TEXT` value. + pub fn text(value: String) -> Self { + Self::Single(Value::Text(value)) + } + + /// Builds a `BLOB` value. + pub fn blob(value: Vec) -> Self { + Self::Single(Value::Blob(value)) + } + + /// Builds a `NULL` value. + pub fn null() -> Self { + Self::Single(Value::Null) + } + + /// Builds a list value bound for a `rarray(?)` table-valued parameter. + pub(crate) fn array(values: Vec) -> Self { + Self::Array(Rc::new(values)) + } +} + +impl ToSql for DbValue { + fn to_sql(&self) -> rusqlite::Result> { + match self { + Self::Single(value) => value.to_sql(), + Self::Array(values) => values.to_sql(), + } + } +} + +/// A borrowed SQL value handed to [`FromSqlValue`] when reading a column. +/// +/// Wraps `rusqlite::types::ValueRef` so codec implementors never name `rusqlite`. +#[derive(Debug, Clone, Copy)] +pub struct DbValueRef<'a>(ValueRef<'a>); + +impl<'a> DbValueRef<'a> { + pub(crate) fn new(value: ValueRef<'a>) -> Self { + Self(value) + } + + /// Reads the value as an `i64`. + pub fn as_i64(self) -> Result { + self.0.as_i64().map_err(|err| DatabaseError::deserialization("i64", err)) + } + + /// Reads the value as a borrowed BLOB. + pub fn as_blob(self) -> Result<&'a [u8], DatabaseError> { + self.0.as_blob().map_err(|err| DatabaseError::deserialization("blob", err)) + } + + /// Reads the value as a borrowed string. + pub fn as_str(self) -> Result<&'a str, DatabaseError> { + self.0.as_str().map_err(|err| DatabaseError::deserialization("str", err)) + } + + /// Returns `true` if the value is SQL `NULL`. + pub fn is_null(self) -> bool { + matches!(self.0, ValueRef::Null) + } +} + +// CODEC TRAITS +// ================================================================================================= + +/// Converts a Rust value into its SQL parameter representation (the write side of the codec). +pub trait ToSqlValue { + /// Returns the SQL value bound for this Rust value. + fn to_sql_value(&self) -> DbValue; +} + +/// Builds a Rust value from a SQL column value (the read side of the codec). +pub trait FromSqlValue: Sized { + /// Reads `Self` from a SQL column value. + fn from_sql_value(value: DbValueRef<'_>) -> Result; +} + +// Forward `ToSqlValue` through references so callers can pass `&value` in a parameter slice. +impl ToSqlValue for &T { + fn to_sql_value(&self) -> DbValue { + (**self).to_sql_value() + } +} + +// PRIMITIVE IMPLS +// ================================================================================================= + +impl ToSqlValue for i64 { + fn to_sql_value(&self) -> DbValue { + DbValue::integer(*self) + } +} + +impl FromSqlValue for i64 { + fn from_sql_value(value: DbValueRef<'_>) -> Result { + value.as_i64() + } +} + +impl ToSqlValue for bool { + fn to_sql_value(&self) -> DbValue { + DbValue::integer(i64::from(*self)) + } +} + +impl FromSqlValue for bool { + fn from_sql_value(value: DbValueRef<'_>) -> Result { + Ok(value.as_i64()? != 0) + } +} + +impl ToSqlValue for Vec { + fn to_sql_value(&self) -> DbValue { + DbValue::blob(self.clone()) + } +} + +impl FromSqlValue for Vec { + fn from_sql_value(value: DbValueRef<'_>) -> Result { + Ok(value.as_blob()?.to_vec()) + } +} + +impl ToSqlValue for str { + fn to_sql_value(&self) -> DbValue { + DbValue::text(self.to_owned()) + } +} + +impl ToSqlValue for String { + fn to_sql_value(&self) -> DbValue { + DbValue::text(self.clone()) + } +} + +impl FromSqlValue for String { + fn from_sql_value(value: DbValueRef<'_>) -> Result { + Ok(value.as_str()?.to_owned()) + } +} + +impl ToSqlValue for Option { + fn to_sql_value(&self) -> DbValue { + match self { + Some(value) => value.to_sql_value(), + None => DbValue::null(), + } + } +} + +impl FromSqlValue for Option { + fn from_sql_value(value: DbValueRef<'_>) -> Result { + if value.is_null() { + Ok(None) + } else { + Ok(Some(T::from_sql_value(value)?)) + } + } +} + +// BLOB CODEC MACRO +// ================================================================================================= + +/// Generates [`ToSqlValue`](crate::sqlite::ToSqlValue) and +/// [`FromSqlValue`](crate::sqlite::FromSqlValue) for types stored as a BLOB via their +/// `Serializable`/`Deserializable` impls. +/// +/// The generated impls call the exact same `to_bytes()`/`read_from_bytes()` used elsewhere, so the +/// on-disk byte layout is unchanged. +#[macro_export] +macro_rules! impl_blob_codec { + ($($t:ty),+ $(,)?) => { + $( + impl $crate::sqlite::ToSqlValue for $t { + fn to_sql_value(&self) -> $crate::sqlite::DbValue { + $crate::sqlite::DbValue::blob( + ::miden_protocol::utils::serde::Serializable::to_bytes(self), + ) + } + } + + impl $crate::sqlite::FromSqlValue for $t { + fn from_sql_value( + value: $crate::sqlite::DbValueRef<'_>, + ) -> ::core::result::Result { + let bytes = value.as_blob()?; + <$t as ::miden_protocol::utils::serde::Deserializable>::read_from_bytes(bytes) + .map_err(|err| { + $crate::DatabaseError::deserialization(::core::stringify!($t), err) + }) + } + } + )+ + }; +} + +// Codec for the common protocol types stored as BLOBs. Shared by all node crates so that the orphan +// rule does not force each consumer to redeclare them. +impl_blob_codec!( + miden_protocol::block::BlockHeader, + miden_protocol::account::AccountId, + miden_protocol::transaction::TransactionId, + miden_protocol::note::Nullifier, + miden_protocol::Word, +); diff --git a/crates/db/src/sqlite/in_list.rs b/crates/db/src/sqlite/in_list.rs new file mode 100644 index 0000000000..67e4fd5386 --- /dev/null +++ b/crates/db/src/sqlite/in_list.rs @@ -0,0 +1,65 @@ +//! Variable-length `IN (...)` lists that keep the SQL text constant. +//! +//! Binding a list as `IN (?, ?, ...)` produces a different SQL string per list length, so SQLite +//! cannot cache the prepared statement. Instead, bind the list as a single array parameter via +//! rusqlite's [`array`](https://docs.rs/rusqlite/latest/rusqlite/vtab/array/index.html) extension +//! and expand it with `rarray`, keeping the SQL text constant and the comparison on the raw column +//! (so an index on the column can be used): +//! +//! ```sql +//! ... WHERE col IN (SELECT value FROM rarray(?1)) +//! ``` +//! +//! The same idiom works for both integer and BLOB keys: the values are bound natively, so there is +//! no per-row `hex()`/`unhex()` conversion and no JSON serialization. + +use rusqlite::types::Value; + +use crate::sqlite::codec::{DbValue, ToSqlValue}; + +/// A list bound as an array parameter for use with `rarray`. +#[derive(Debug, Clone, PartialEq)] +pub struct InList(Vec); + +impl ToSqlValue for InList { + fn to_sql_value(&self) -> DbValue { + DbValue::array(self.0.clone()) + } +} + +/// Builds an integer-keyed `IN` list. Pair with `... IN (SELECT value FROM rarray(?))`. +pub fn in_list_i64(items: impl IntoIterator) -> InList { + InList(items.into_iter().map(Value::Integer).collect()) +} + +/// Builds a BLOB-keyed `IN` list. Pair with `... IN (SELECT value FROM rarray(?))`; the column is +/// compared directly against the bound blobs, with no hex conversion. +pub fn in_list_blob<'a>(items: impl IntoIterator) -> InList { + InList(items.into_iter().map(|bytes| Value::Blob(bytes.to_vec())).collect()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn in_list_i64_collects_integer_values() { + // Different list lengths produce the same SQL template (`rarray(?1)`); only the bound + // parameter contents differ. + assert_eq!(in_list_i64([1]).0, vec![Value::Integer(1)]); + assert_eq!( + in_list_i64([1, 2, 3]).0, + vec![Value::Integer(1), Value::Integer(2), Value::Integer(3)] + ); + assert_eq!(in_list_i64(std::iter::empty()).0, Vec::::new()); + } + + #[test] + fn in_list_blob_collects_blob_values() { + assert_eq!(in_list_blob([[0x0a, 0xff].as_slice()]).0, vec![Value::Blob(vec![0x0a, 0xff])]); + assert_eq!( + in_list_blob([[0x01].as_slice(), [0x02].as_slice()]).0, + vec![Value::Blob(vec![0x01]), Value::Blob(vec![0x02])] + ); + } +} diff --git a/crates/db/src/sqlite/mod.rs b/crates/db/src/sqlite/mod.rs new file mode 100644 index 0000000000..91c66ba418 --- /dev/null +++ b/crates/db/src/sqlite/mod.rs @@ -0,0 +1,11 @@ +//! A thin, additive SQLite framework over raw `rusqlite`. + +mod codec; +mod in_list; +mod pool; +mod tx; + +pub use codec::{DbValue, DbValueRef, FromSqlValue, ToSqlValue}; +pub use in_list::{InList, in_list_blob, in_list_i64}; +pub use pool::{Database, ReadTransaction, WriteTransaction}; +pub use tx::{ReadTx, Row, WriteTx}; diff --git a/crates/db/src/sqlite/pool.rs b/crates/db/src/sqlite/pool.rs new file mode 100644 index 0000000000..bc57e47bb1 --- /dev/null +++ b/crates/db/src/sqlite/pool.rs @@ -0,0 +1,506 @@ +//! Async connection pool over raw `rusqlite`. +//! +//! SQLite permits only a single writer at a time, so the pool is split into a **single** writer +//! connection and a pool of read-only connections. Writes (`write`/`begin_write`) serialize on the +//! one writer; reads (`read`/`begin_read`) run concurrently on the reader pool. This makes the +//! single-writer model structural (rather than relying on lock contention) and lets a held write +//! transaction stay open without starving readers. + +use std::num::NonZeroUsize; +use std::path::{Path, PathBuf}; + +use deadpool::Runtime; +use deadpool::managed::{Manager, Metrics, Object, Pool, RecycleError, RecycleResult}; +use deadpool_sync::SyncWrapper; +use rusqlite::{Connection, OpenFlags, TransactionBehavior}; +use tracing::Instrument; + +use crate::sqlite::tx::{ReadTx, WriteTx}; +use crate::{DatabaseError, default_connection_pool_size}; + +/// Per-connection prepared-statement cache capacity. Raised well above rusqlite's default of 16 +/// because we keep a large set of distinct statements; the bounded connection pools cap total +/// cached-statement memory. +const STATEMENT_CACHE_CAPACITY: usize = 512; + +// CONNECTION MANAGER +// ================================================================================================= + +/// Errors raised while creating or recycling a pooled connection. +/// +/// Internal to the pool: callers only ever observe a [`DatabaseError`] (pool failures are boxed into +/// [`DatabaseError::ConnectionPoolObtainError`]), so this type is not part of the public API. +#[derive(Debug, thiserror::Error)] +pub(crate) enum SqliteManagerError { + /// Opening the database file failed. + #[error("failed to open the sqlite database")] + Open(#[source] rusqlite::Error), + /// Applying the per-connection PRAGMAs failed. + #[error("failed to configure the sqlite connection")] + Configure(#[source] rusqlite::Error), + /// The pooled connection's mutex was poisoned by a panic during a previous interaction. + #[error("the pooled sqlite connection is poisoned")] + Poisoned, +} + +struct SqliteManager { + path: PathBuf, + /// When set, connections are configured `PRAGMA query_only = ON` and skip the writer-only + /// `journal_mode` setup — used for the reader pool. + query_only: bool, +} + +impl Manager for SqliteManager { + type Type = SyncWrapper; + type Error = SqliteManagerError; + + async fn create(&self) -> Result { + let path = self.path.clone(); + let query_only = self.query_only; + SyncWrapper::new(Runtime::Tokio1, move || { + let conn = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_WRITE) + .map_err(SqliteManagerError::Open)?; + configure_connection(&conn, query_only).map_err(SqliteManagerError::Configure)?; + Ok(conn) + }) + .await + } + + async fn recycle( + &self, + conn: &mut Self::Type, + _metrics: &Metrics, + ) -> RecycleResult { + if conn.is_mutex_poisoned() { + return Err(RecycleError::Backend(SqliteManagerError::Poisoned)); + } + // Safety net for a held transaction handle dropped without `commit`/`rollback`: roll back + // any still-open transaction so the next user gets a clean connection. + conn.interact(|conn| { + if !conn.is_autocommit() { + let _ = conn.execute_batch("ROLLBACK"); + } + }) + .await + .map_err(|_| RecycleError::Backend(SqliteManagerError::Poisoned))?; + Ok(()) + } +} + +/// Applies the per-connection PRAGMAs and statement-cache sizing. +/// +/// Both pools open the file `READ_WRITE`; reader connections are made read-only at runtime with +/// `PRAGMA query_only = ON` (which, unlike opening `READ_ONLY`, still lets them create the WAL +/// `-shm` file and read a WAL database). +fn configure_connection(conn: &Connection, query_only: bool) -> rusqlite::Result<()> { + // busy_timeout makes concurrent writers wait instead of failing immediately; foreign keys + // enforce referential integrity. + if query_only { + // A query_only connection cannot set `journal_mode` (it is a write); WAL is already + // persisted in the file header by the writer / migration path. + conn.execute_batch( + "PRAGMA busy_timeout = 5000; + PRAGMA foreign_keys = ON; + PRAGMA query_only = ON;", + )?; + } else { + // WAL allows concurrent readers while the writer holds the lock. + conn.execute_batch( + "PRAGMA busy_timeout = 5000; + PRAGMA journal_mode = WAL; + PRAGMA foreign_keys = ON;", + )?; + } + conn.set_prepared_statement_cache_capacity(STATEMENT_CACHE_CAPACITY); + // Register the `array` extension so the cacheable IN-list helpers can bind lists via + // `rarray(?)` (see `crate::sqlite::in_list`). + rusqlite::vtab::array::load_module(conn)?; + Ok(()) +} + +// DATABASE +// ================================================================================================= + +/// A rusqlite-backed connection pool. Cloning shares the underlying pools. +/// +/// Holds a single writer connection and a pool of reader connections (see the module docs). +#[derive(Clone)] +pub struct Database { + writer: Pool, + readers: Pool, +} + +impl Database { + /// Opens a database over `database_filepath` with the default reader-pool size. + pub fn new(database_filepath: &Path) -> Result { + Self::new_with_pool_size(database_filepath, default_connection_pool_size()) + } + + /// Opens a database over `database_filepath` with the given reader-pool size. The writer is + /// always a single connection. + pub fn new_with_pool_size( + database_filepath: &Path, + connection_pool_size: NonZeroUsize, + ) -> Result { + let writer = Pool::builder(SqliteManager { + path: database_filepath.to_path_buf(), + query_only: false, + }) + .max_size(1) + .build()?; + let readers = Pool::builder(SqliteManager { + path: database_filepath.to_path_buf(), + query_only: true, + }) + .max_size(connection_pool_size.get()) + .build()?; + Ok(Self { writer, readers }) + } + + /// Checks the single writer connection out of the pool. + async fn checkout_writer(&self) -> Result, DatabaseError> { + self.writer + .get() + .in_current_span() + .await + .map_err(|err| DatabaseError::ConnectionPoolObtainError(Box::new(err))) + } + + /// Checks a reader connection out of the pool. + async fn checkout_reader(&self) -> Result, DatabaseError> { + self.readers + .get() + .in_current_span() + .await + .map_err(|err| DatabaseError::ConnectionPoolObtainError(Box::new(err))) + } + + /// Runs `query` inside a read-only (`DEFERRED`, never committed) transaction on a reader + /// connection. + pub async fn read(&self, msg: impl ToString + Send, query: F) -> Result + where + F: FnOnce(&ReadTx<'_>) -> Result + Send + 'static, + R: Send + 'static, + E: From + Send + 'static, + { + let conn = self.checkout_reader().await.map_err(E::from)?; + let msg = msg.to_string(); + let span = tracing::Span::current(); + conn.interact(move |conn| { + let _guard = span.enter(); + let tx = conn + .transaction_with_behavior(TransactionBehavior::Deferred) + .map_err(|err| E::from(DatabaseError::from(err)))?; + query(&ReadTx::new(&tx)) + // `tx` is dropped here without a commit, rolling back any writes. + }) + .await + .map_err(|err| E::from(DatabaseError::interact(&msg, &err)))? + } + + /// Runs `query` inside a read-write (`IMMEDIATE`) transaction on the single writer connection, + /// committing on `Ok`. + pub async fn write(&self, msg: impl ToString + Send, query: F) -> Result + where + F: FnOnce(&WriteTx<'_>) -> Result + Send + 'static, + R: Send + 'static, + E: From + Send + 'static, + { + let conn = self.checkout_writer().await.map_err(E::from)?; + let msg = msg.to_string(); + let span = tracing::Span::current(); + conn.interact(move |conn| { + let _guard = span.enter(); + let tx = conn + .transaction_with_behavior(TransactionBehavior::Immediate) + .map_err(|err| E::from(DatabaseError::from(err)))?; + let result = query(&WriteTx::new(&tx))?; + tx.commit().map_err(|err| E::from(DatabaseError::from(err)))?; + Ok(result) + }) + .await + .map_err(|err| E::from(DatabaseError::interact(&msg, &err)))? + } + + /// Begins a read-only (`DEFERRED`) transaction on a reader connection and returns a handle held + /// across `.await` points. See [`ReadTransaction`]. + pub async fn begin_read(&self) -> Result { + let conn = self.checkout_reader().await?; + run_tx_stmt(&conn, "BEGIN DEFERRED").await?; + Ok(ReadTransaction { conn }) + } + + /// Begins a read-write (`IMMEDIATE`) transaction on the single writer connection and returns a + /// handle held across `.await` points. The handle must be committed (or it rolls back). See + /// [`WriteTransaction`]. + pub async fn begin_write(&self) -> Result { + let conn = self.checkout_writer().await?; + run_tx_stmt(&conn, "BEGIN IMMEDIATE").await?; + Ok(WriteTransaction { conn }) + } +} + +// HELD TRANSACTIONS +// ================================================================================================= + +/// Runs a transaction-control statement (`BEGIN`/`COMMIT`/`ROLLBACK`) on a checked-out connection. +async fn run_tx_stmt( + conn: &Object, + stmt: &'static str, +) -> Result<(), DatabaseError> { + conn.interact(move |conn| conn.execute_batch(stmt)) + .await + .map_err(|err| DatabaseError::interact(stmt, &err))? + .map_err(DatabaseError::from) +} + +/// A read transaction (`DEFERRED`) held across `.await` points, on a reader connection. +/// +/// Run batches of synchronous queries with [`run`](Self::run); the transaction stays open between +/// calls, so a request handler can interleave queries with async work on a single consistent +/// snapshot. The transaction is read-only and ends (rolls back) when the handle is dropped, or +/// explicitly via [`close`](Self::close). +pub struct ReadTransaction { + conn: Object, +} + +impl ReadTransaction { + /// Runs a batch of read queries against the open transaction. + pub async fn run(&self, msg: impl ToString + Send, query: F) -> Result + where + F: FnOnce(&ReadTx<'_>) -> Result + Send + 'static, + R: Send + 'static, + E: From + Send + 'static, + { + let msg = msg.to_string(); + let span = tracing::Span::current(); + self.conn + .interact(move |conn| { + let _guard = span.enter(); + query(&ReadTx::new(conn)) + }) + .await + .map_err(|err| E::from(DatabaseError::interact(&msg, &err)))? + } + + /// Ends the transaction explicitly (rolls back; a read transaction has nothing to commit). + pub async fn close(self) -> Result<(), DatabaseError> { + run_tx_stmt(&self.conn, "ROLLBACK").await + } +} + +/// A read-write transaction (`IMMEDIATE`) held across `.await` points, on the single writer +/// connection. +/// +/// Run batches of synchronous queries with [`run`](Self::run); the transaction stays open between +/// calls, so a request handler can interleave reads and writes with async work atomically. Finish +/// with [`commit`](Self::commit) to persist, or [`rollback`](Self::rollback) to discard; if the +/// handle is dropped without either, the pool rolls the transaction back when the connection is +/// recycled. +/// +/// The handle holds the sole writer connection for its whole lifetime. +pub struct WriteTransaction { + conn: Object, +} + +impl WriteTransaction { + /// Runs a batch of read/write queries against the open transaction. + pub async fn run(&self, msg: impl ToString + Send, query: F) -> Result + where + F: FnOnce(&WriteTx<'_>) -> Result + Send + 'static, + R: Send + 'static, + E: From + Send + 'static, + { + let msg = msg.to_string(); + let span = tracing::Span::current(); + self.conn + .interact(move |conn| { + let _guard = span.enter(); + query(&WriteTx::new(conn)) + }) + .await + .map_err(|err| E::from(DatabaseError::interact(&msg, &err)))? + } + + /// Commits the transaction, persisting all writes. + pub async fn commit(self) -> Result<(), DatabaseError> { + run_tx_stmt(&self.conn, "COMMIT").await + } + + /// Rolls back the transaction, discarding all writes. + pub async fn rollback(self) -> Result<(), DatabaseError> { + run_tx_stmt(&self.conn, "ROLLBACK").await + } +} + +#[cfg(test)] +mod tests { + use std::num::NonZeroUsize; + use std::path::{Path, PathBuf}; + + use rusqlite::Connection; + + use super::Database; + use crate::DatabaseError; + + /// A throwaway file-backed database; the pools open existing files `READ_WRITE` only, so the + /// file and schema are created up front. + struct TempDb { + path: PathBuf, + } + + impl TempDb { + fn new(name: &str) -> Self { + let path = std::env::temp_dir() + .join(format!("miden-node-db-pool-{name}-{}.sqlite3", std::process::id())); + let db = Self { path }; + db.remove_files(); + let conn = Connection::open(&db.path).expect("create db file"); + conn.execute_batch("CREATE TABLE items (id INTEGER PRIMARY KEY);") + .expect("create table"); + db + } + + fn path(&self) -> &Path { + &self.path + } + + fn remove_files(&self) { + let _ = fs_err::remove_file(&self.path); + let _ = fs_err::remove_file(self.path.with_extension("sqlite3-wal")); + let _ = fs_err::remove_file(self.path.with_extension("sqlite3-shm")); + } + } + + impl Drop for TempDb { + fn drop(&mut self) { + self.remove_files(); + } + } + + fn open_db(temp: &TempDb) -> Database { + Database::new_with_pool_size(temp.path(), NonZeroUsize::new(4).unwrap()).unwrap() + } + + async fn count_items(db: &Database) -> i64 { + db.read::<_, DatabaseError, _>("count", |r| { + Ok(r.query("SELECT COUNT(*) FROM items", &[], |row| row.get::(0))? + .into_iter() + .next() + .unwrap_or(0)) + }) + .await + .unwrap() + } + + async fn insert_committed(db: &Database, id: i64) { + let tx = db.begin_write().await.unwrap(); + tx.run::<_, DatabaseError, _>("insert", move |w| { + w.execute("INSERT INTO items (id) VALUES (?1)", &[&id])?; + Ok(()) + }) + .await + .unwrap(); + tx.commit().await.unwrap(); + } + + #[tokio::test] + async fn held_write_transaction_commits_across_awaits() { + let temp = TempDb::new("commit"); + let db = open_db(&temp); + + let tx = db.begin_write().await.unwrap(); + tx.run::<_, DatabaseError, _>("insert-1", |w| { + w.execute("INSERT INTO items (id) VALUES (?1)", &[&1i64])?; + Ok(()) + }) + .await + .unwrap(); + + // Interleave async work between statements on the same still-open transaction. + tokio::task::yield_now().await; + + tx.run::<_, DatabaseError, _>("insert-2", |w| { + w.execute("INSERT INTO items (id) VALUES (?1)", &[&2i64])?; + Ok(()) + }) + .await + .unwrap(); + + tx.commit().await.unwrap(); + + assert_eq!(count_items(&db).await, 2); + } + + #[tokio::test] + async fn dropped_write_transaction_rolls_back() { + let temp = TempDb::new("rollback"); + let db = open_db(&temp); + + { + let tx = db.begin_write().await.unwrap(); + tx.run::<_, DatabaseError, _>("insert", |w| { + w.execute("INSERT INTO items (id) VALUES (?1)", &[&1i64])?; + Ok(()) + }) + .await + .unwrap(); + // `tx` is dropped here without a commit. + } + + // The sole writer connection is reused; `recycle` must have rolled back the orphaned + // transaction, otherwise this `BEGIN IMMEDIATE` would fail with "cannot start a transaction + // within a transaction". The first insert must not have persisted. + insert_committed(&db, 2).await; + assert_eq!(count_items(&db).await, 1); + } + + #[tokio::test] + async fn reads_proceed_while_write_transaction_is_held() { + let temp = TempDb::new("concurrent"); + let db = open_db(&temp); + insert_committed(&db, 1).await; + + // Hold an open write transaction with an uncommitted insert. + let tx = db.begin_write().await.unwrap(); + tx.run::<_, DatabaseError, _>("insert-uncommitted", |w| { + w.execute("INSERT INTO items (id) VALUES (?1)", &[&2i64])?; + Ok(()) + }) + .await + .unwrap(); + + // A read on the reader pool proceeds (does not block on the writer) and does not see the + // uncommitted row. + assert_eq!(count_items(&db).await, 1); + + tx.commit().await.unwrap(); + assert_eq!(count_items(&db).await, 2); + } + + #[tokio::test] + async fn reader_connections_are_query_only() { + let temp = TempDb::new("query_only"); + let db = open_db(&temp); + + let query_only = db + .read::<_, DatabaseError, _>("pragma", |r| { + Ok(r.query("PRAGMA query_only", &[], |row| row.get::(0))? + .into_iter() + .next() + .unwrap_or(0)) + }) + .await + .unwrap(); + assert_eq!(query_only, 1, "reader connections must be query_only"); + + // A write attempted on a reader connection is rejected. + let result = db + .read::<(), DatabaseError, _>("rejected-write", |r| { + r.query("INSERT INTO items (id) VALUES (99)", &[], |_| Ok(()))?; + Ok(()) + }) + .await; + assert!(result.is_err(), "writes on a reader connection must fail"); + } +} diff --git a/crates/db/src/sqlite/tx.rs b/crates/db/src/sqlite/tx.rs new file mode 100644 index 0000000000..30efda10be --- /dev/null +++ b/crates/db/src/sqlite/tx.rs @@ -0,0 +1,241 @@ +//! Read/write transaction wrappers and the [`Row`] accessor. +//! +//! [`ReadTx`] and [`WriteTx`] are the only handles callers ever touch; the underlying +//! `rusqlite::Connection` is private. They borrow a connection on which a transaction has already +//! been opened (by the pool's `read`/`write` or by a held transaction handle) — they do not begin +//! or end the transaction themselves. +//! +//! A write transaction is a read transaction with the extra ability to mutate: [`WriteTx`] wraps a +//! [`ReadTx`] and derefs to it, so it inherits [`query`](ReadTx::query) and only adds +//! [`execute`](WriteTx::execute). A function that receives `&ReadTx` therefore cannot compile a +//! mutation. Both prepare statements with `prepare_cached`, so prepared statements are always +//! cached. + +use std::ops::Deref; + +use rusqlite::Connection; + +use crate::DatabaseError; +use crate::sqlite::codec::{DbValue, DbValueRef, FromSqlValue, ToSqlValue}; + +// ROW +// ================================================================================================= + +/// A single result row. Wraps `rusqlite::Row` so callers read columns through the codec via +/// [`Row::get`] without naming `rusqlite`. +pub struct Row<'a>(&'a rusqlite::Row<'a>); + +impl<'a> Row<'a> { + fn new(row: &'a rusqlite::Row<'a>) -> Self { + Self(row) + } + + /// Reads column `idx` (zero-based) decoded through [`FromSqlValue`], e.g. + /// `row.get::(0)`. + pub fn get(&self, idx: usize) -> Result { + let value = self.0.get_ref(idx)?; + T::from_sql_value(DbValueRef::new(value)) + } +} + +// TRANSACTION WRAPPERS +// ================================================================================================= + +/// A read-only transaction. Borrows a connection on which a `DEFERRED` transaction is open; the +/// transaction is never committed (changes roll back when it ends). +pub struct ReadTx<'t>(&'t Connection); + +impl<'t> ReadTx<'t> { + pub(crate) fn new(conn: &'t Connection) -> Self { + Self(conn) + } + + /// Runs a query and maps every row, collecting the results. + /// + /// This is the single read primitive: a caller expecting at most one row takes + /// `.into_iter().next()`, and `SELECT EXISTS(...)` / `SELECT COUNT(*)` map the single row's + /// first column. + pub fn query( + &self, + sql: &'static str, + params: &[&dyn ToSqlValue], + mut map: impl FnMut(&Row<'_>) -> Result, + ) -> Result, DatabaseError> { + debug_assert_no_dynamic_in(sql); + let values = to_values(params); + let mut stmt = self.0.prepare_cached(sql)?; + let mut rows = stmt.query(rusqlite::params_from_iter(values))?; + let mut out = Vec::new(); + while let Some(row) = rows.next()? { + out.push(map(&Row::new(row))?); + } + Ok(out) + } +} + +/// A read-write transaction. Borrows a connection on which an `IMMEDIATE` transaction is open; the +/// transaction is committed by the owner when the work returns `Ok`. Derefs to [`ReadTx`] for all +/// read queries and adds [`execute`](Self::execute). +pub struct WriteTx<'t>(ReadTx<'t>); + +impl<'t> WriteTx<'t> { + pub(crate) fn new(conn: &'t Connection) -> Self { + Self(ReadTx::new(conn)) + } + + /// Executes an `INSERT`/`UPDATE`/`DELETE`/`REPLACE` and returns the affected row count. + pub fn execute( + &self, + sql: &'static str, + params: &[&dyn ToSqlValue], + ) -> Result { + debug_assert_no_dynamic_in(sql); + let values = to_values(params); + let mut stmt = self.0.0.prepare_cached(sql)?; + Ok(stmt.execute(rusqlite::params_from_iter(values))?) + } +} + +impl<'t> Deref for WriteTx<'t> { + type Target = ReadTx<'t>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +// SHARED HELPERS +// ================================================================================================= + +fn to_values(params: &[&dyn ToSqlValue]) -> Vec { + params.iter().map(ToSqlValue::to_sql_value).collect() +} + +fn debug_assert_no_dynamic_in(sql: &str) { + debug_assert!( + !(sql.contains(" IN (?") || sql.contains(" IN (:")), + "use in_list() instead of a variable-length `IN (?, ...)` placeholder list to keep the \ + statement cacheable: {sql}" + ); +} + +#[cfg(test)] +mod tests { + use rusqlite::Connection; + + use super::*; + use crate::sqlite::{in_list_blob, in_list_i64}; + + fn in_memory() -> Connection { + let conn = Connection::open_in_memory().expect("open in-memory db"); + // `rarray()` is provided by rusqlite's `array` extension, which must be loaded per + // connection (the pool does this in `configure_connection`). + rusqlite::vtab::array::load_module(&conn).expect("load array module"); + conn.execute_batch( + "CREATE TABLE items (id INTEGER PRIMARY KEY, payload BLOB, label TEXT);", + ) + .expect("create table"); + conn + } + + #[test] + fn write_then_read_roundtrips_through_the_codec() { + let mut conn = in_memory(); + let tx = conn.transaction().unwrap(); + let w = WriteTx::new(&tx); + + let payload = vec![1u8, 2, 3]; + let inserted = w + .execute( + "INSERT INTO items (id, payload, label) VALUES (?1, ?2, ?3)", + &[&1i64, &payload, &"hello".to_string()], + ) + .unwrap(); + assert_eq!(inserted, 1); + + let got: (i64, Vec, String) = w + .query("SELECT id, payload, label FROM items WHERE id = ?1", &[&1i64], |row| { + Ok((row.get::(0)?, row.get::>(1)?, row.get::(2)?)) + }) + .unwrap() + .into_iter() + .next() + .unwrap(); + assert_eq!(got, (1, vec![1, 2, 3], "hello".to_string())); + } + + #[test] + fn null_column_reads_as_none() { + let mut conn = in_memory(); + let tx = conn.transaction().unwrap(); + let w = WriteTx::new(&tx); + + w.execute("INSERT INTO items (id, payload) VALUES (?1, NULL)", &[&1i64]) + .unwrap(); + let payload: Option> = w + .query("SELECT payload FROM items WHERE id = ?1", &[&1i64], |row| { + row.get::>>(0) + }) + .unwrap() + .into_iter() + .next() + .unwrap(); + assert_eq!(payload, None); + } + + #[test] + fn query_returns_empty_for_missing_row() { + let mut conn = in_memory(); + let tx = conn.transaction().unwrap(); + let r = ReadTx::new(&tx); + + let got = r + .query("SELECT id FROM items WHERE id = ?1", &[&404i64], |row| row.get::(0)) + .unwrap(); + assert!(got.is_empty()); + } + + // Regression guard for the cacheable IN-list idiom: the rarray form must run through `query` + // without tripping `debug_assert_no_dynamic_in` (tests run with debug assertions on). + #[test] + fn in_list_i64_rarray_runs_and_matches() { + let mut conn = in_memory(); + let tx = conn.transaction().unwrap(); + let w = WriteTx::new(&tx); + for id in [1i64, 2, 3, 4] { + w.execute("INSERT INTO items (id) VALUES (?1)", &[&id]).unwrap(); + } + + let wanted = in_list_i64([1, 3]); + let mut ids = w + .query( + "SELECT id FROM items WHERE id IN (SELECT value FROM rarray(?1))", + &[&wanted], + |row| row.get::(0), + ) + .unwrap(); + ids.sort_unstable(); + assert_eq!(ids, vec![1, 3]); + } + + #[test] + fn in_list_blob_matches_blob_column() { + let mut conn = in_memory(); + let tx = conn.transaction().unwrap(); + let w = WriteTx::new(&tx); + let a = vec![0xAAu8, 0xBB]; + let b = vec![0x01u8]; + w.execute("INSERT INTO items (id, payload) VALUES (1, ?1)", &[&a]).unwrap(); + w.execute("INSERT INTO items (id, payload) VALUES (2, ?1)", &[&b]).unwrap(); + + let wanted = in_list_blob([a.as_slice()]); + let ids = w + .query( + "SELECT id FROM items WHERE payload IN (SELECT value FROM rarray(?1))", + &[&wanted], + |row| row.get::(0), + ) + .unwrap(); + assert_eq!(ids, vec![1]); + } +}