Skip to content

Commit

Permalink
refactor: make statement creation logic clearer
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosentino11 committed Nov 14, 2023
1 parent 962eb8c commit 575cec4
Showing 1 changed file with 33 additions and 32 deletions.
65 changes: 33 additions & 32 deletions src/main/java/com/aws/greengrass/disk/spool/DiskSpoolDAO.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.sql.Statement;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
Expand All @@ -56,13 +57,22 @@ public class DiskSpoolDAO {

private final Path databasePath;
private final String url;
private final List<CrashableFunction<Connection, Void, SQLException>> onNewConnection = new ArrayList<>();
@SuppressWarnings("PMD.UnusedPrivateField") // runs whenever a new database connection is made
private final CreateSpoolerTable createSpoolerTable = new CreateSpoolerTable();
private final GetAllSpoolMessageIds getAllSpoolMessageIds = new GetAllSpoolMessageIds();
private final GetSpoolMessageById getSpoolMessageById = new GetSpoolMessageById();
private final InsertSpoolMessage insertSpoolMessage = new InsertSpoolMessage();
private final RemoveSpoolMessageById removeSpoolMessageById = new RemoveSpoolMessageById();

/**
* Statements that will automatically be recreated when a new database connection is made.
*/
private final List<CachedStatement<?,?>> statementsToRecreate = Arrays.asList(
getAllSpoolMessageIds,
getSpoolMessageById,
insertSpoolMessage,
removeSpoolMessageById
);

private final ReentrantLock recoverDBLock = new ReentrantLock();
private final ReentrantReadWriteLock connectionLock = new ReentrantReadWriteLock();
private Connection connection;
Expand Down Expand Up @@ -93,13 +103,15 @@ public void initialize() throws SQLException {
try (LockScope ls = LockScope.lock(connectionLock.writeLock())) {
close();
connection = createConnection();
fireOnNewConnection();
}
}

private void fireOnNewConnection() throws SQLException {
for (CrashableFunction<Connection, Void, SQLException> handler : onNewConnection) {
handler.apply(connection);
// recreate the database table first
createSpoolerTable.createStatement(connection);
createSpoolerTable.execute();

// eagerly create remaining statements
for (CachedStatement<?, ?> statement : statementsToRecreate) {
statement.createStatement(connection);
}
}
}

Expand Down Expand Up @@ -201,7 +213,7 @@ class GetAllSpoolMessageIds extends CachedStatement<PreparedStatement, ResultSet
private static final String QUERY = "SELECT message_id FROM spooler;";

@Override
protected PreparedStatement createStatement(Connection connection) throws SQLException {
protected PreparedStatement doCreateStatement(Connection connection) throws SQLException {
return connection.prepareStatement(QUERY);
}

Expand All @@ -226,7 +238,7 @@ class GetSpoolMessageById extends CachedStatement<PreparedStatement, ResultSet>
+ "FROM spooler WHERE message_id = ?;";

@Override
protected PreparedStatement createStatement(Connection connection) throws SQLException {
protected PreparedStatement doCreateStatement(Connection connection) throws SQLException {
return connection.prepareStatement(QUERY);
}

Expand Down Expand Up @@ -276,7 +288,7 @@ class InsertSpoolMessage extends CachedStatement<PreparedStatement, Integer> {
+ "VALUES (?,?,?,?,?,?,?,?,?,?,?,?);";

@Override
protected PreparedStatement createStatement(Connection connection) throws SQLException {
protected PreparedStatement doCreateStatement(Connection connection) throws SQLException {
return connection.prepareStatement(QUERY);
}

Expand Down Expand Up @@ -340,7 +352,7 @@ class RemoveSpoolMessageById extends CachedStatement<PreparedStatement, Integer>
private static final String QUERY = "DELETE FROM spooler WHERE message_id = ?;";

@Override
protected PreparedStatement createStatement(Connection connection) throws SQLException {
protected PreparedStatement doCreateStatement(Connection connection) throws SQLException {
return connection.prepareStatement(QUERY);
}

Expand Down Expand Up @@ -374,7 +386,7 @@ class CreateSpoolerTable extends CachedStatement<PreparedStatement, Integer> {
+ ");";

@Override
protected PreparedStatement createStatement(Connection connection) throws SQLException {
protected PreparedStatement doCreateStatement(Connection connection) throws SQLException {
return connection.prepareStatement(QUERY);
}

Expand All @@ -384,10 +396,9 @@ protected Integer doExecute(PreparedStatement statement) throws SQLException {
}

@Override
public Void onNewConnection(Connection connection) throws SQLException {
super.onNewConnection(connection);
public void createStatement(Connection connection) throws SQLException {
super.createStatement(connection);
execute();
return null;
}
}

Expand All @@ -402,25 +413,15 @@ public Void onNewConnection(Connection connection) throws SQLException {
abstract class CachedStatement<T extends Statement, R> {
private T statement;

protected CachedStatement() {
onNewConnection.add(this::onNewConnection);
}

/**
* Callback that's executed whenever a database connection is created,
* during startup, or after corruption recovery, for example.
* Create a new statement.
*
* @param connection connection
* @return nothing
* @throws SQLException if unable to close old statement, or unable to create a new statement
* @throws SQLException if unable to create statement
*/
public Void onNewConnection(Connection connection) throws SQLException {
// TODO lock needed anymore?
try (LockScope ls = LockScope.lock(connectionLock.readLock())) {
close(); // clean up old resources
statement = createStatement(connection);
}
return null;
public void createStatement(Connection connection) throws SQLException {
close(); // clean up old resources
statement = doCreateStatement(connection);
}

/**
Expand All @@ -430,7 +431,7 @@ public Void onNewConnection(Connection connection) throws SQLException {
* @return statement
* @throws SQLException if unable to create statement
*/
protected abstract T createStatement(Connection connection) throws SQLException;
protected abstract T doCreateStatement(Connection connection) throws SQLException;

public void close() throws SQLException {
if (statement != null) {
Expand Down

0 comments on commit 575cec4

Please sign in to comment.