Blob Blame History Raw
From 64b6bd89d0faad3274d0b224b1d1c92fcd397a62 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=C5=A0t=C4=9Bp=C3=A1n=20Hor=C3=A1=C4=8Dek?=
 <shoracek@redhat.com>
Date: Wed, 2 Nov 2022 19:23:13 +0100
Subject: [PATCH 5/6] db: fix upgrade backup
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

During a failed upgrade, the original database was deleted and replaced
with the upgraded one, making it impossible to revert the failed
upgrade.

This commit fixes this problem by keeping the old version of the
database as a separate file for upgrades that finished successfully and
keeping the original database for those that did not.

Signed-off-by: Štěpán Horáček <shoracek@redhat.com>
---
 tools/tpm2_pkcs11/db.py | 36 +++++++++++++++++++++---------------
 1 file changed, 21 insertions(+), 15 deletions(-)

diff --git a/tools/tpm2_pkcs11/db.py b/tools/tpm2_pkcs11/db.py
index 1b18b8f..d0a526b 100644
--- a/tools/tpm2_pkcs11/db.py
+++ b/tools/tpm2_pkcs11/db.py
@@ -454,27 +454,33 @@ class Db(object):
                     REPLACE INTO schema (id, schema_version) VALUES (1, {version});
                 '''.format(version=new_version))
             dbbakcon.execute(sql)
-        finally:
-            # Close the connections
-            self._conn.commit()
-            self._conn.close()
-
+        except Exception as e:
+            # Close the connection to backup
             dbbakcon.commit()
             dbbakcon.close()
 
-            # move old db to ".old" suffix
-            olddbpath = self._path + ".old"
-            os.rename(self._path, olddbpath)
+            # unlink the backup
+            os.unlink(dbbakpath)
+
+            raise e
+
+        # Close the connections
+        self._conn.commit()
+        self._conn.close()
 
-            # move the backup to the normal dbpath
-            os.rename(dbbakpath, self._path)
+        dbbakcon.commit()
+        dbbakcon.close()
 
-            # unlink the old
-            os.unlink(olddbpath)
+        # move old db to ".old" suffix
+        olddbpath = self._path + ".old"
+        os.rename(self._path, olddbpath)
 
-            # re-establish a connection
-            self._conn = sqlite3.connect(self._path)
-            self._conn.row_factory = sqlite3.Row
+        # move the backup to the normal dbpath
+        os.rename(dbbakpath, self._path)
+
+        # re-establish a connection
+        self._conn = sqlite3.connect(self._path)
+        self._conn.row_factory = sqlite3.Row
 
     def _get_version(self):
         c = self._conn.cursor()
-- 
2.38.1