From 575cec40fd51da1d5ab8f2bff046bb8703764bbb Mon Sep 17 00:00:00 2001 From: Joseph Cosentino Date: Mon, 13 Nov 2023 22:45:50 -0800 Subject: [PATCH] refactor: make statement creation logic clearer --- .../greengrass/disk/spool/DiskSpoolDAO.java | 65 ++++++++++--------- 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/src/main/java/com/aws/greengrass/disk/spool/DiskSpoolDAO.java b/src/main/java/com/aws/greengrass/disk/spool/DiskSpoolDAO.java index 4b20b56..1881909 100644 --- a/src/main/java/com/aws/greengrass/disk/spool/DiskSpoolDAO.java +++ b/src/main/java/com/aws/greengrass/disk/spool/DiskSpoolDAO.java @@ -30,6 +30,7 @@ import java.sql.Statement; import java.sql.Types; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -56,13 +57,22 @@ public class DiskSpoolDAO { private final Path databasePath; private final String url; - private final List> onNewConnection = new ArrayList<>(); - @SuppressWarnings("PMD.UnusedPrivateField") // runs whenever a new database connection is made private final CreateSpoolerTable createSpoolerTable = new CreateSpoolerTable(); private final GetAllSpoolMessageIds getAllSpoolMessageIds = new GetAllSpoolMessageIds(); private final GetSpoolMessageById getSpoolMessageById = new GetSpoolMessageById(); private final InsertSpoolMessage insertSpoolMessage = new InsertSpoolMessage(); private final RemoveSpoolMessageById removeSpoolMessageById = new RemoveSpoolMessageById(); + + /** + * Statements that will automatically be recreated when a new database connection is made. + */ + private final List> statementsToRecreate = Arrays.asList( + getAllSpoolMessageIds, + getSpoolMessageById, + insertSpoolMessage, + removeSpoolMessageById + ); + private final ReentrantLock recoverDBLock = new ReentrantLock(); private final ReentrantReadWriteLock connectionLock = new ReentrantReadWriteLock(); private Connection connection; @@ -93,13 +103,15 @@ public void initialize() throws SQLException { try (LockScope ls = LockScope.lock(connectionLock.writeLock())) { close(); connection = createConnection(); - fireOnNewConnection(); - } - } - private void fireOnNewConnection() throws SQLException { - for (CrashableFunction handler : onNewConnection) { - handler.apply(connection); + // recreate the database table first + createSpoolerTable.createStatement(connection); + createSpoolerTable.execute(); + + // eagerly create remaining statements + for (CachedStatement statement : statementsToRecreate) { + statement.createStatement(connection); + } } } @@ -201,7 +213,7 @@ class GetAllSpoolMessageIds extends CachedStatement + "FROM spooler WHERE message_id = ?;"; @Override - protected PreparedStatement createStatement(Connection connection) throws SQLException { + protected PreparedStatement doCreateStatement(Connection connection) throws SQLException { return connection.prepareStatement(QUERY); } @@ -276,7 +288,7 @@ class InsertSpoolMessage extends CachedStatement { + "VALUES (?,?,?,?,?,?,?,?,?,?,?,?);"; @Override - protected PreparedStatement createStatement(Connection connection) throws SQLException { + protected PreparedStatement doCreateStatement(Connection connection) throws SQLException { return connection.prepareStatement(QUERY); } @@ -340,7 +352,7 @@ class RemoveSpoolMessageById extends CachedStatement private static final String QUERY = "DELETE FROM spooler WHERE message_id = ?;"; @Override - protected PreparedStatement createStatement(Connection connection) throws SQLException { + protected PreparedStatement doCreateStatement(Connection connection) throws SQLException { return connection.prepareStatement(QUERY); } @@ -374,7 +386,7 @@ class CreateSpoolerTable extends CachedStatement { + ");"; @Override - protected PreparedStatement createStatement(Connection connection) throws SQLException { + protected PreparedStatement doCreateStatement(Connection connection) throws SQLException { return connection.prepareStatement(QUERY); } @@ -384,10 +396,9 @@ protected Integer doExecute(PreparedStatement statement) throws SQLException { } @Override - public Void onNewConnection(Connection connection) throws SQLException { - super.onNewConnection(connection); + public void createStatement(Connection connection) throws SQLException { + super.createStatement(connection); execute(); - return null; } } @@ -402,25 +413,15 @@ public Void onNewConnection(Connection connection) throws SQLException { abstract class CachedStatement { private T statement; - protected CachedStatement() { - onNewConnection.add(this::onNewConnection); - } - /** - * Callback that's executed whenever a database connection is created, - * during startup, or after corruption recovery, for example. + * Create a new statement. * * @param connection connection - * @return nothing - * @throws SQLException if unable to close old statement, or unable to create a new statement + * @throws SQLException if unable to create statement */ - public Void onNewConnection(Connection connection) throws SQLException { - // TODO lock needed anymore? - try (LockScope ls = LockScope.lock(connectionLock.readLock())) { - close(); // clean up old resources - statement = createStatement(connection); - } - return null; + public void createStatement(Connection connection) throws SQLException { + close(); // clean up old resources + statement = doCreateStatement(connection); } /** @@ -430,7 +431,7 @@ public Void onNewConnection(Connection connection) throws SQLException { * @return statement * @throws SQLException if unable to create statement */ - protected abstract T createStatement(Connection connection) throws SQLException; + protected abstract T doCreateStatement(Connection connection) throws SQLException; public void close() throws SQLException { if (statement != null) {