1use crate::db::types::UnixTimestamp;
2use crate::helpers::Fp;
3use crate::version::GitVersion;
4use anyhow::{Context, Error};
5use serde::{Deserialize, Serialize};
6use sqlx::SqlitePool;
7use std::str::FromStr;
8
9include!(concat!(env!("OUT_DIR"), "/migration_hash.rs"));
10
11#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
16#[serde(transparent)]
17pub struct MigrationHash(Fp);
18
19impl MigrationHash {
20 pub fn current() -> Self {
22 MIGRATION_HASH
23 .parse()
24 .expect("generated migration hash should be valid")
25 }
26}
27
28impl std::fmt::Display for MigrationHash {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 self.0.fmt(f)
31 }
32}
33
34impl FromStr for MigrationHash {
35 type Err = Error;
36
37 fn from_str(value: &str) -> Result<Self, Self::Err> {
38 value.parse::<Fp>().map(Self).map_err(Error::from)
39 }
40}
41
42impl sqlx::Type<sqlx::Sqlite> for MigrationHash {
43 fn type_info() -> sqlx::sqlite::SqliteTypeInfo {
44 <Fp as sqlx::Type<sqlx::Sqlite>>::type_info()
45 }
46
47 fn compatible(ty: &sqlx::sqlite::SqliteTypeInfo) -> bool {
48 <Fp as sqlx::Type<sqlx::Sqlite>>::compatible(ty)
49 }
50}
51
52impl sqlx::Encode<'_, sqlx::Sqlite> for MigrationHash {
53 fn encode_by_ref(
54 &self,
55 buf: &mut sqlx::sqlite::SqliteArgumentsBuffer,
56 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
57 self.0.encode_by_ref(buf)
58 }
59}
60
61impl<'r> sqlx::Decode<'r, sqlx::Sqlite> for MigrationHash {
62 fn decode(value: sqlx::sqlite::SqliteValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
63 Fp::decode(value).map(Self)
64 }
65}
66
67pub fn current_migration_hash() -> MigrationHash {
69 MigrationHash::current()
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct DbVersionStamp {
75 pub version: GitVersion,
77 pub migration_hash: MigrationHash,
79}
80
81#[cfg_attr(feature = "tracing", tracing::instrument(skip(pool)))]
86pub async fn read_db_version(pool: &SqlitePool) -> Result<Option<GitVersion>, Error> {
87 Ok(read_db_version_stamp(pool)
88 .await?
89 .map(|stamp| stamp.version))
90}
91
92#[cfg_attr(feature = "tracing", tracing::instrument(skip(pool)))]
96pub async fn read_db_version_stamp(pool: &SqlitePool) -> Result<Option<DbVersionStamp>, Error> {
97 let result = sqlx::query!(
98 r#"SELECT version AS "version: GitVersion", migration_hash AS "migration_hash: MigrationHash"
99 FROM _cadmus_version
100 WHERE id = 1"#,
101 )
102 .fetch_optional(pool)
103 .await;
104
105 match result {
106 Ok(Some(row)) => Ok(Some(DbVersionStamp {
107 version: row.version,
108 migration_hash: row.migration_hash,
109 })),
110 Ok(None) => Ok(None),
111 Err(sqlx::Error::Database(e)) if e.message().contains("no such table") => Ok(None),
112 Err(e) => Err(Error::from(e).context("failed to read _cadmus_version")),
113 }
114}
115
116#[cfg_attr(feature = "tracing", tracing::instrument(skip(pool)))]
118pub async fn stamp_db_version(
119 pool: &SqlitePool,
120 version: &GitVersion,
121 migration_hash: &MigrationHash,
122) -> Result<(), Error> {
123 let migrated_at = UnixTimestamp::now();
124 let version_str = version.to_string();
125 let migration_hash_str = migration_hash.to_string();
126 sqlx::query!(
127 "INSERT INTO _cadmus_version (id, version, migration_hash, migrated_at)
128 VALUES (1, ?, ?, ?)
129 ON CONFLICT(id) DO UPDATE
130 SET version = excluded.version,
131 migration_hash = excluded.migration_hash,
132 migrated_at = excluded.migrated_at",
133 version_str,
134 migration_hash_str,
135 migrated_at,
136 )
137 .execute(pool)
138 .await
139 .context("failed to stamp _cadmus_version")?;
140
141 Ok(())
142}
143
144#[derive(Debug, Clone, Copy, PartialEq, Eq)]
149pub enum VersionGateResult {
150 Downgrade,
152 Upgrade,
154 Current,
156 CompatibleDowngrade,
158 Unknown,
160}
161
162#[cfg_attr(feature = "tracing", tracing::instrument(skip(pool)))]
168pub async fn check_version_gate(
169 pool: &SqlitePool,
170 app_version: &GitVersion,
171) -> Result<VersionGateResult, Error> {
172 match read_db_version_stamp(pool).await? {
173 None => Ok(VersionGateResult::Unknown),
174 Some(db_stamp) => match db_stamp.version.cmp(app_version) {
175 std::cmp::Ordering::Greater => {
176 if db_stamp.migration_hash == current_migration_hash() {
177 return Ok(VersionGateResult::CompatibleDowngrade);
178 }
179
180 Ok(VersionGateResult::Downgrade)
181 }
182 std::cmp::Ordering::Less => Ok(VersionGateResult::Upgrade),
183 std::cmp::Ordering::Equal => {
184 if db_stamp.migration_hash == current_migration_hash() {
185 Ok(VersionGateResult::Current)
186 } else {
187 Ok(VersionGateResult::Downgrade)
188 }
189 }
190 },
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use crate::db::Database;
198 use crate::db::runtime::RUNTIME;
199 use crate::version::get_current_version;
200
201 fn different_migration_hash() -> MigrationHash {
202 blake3::hash(uuid::Uuid::now_v7().as_bytes())
203 .to_hex()
204 .to_string()
205 .parse()
206 .unwrap()
207 }
208
209 fn setup_db() -> Database {
210 let mut db = Database::new(":memory:").expect("failed to create in-memory database");
211 db.init(0).expect("failed to run migrations");
212 db
213 }
214
215 #[test]
216 fn read_db_version_returns_none_before_table_exists() {
217 let db = Database::new(":memory:").expect("failed to create in-memory database");
220 let version = RUNTIME.block_on(async { read_db_version(db.pool()).await.unwrap() });
221 assert!(version.is_none());
222 }
223
224 #[test]
225 fn stamp_and_read_db_version_roundtrip() {
226 let db = setup_db();
227 let version = GitVersion::from_str("v0.10.0").unwrap();
228 let migration_hash = current_migration_hash();
229
230 RUNTIME.block_on(async {
231 stamp_db_version(db.pool(), &version, &migration_hash)
232 .await
233 .unwrap();
234 let read = read_db_version(db.pool()).await.unwrap();
235 let stamp = read_db_version_stamp(db.pool()).await.unwrap().unwrap();
236 assert_eq!(read, Some(version));
237 assert_eq!(stamp.migration_hash, current_migration_hash());
238 });
239 }
240
241 #[test]
242 fn check_version_gate_detects_upgrade() {
243 let db = setup_db();
244 let older = GitVersion::from_str("v0.9.0").unwrap();
245 let newer = GitVersion::from_str("v0.10.0").unwrap();
246 let migration_hash = current_migration_hash();
247
248 RUNTIME.block_on(async {
249 stamp_db_version(db.pool(), &older, &migration_hash)
250 .await
251 .unwrap();
252 let gate = check_version_gate(db.pool(), &newer).await.unwrap();
253 assert_eq!(gate, VersionGateResult::Upgrade);
254 });
255 }
256
257 #[test]
258 fn check_version_gate_allows_compatible_downgrade() {
259 let db = setup_db();
260 let older = GitVersion::from_str("v0.9.0").unwrap();
261 let newer = GitVersion::from_str("v0.10.0").unwrap();
262 let migration_hash = current_migration_hash();
263
264 RUNTIME.block_on(async {
265 stamp_db_version(db.pool(), &newer, &migration_hash)
266 .await
267 .unwrap();
268 let gate = check_version_gate(db.pool(), &older).await.unwrap();
269 assert_eq!(gate, VersionGateResult::CompatibleDowngrade);
270 });
271 }
272
273 #[test]
274 fn check_version_gate_detects_incompatible_downgrade() {
275 let db = setup_db();
276 let older = GitVersion::from_str("v0.9.0").unwrap();
277 let newer = GitVersion::from_str("v0.10.0").unwrap();
278 let migration_hash = different_migration_hash();
279
280 RUNTIME.block_on(async {
281 stamp_db_version(db.pool(), &newer, &migration_hash)
282 .await
283 .unwrap();
284 let gate = check_version_gate(db.pool(), &older).await.unwrap();
285 assert_eq!(gate, VersionGateResult::Downgrade);
286 });
287 }
288
289 #[test]
290 fn check_version_gate_detects_current() {
291 let db = setup_db();
292 let version = GitVersion::from_str("v0.10.0").unwrap();
293 let migration_hash = current_migration_hash();
294
295 RUNTIME.block_on(async {
296 stamp_db_version(db.pool(), &version, &migration_hash)
297 .await
298 .unwrap();
299 let gate = check_version_gate(db.pool(), &version).await.unwrap();
300 assert_eq!(gate, VersionGateResult::Current);
301 });
302 }
303
304 #[test]
305 fn check_version_gate_detects_downgrade_on_equal_version_different_hash() {
306 let db = setup_db();
307 let version = GitVersion::from_str("v0.10.0").unwrap();
308 let migration_hash = different_migration_hash();
309
310 RUNTIME.block_on(async {
311 stamp_db_version(db.pool(), &version, &migration_hash)
312 .await
313 .unwrap();
314 let gate = check_version_gate(db.pool(), &version).await.unwrap();
315 assert_eq!(gate, VersionGateResult::Downgrade);
316 });
317 }
318
319 #[test]
320 fn check_version_gate_unknown_when_table_is_empty() {
321 let db = setup_db();
322
323 RUNTIME.block_on(async {
324 sqlx::query!("DELETE FROM _cadmus_version")
325 .execute(db.pool())
326 .await
327 .unwrap();
328 let gate = check_version_gate(db.pool(), &get_current_version())
329 .await
330 .unwrap();
331 assert_eq!(gate, VersionGateResult::Unknown);
332 });
333 }
334}