Skip to main content

cadmus_core/db/
version.rs

1use crate::db::types::UnixTimestamp;
2use crate::helpers::Fp;
3use crate::version::GitVersion;
4use anyhow::{Context, Error};
5use serde::{Deserialize, Serialize};
6use sqlx::SqlitePool;
7use std::str::FromStr;
8
9include!(concat!(env!("OUT_DIR"), "/migration_hash.rs"));
10
11/// BLAKE3 hash of all schema migration file paths and contents.
12///
13/// Backed by [`Fp`] so it shares the same hex encoding, parsing, and sqlx
14/// serialisation behaviour as book content fingerprints.
15#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(transparent)]
17pub struct MigrationHash(Fp);
18
19impl MigrationHash {
20    /// Returns the migration hash embedded in the running build.
21    pub fn current() -> Self {
22        MIGRATION_HASH
23            .parse()
24            .expect("generated migration hash should be valid")
25    }
26}
27
28impl std::fmt::Display for MigrationHash {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        self.0.fmt(f)
31    }
32}
33
34impl FromStr for MigrationHash {
35    type Err = Error;
36
37    fn from_str(value: &str) -> Result<Self, Self::Err> {
38        value.parse::<Fp>().map(Self).map_err(Error::from)
39    }
40}
41
42impl sqlx::Type<sqlx::Sqlite> for MigrationHash {
43    fn type_info() -> sqlx::sqlite::SqliteTypeInfo {
44        <Fp as sqlx::Type<sqlx::Sqlite>>::type_info()
45    }
46
47    fn compatible(ty: &sqlx::sqlite::SqliteTypeInfo) -> bool {
48        <Fp as sqlx::Type<sqlx::Sqlite>>::compatible(ty)
49    }
50}
51
52impl sqlx::Encode<'_, sqlx::Sqlite> for MigrationHash {
53    fn encode_by_ref(
54        &self,
55        buf: &mut sqlx::sqlite::SqliteArgumentsBuffer,
56    ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
57        self.0.encode_by_ref(buf)
58    }
59}
60
61impl<'r> sqlx::Decode<'r, sqlx::Sqlite> for MigrationHash {
62    fn decode(value: sqlx::sqlite::SqliteValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
63        Fp::decode(value).map(Self)
64    }
65}
66
67/// Returns the schema migration hash embedded in the running build.
68pub fn current_migration_hash() -> MigrationHash {
69    MigrationHash::current()
70}
71
72/// Version and schema migration state stored in `_cadmus_version`.
73#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct DbVersionStamp {
75    /// Cadmus version that last stamped the database.
76    pub version: GitVersion,
77    /// Migration hash that last stamped the database.
78    pub migration_hash: MigrationHash,
79}
80
81/// Reads the Cadmus version stored in `_cadmus_version`.
82///
83/// Returns `None` if the table does not exist (database predates migration 012)
84/// or if the row is missing.
85#[cfg_attr(feature = "tracing", tracing::instrument(skip(pool)))]
86pub async fn read_db_version(pool: &SqlitePool) -> Result<Option<GitVersion>, Error> {
87    Ok(read_db_version_stamp(pool)
88        .await?
89        .map(|stamp| stamp.version))
90}
91
92/// Reads the Cadmus version stamp stored in `_cadmus_version`.
93///
94/// Returns `None` if the table does not exist or the singleton row is missing.
95#[cfg_attr(feature = "tracing", tracing::instrument(skip(pool)))]
96pub async fn read_db_version_stamp(pool: &SqlitePool) -> Result<Option<DbVersionStamp>, Error> {
97    let result = sqlx::query!(
98        r#"SELECT version AS "version: GitVersion", migration_hash AS "migration_hash: MigrationHash"
99           FROM _cadmus_version
100           WHERE id = 1"#,
101    )
102    .fetch_optional(pool)
103    .await;
104
105    match result {
106        Ok(Some(row)) => Ok(Some(DbVersionStamp {
107            version: row.version,
108            migration_hash: row.migration_hash,
109        })),
110        Ok(None) => Ok(None),
111        Err(sqlx::Error::Database(e)) if e.message().contains("no such table") => Ok(None),
112        Err(e) => Err(Error::from(e).context("failed to read _cadmus_version")),
113    }
114}
115
116/// Stamps the database with an explicit migration hash.
117#[cfg_attr(feature = "tracing", tracing::instrument(skip(pool)))]
118pub async fn stamp_db_version(
119    pool: &SqlitePool,
120    version: &GitVersion,
121    migration_hash: &MigrationHash,
122) -> Result<(), Error> {
123    let migrated_at = UnixTimestamp::now();
124    let version_str = version.to_string();
125    let migration_hash_str = migration_hash.to_string();
126    sqlx::query!(
127        "INSERT INTO _cadmus_version (id, version, migration_hash, migrated_at)
128         VALUES (1, ?, ?, ?)
129         ON CONFLICT(id) DO UPDATE
130         SET version = excluded.version,
131             migration_hash = excluded.migration_hash,
132             migrated_at = excluded.migrated_at",
133        version_str,
134        migration_hash_str,
135        migrated_at,
136    )
137    .execute(pool)
138    .await
139    .context("failed to stamp _cadmus_version")?;
140
141    Ok(())
142}
143
144/// Compares the database version with the current application version.
145///
146/// A newer database is compatible with an older app only when both were built
147/// from the same schema migration file set.
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum VersionGateResult {
150    /// The database was written by a newer Cadmus build; this is a downgrade.
151    Downgrade,
152    /// The database was written by an older Cadmus build; normal upgrade path.
153    Upgrade,
154    /// The database version matches the app version.
155    Current,
156    /// The database was written by a newer app with the same schema migrations.
157    CompatibleDowngrade,
158    /// No database version stamp exists (pre-012 database).
159    Unknown,
160}
161
162/// Checks whether the database version is compatible with the running app.
163///
164/// A `Downgrade` result means the database was touched by a newer Cadmus
165/// version with different schema migrations, so a backup from the current app
166/// version should be restored.
167#[cfg_attr(feature = "tracing", tracing::instrument(skip(pool)))]
168pub async fn check_version_gate(
169    pool: &SqlitePool,
170    app_version: &GitVersion,
171) -> Result<VersionGateResult, Error> {
172    match read_db_version_stamp(pool).await? {
173        None => Ok(VersionGateResult::Unknown),
174        Some(db_stamp) => match db_stamp.version.cmp(app_version) {
175            std::cmp::Ordering::Greater => {
176                if db_stamp.migration_hash == current_migration_hash() {
177                    return Ok(VersionGateResult::CompatibleDowngrade);
178                }
179
180                Ok(VersionGateResult::Downgrade)
181            }
182            std::cmp::Ordering::Less => Ok(VersionGateResult::Upgrade),
183            std::cmp::Ordering::Equal => {
184                if db_stamp.migration_hash == current_migration_hash() {
185                    Ok(VersionGateResult::Current)
186                } else {
187                    Ok(VersionGateResult::Downgrade)
188                }
189            }
190        },
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use crate::db::Database;
198    use crate::db::runtime::RUNTIME;
199    use crate::version::get_current_version;
200
201    fn different_migration_hash() -> MigrationHash {
202        blake3::hash(uuid::Uuid::now_v7().as_bytes())
203            .to_hex()
204            .to_string()
205            .parse()
206            .unwrap()
207    }
208
209    fn setup_db() -> Database {
210        let mut db = Database::new(":memory:").expect("failed to create in-memory database");
211        db.init(0).expect("failed to run migrations");
212        db
213    }
214
215    #[test]
216    fn read_db_version_returns_none_before_table_exists() {
217        // Database::new creates the pool but does not run migrations, so the
218        // _cadmus_version table does not exist yet.
219        let db = Database::new(":memory:").expect("failed to create in-memory database");
220        let version = RUNTIME.block_on(async { read_db_version(db.pool()).await.unwrap() });
221        assert!(version.is_none());
222    }
223
224    #[test]
225    fn stamp_and_read_db_version_roundtrip() {
226        let db = setup_db();
227        let version = GitVersion::from_str("v0.10.0").unwrap();
228        let migration_hash = current_migration_hash();
229
230        RUNTIME.block_on(async {
231            stamp_db_version(db.pool(), &version, &migration_hash)
232                .await
233                .unwrap();
234            let read = read_db_version(db.pool()).await.unwrap();
235            let stamp = read_db_version_stamp(db.pool()).await.unwrap().unwrap();
236            assert_eq!(read, Some(version));
237            assert_eq!(stamp.migration_hash, current_migration_hash());
238        });
239    }
240
241    #[test]
242    fn check_version_gate_detects_upgrade() {
243        let db = setup_db();
244        let older = GitVersion::from_str("v0.9.0").unwrap();
245        let newer = GitVersion::from_str("v0.10.0").unwrap();
246        let migration_hash = current_migration_hash();
247
248        RUNTIME.block_on(async {
249            stamp_db_version(db.pool(), &older, &migration_hash)
250                .await
251                .unwrap();
252            let gate = check_version_gate(db.pool(), &newer).await.unwrap();
253            assert_eq!(gate, VersionGateResult::Upgrade);
254        });
255    }
256
257    #[test]
258    fn check_version_gate_allows_compatible_downgrade() {
259        let db = setup_db();
260        let older = GitVersion::from_str("v0.9.0").unwrap();
261        let newer = GitVersion::from_str("v0.10.0").unwrap();
262        let migration_hash = current_migration_hash();
263
264        RUNTIME.block_on(async {
265            stamp_db_version(db.pool(), &newer, &migration_hash)
266                .await
267                .unwrap();
268            let gate = check_version_gate(db.pool(), &older).await.unwrap();
269            assert_eq!(gate, VersionGateResult::CompatibleDowngrade);
270        });
271    }
272
273    #[test]
274    fn check_version_gate_detects_incompatible_downgrade() {
275        let db = setup_db();
276        let older = GitVersion::from_str("v0.9.0").unwrap();
277        let newer = GitVersion::from_str("v0.10.0").unwrap();
278        let migration_hash = different_migration_hash();
279
280        RUNTIME.block_on(async {
281            stamp_db_version(db.pool(), &newer, &migration_hash)
282                .await
283                .unwrap();
284            let gate = check_version_gate(db.pool(), &older).await.unwrap();
285            assert_eq!(gate, VersionGateResult::Downgrade);
286        });
287    }
288
289    #[test]
290    fn check_version_gate_detects_current() {
291        let db = setup_db();
292        let version = GitVersion::from_str("v0.10.0").unwrap();
293        let migration_hash = current_migration_hash();
294
295        RUNTIME.block_on(async {
296            stamp_db_version(db.pool(), &version, &migration_hash)
297                .await
298                .unwrap();
299            let gate = check_version_gate(db.pool(), &version).await.unwrap();
300            assert_eq!(gate, VersionGateResult::Current);
301        });
302    }
303
304    #[test]
305    fn check_version_gate_detects_downgrade_on_equal_version_different_hash() {
306        let db = setup_db();
307        let version = GitVersion::from_str("v0.10.0").unwrap();
308        let migration_hash = different_migration_hash();
309
310        RUNTIME.block_on(async {
311            stamp_db_version(db.pool(), &version, &migration_hash)
312                .await
313                .unwrap();
314            let gate = check_version_gate(db.pool(), &version).await.unwrap();
315            assert_eq!(gate, VersionGateResult::Downgrade);
316        });
317    }
318
319    #[test]
320    fn check_version_gate_unknown_when_table_is_empty() {
321        let db = setup_db();
322
323        RUNTIME.block_on(async {
324            sqlx::query!("DELETE FROM _cadmus_version")
325                .execute(db.pool())
326                .await
327                .unwrap();
328            let gate = check_version_gate(db.pool(), &get_current_version())
329                .await
330                .unwrap();
331            assert_eq!(gate, VersionGateResult::Unknown);
332        });
333    }
334}