diff --git a/src/main/java/com/jd/jdbc/vitess/VitessStatement.java b/src/main/java/com/jd/jdbc/vitess/VitessStatement.java index 8ccb0c2..5859435 100755 --- a/src/main/java/com/jd/jdbc/vitess/VitessStatement.java +++ b/src/main/java/com/jd/jdbc/vitess/VitessStatement.java @@ -99,6 +99,8 @@ public class VitessStatement extends AbstractVitessStatement { private static final String FUNCATION_IDENTITY = "@@IDENTITY"; + private static Query.Field[] generatedKeyField; + @Getter protected final Executor executor; @@ -582,10 +584,8 @@ protected VtResultSet getGeneratedKeysInternal(long numKeys) throws SQLException throw new SQLException("does not support insert multiple rows in one sql statement"); } long generatedKey = lastInsertId; - VtResultSet vtStaticResultSet = null; - Query.Field field = Query.Field.newBuilder().setName("GENERATED_KEY").setJdbcClassName("java.math.BigInteger").setType(Query.Type.UINT64).setColumnLength(20).setPrecision(20).build(); List> rows = new ArrayList<>(); - + VtResultSet vtStaticResultSet = new VtResultSet(getGeneratedKeyField(), rows); if (!this.resultSets.isEmpty()) { if (generatedKey < 0) { throw new SQLException("generatedKey error"); @@ -593,7 +593,6 @@ protected VtResultSet getGeneratedKeysInternal(long numKeys) throws SQLException if (generatedKey != 0 && (numKeys > 0)) { List row = Collections.singletonList(VtResultValue.newVtResultValue(Query.Type.UINT64, BigInteger.valueOf(generatedKey))); rows.add(row); - vtStaticResultSet = new VtResultSet(new Query.Field[] {field}, rows); } } return vtStaticResultSet; @@ -1145,6 +1144,20 @@ private void errorCount(String sql, Map bindVariableMap, S StatementCollector.getStatementErrorCounter().labels(connection.getDefaultKeyspace(), VitessJdbcProperyUtil.getRole(connection.getProperties())).inc(); } + private Query.Field[] getGeneratedKeyField() { + if (generatedKeyField == null) { + synchronized (VitessStatement.class) { + if (generatedKeyField == null) { + Query.Field field = Query.Field.newBuilder().setName("GENERATED_KEY").setJdbcClassName("java.math.BigInteger") + .setType(Query.Type.UINT64).setColumnLength(20).setPrecision(20).build(); + generatedKeyField = new Query.Field[] {field}; + return generatedKeyField; + } + } + } + return generatedKeyField; + } + @Getter @AllArgsConstructor public static class ParseResult { diff --git a/src/test/java/com/jd/jdbc/table/TableAutoGeneratedKeysTest.java b/src/test/java/com/jd/jdbc/table/TableAutoGeneratedKeysTest.java index 917a3e7..605a035 100644 --- a/src/test/java/com/jd/jdbc/table/TableAutoGeneratedKeysTest.java +++ b/src/test/java/com/jd/jdbc/table/TableAutoGeneratedKeysTest.java @@ -25,20 +25,17 @@ import org.junit.After; import org.junit.Assert; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import testsuite.internal.TestSuiteShardSpec; public class TableAutoGeneratedKeysTest extends VitessAutoGeneratedKeysTest { - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Before - public void init() throws Exception { + @Override + public void testLoadDriver() throws Exception { getConn(); TableTestUtil.setSplitTableConfig("table/autoGeneratedKeys.yml"); + clean(); sql1 = "insert into table_auto (id,ai,email) values(1,1,'x')"; sql100 = "insert into table_auto (id,ai,email) values(1,100,'x')"; @@ -46,6 +43,9 @@ public void init() throws Exception { sqld = "insert into table_auto (id,ai,email) values(%d,%d,'x')"; updateSql = "update table_auto set email = 'zz' where id = %d"; deleteSql = "delete from table_auto where id = %d"; + updateSql100 = "update table_auto set email = 'zz' where id = 1"; + sql200 = "insert into table_auto (id,ai,email) values(200,200,'x')"; + updatesql200 = "update table_auto set email = 'zz' where id = ?"; } protected void getConn() throws SQLException { @@ -55,9 +55,7 @@ protected void getConn() throws SQLException { @After public void close() throws Exception { - if (conn != null) { - conn.close(); - } + closeConnection(conn); TableTestUtil.setDefaultTableConfig(); } diff --git a/src/test/java/com/jd/jdbc/table/TableRewriteBatchedStatementsTest.java b/src/test/java/com/jd/jdbc/table/TableRewriteBatchedStatementsTest.java index 193eba1..ef56da8 100644 --- a/src/test/java/com/jd/jdbc/table/TableRewriteBatchedStatementsTest.java +++ b/src/test/java/com/jd/jdbc/table/TableRewriteBatchedStatementsTest.java @@ -45,6 +45,14 @@ public void test12StatementExecuteBatch() throws Exception { super.test12StatementExecuteBatch(); } + @Test + @Override + public void test12StatementUpdateExecuteBatch() throws Exception { + thrown.expect(java.sql.SQLFeatureNotSupportedException.class); + thrown.expectMessage("unsupported multiquery"); + super.test12StatementUpdateExecuteBatch(); + } + @Test @Override public void test20PreparedStatementExecuteBatch() throws Exception { @@ -70,8 +78,9 @@ public void test22PreparedStatementExecuteBatch() throws Exception { } @Test - @Override - public void test23SetNull() throws Exception { - super.test23SetNull(); + public void test22PreparedStatementUpdateExecuteBatch() throws Exception { + thrown.expect(java.sql.SQLFeatureNotSupportedException.class); + thrown.expectMessage("unsupported multiquery"); + super.test22PreparedStatementUpdateExecuteBatch(); } } diff --git a/src/test/java/com/jd/jdbc/vitess/VitessAutoGeneratedKeysTest.java b/src/test/java/com/jd/jdbc/vitess/VitessAutoGeneratedKeysTest.java index 60a597f..0d2ff40 100755 --- a/src/test/java/com/jd/jdbc/vitess/VitessAutoGeneratedKeysTest.java +++ b/src/test/java/com/jd/jdbc/vitess/VitessAutoGeneratedKeysTest.java @@ -19,19 +19,19 @@ import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; +import java.sql.Types; import java.util.Arrays; import org.junit.After; import org.junit.Before; -import org.junit.FixMethodOrder; +import org.junit.Rule; import org.junit.Test; -import org.junit.runners.MethodSorters; +import org.junit.rules.ExpectedException; import testsuite.TestSuite; import testsuite.internal.TestSuiteShardSpec; -import static java.sql.Statement.NO_GENERATED_KEYS; -import static java.sql.Statement.RETURN_GENERATED_KEYS; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -54,39 +54,47 @@ PRIMARY KEY (`ai`) * * only auto-increment filed return auto-generated-keys * */ - -@FixMethodOrder(MethodSorters.NAME_ASCENDING) public class VitessAutoGeneratedKeysTest extends TestSuite { private static final int[] colIndex = new int[] {1}; private static final String[] colString = new String[] {"no matter what you input, it would be ignored."}; - private static final int[] int1x100Array = new int[100]; + private static final int[] int1x10Array = new int[10]; - private static final int[] int1x300Array = new int[300]; + private static final int[] int1x30Array = new int[30]; static { - Arrays.fill(int1x100Array, 1); - Arrays.fill(int1x300Array, 1); + Arrays.fill(int1x10Array, 1); + Arrays.fill(int1x30Array, 1); } - public String sql1 = "insert into auto (id,ai,email) values(1,1,'x')"; + @Rule + public ExpectedException thrown = ExpectedException.none(); + + protected String sql1 = "insert into auto (id,ai,email) values(1,1,'x')"; + + protected String sql100 = "insert into auto (id,ai,email) values(1,100,'x')"; + + protected String sqlx = "insert into auto (id,ai,email) values(?,?,'x')"; - public String sql100 = "insert into auto (id,ai,email) values(1,100,'x')"; + protected String sqld = "insert into auto (id,ai,email) values(%d,%d,'x')"; - public String sqlx = "insert into auto (id,ai,email) values(?,?,'x')"; + protected String updateSql = "update auto set email = 'zz' where id = %d"; - public String sqld = "insert into auto (id,ai,email) values(%d,%d,'x')"; + protected String updateSql100 = "update auto set email = 'zz' where id = 1"; - public String updateSql = "update auto set email = 'zz' where id = %d"; + protected String deleteSql = "delete from auto where id = %d"; - public String deleteSql = "delete from auto where id = %d"; + protected String sql200 = "insert into auto (id,ai,email) values(200,200,'x')"; + + protected String updatesql200 = "update auto set email = 'zz' where id = ?"; protected Connection conn; @Before public void testLoadDriver() throws Exception { getConn(); + clean(); } protected void getConn() throws SQLException { @@ -95,9 +103,7 @@ protected void getConn() throws SQLException { @After public void close() throws Exception { - if (conn != null) { - conn.close(); - } + closeConnection(conn); } protected void clean() throws SQLException { @@ -111,225 +117,174 @@ protected void clean() throws SQLException { * ***********************************************************************************************/ @Test public void test01StatementExecute() throws Exception { - clean(); + thrown.expect(SQLException.class); + thrown.expectMessage("Generated keys not requested"); try (Statement stmt = conn.createStatement()) { boolean rc = stmt.execute(sql1); assertFalse(rc); assertEquals(1, stmt.getUpdateCount()); - - String exceptionMsg = null; - try { - stmt.getGeneratedKeys(); - } catch (SQLException e) { - exceptionMsg = e.getMessage(); - } - assertTrue(exceptionMsg.contains("Generated keys not requested")); + stmt.getGeneratedKeys(); } } @Test public void test02StatementExecute() throws Exception { - clean(); + thrown.expect(SQLException.class); + thrown.expectMessage("Generated keys not requested"); try (Statement stmt = conn.createStatement()) { - boolean rc = stmt.execute(sql1, NO_GENERATED_KEYS); + boolean rc = stmt.execute(sql1, Statement.NO_GENERATED_KEYS); assertFalse(rc); assertEquals(1, stmt.getUpdateCount()); - - String exceptionMsg = null; - try { - stmt.getGeneratedKeys(); - } catch (SQLException e) { - exceptionMsg = e.getMessage(); - } - assertTrue(exceptionMsg.contains("Generated keys not requested")); + stmt.getGeneratedKeys(); } } - @Test public void test03StatementExecute() throws Exception { - clean(); try (Statement stmt = conn.createStatement()) { - boolean rc = stmt.execute(sql100, RETURN_GENERATED_KEYS); - assertFalse(rc); - assertEquals(1, stmt.getUpdateCount()); + boolean rc = stmt.execute(sql100, Statement.RETURN_GENERATED_KEYS); + checkStatementExecute(stmt, rc); + } + } - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(100, generatedKeys.getLong(1)); - count++; - } - assertEquals(1, count); + @Test + public void test03UpdateStatementExecute() throws Exception { + try (Statement stmt = conn.createStatement()) { + stmt.execute(sql100); + boolean rc = stmt.execute(updateSql100, Statement.RETURN_GENERATED_KEYS); + checkUpdateGenerate(stmt, rc); } } @Test public void test04StatementExecute() throws Exception { - clean(); try (Statement stmt = conn.createStatement()) { boolean rc = stmt.execute(sql100, colIndex); - assertFalse(rc); - assertEquals(1, stmt.getUpdateCount()); + checkStatementExecute(stmt, rc); + } + } - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(100, generatedKeys.getLong(1)); - count++; - } - assertEquals(1, count); + @Test + public void test04UpdateStatementExecute() throws Exception { + try (Statement stmt = conn.createStatement()) { + stmt.execute(sql100); + boolean rc = stmt.execute(updateSql100, colIndex); + checkUpdateGenerate(stmt, rc); } } @Test public void test05StatementExecute() throws Exception { - clean(); try (Statement stmt = conn.createStatement()) { boolean rc = stmt.execute(sql100, colString); - assertFalse(rc); - assertEquals(1, stmt.getUpdateCount()); + checkStatementExecute(stmt, rc); + } + } - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(100, generatedKeys.getLong(1)); - count++; - } - assertEquals(1, count); + @Test + public void test05UpdateStatementExecute() throws Exception { + try (Statement stmt = conn.createStatement()) { + stmt.execute(sql100); + boolean rc = stmt.execute(updateSql100, colString); + checkUpdateGenerate(stmt, rc); } } @Test public void test06StatementExecuteUpdate() throws Exception { - clean(); + thrown.expect(SQLException.class); + thrown.expectMessage("Generated keys not requested"); try (Statement stmt = conn.createStatement()) { int updateCount = stmt.executeUpdate(sql1); assertEquals(1, updateCount); assertEquals(1, stmt.getUpdateCount()); - - String exceptionMsg = null; - try { - stmt.getGeneratedKeys(); - } catch (SQLException e) { - exceptionMsg = e.getMessage(); - } - assertTrue(exceptionMsg.contains("Generated keys not requested")); + stmt.getGeneratedKeys(); } } @Test public void test07StatementExecuteUpdate() throws Exception { - clean(); + thrown.expect(SQLException.class); + thrown.expectMessage("Generated keys not requested"); try (Statement stmt = conn.createStatement()) { - int updateCount = stmt.executeUpdate(sql1, NO_GENERATED_KEYS); + int updateCount = stmt.executeUpdate(sql1, Statement.NO_GENERATED_KEYS); assertEquals(1, updateCount); assertEquals(1, stmt.getUpdateCount()); - - String exceptionMsg = null; - try { - stmt.getGeneratedKeys(); - } catch (SQLException e) { - exceptionMsg = e.getMessage(); - } - assertTrue(exceptionMsg.contains("Generated keys not requested")); + stmt.getGeneratedKeys(); } } @Test public void test08StatementExecuteUpdate() throws Exception { - clean(); try (Statement stmt = conn.createStatement()) { - int updateCount = stmt.executeUpdate(sql100, RETURN_GENERATED_KEYS); - assertEquals(1, updateCount); - assertEquals(1, stmt.getUpdateCount()); - - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(100, generatedKeys.getLong(1)); - count++; - } - assertEquals(1, count); + int updateCount = stmt.executeUpdate(sql100, Statement.RETURN_GENERATED_KEYS); + checkStatementExecuteUpdate(stmt, updateCount, 100); } } - @Test public void test09StatementExecuteUpdate() throws Exception { - clean(); try (Statement stmt = conn.createStatement()) { int updateCount = stmt.executeUpdate(sql100, colIndex); - assertEquals(1, updateCount); - assertEquals(1, stmt.getUpdateCount()); - - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(100, generatedKeys.getLong(1)); - count++; - } - assertEquals(1, count); + checkStatementExecuteUpdate(stmt, updateCount, 100); } } @Test public void test10StatementExecuteUpdate() throws Exception { - clean(); try (Statement stmt = conn.createStatement()) { int updateCount = stmt.executeUpdate(sql100, colString); - assertEquals(1, updateCount); - assertEquals(1, stmt.getUpdateCount()); - - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(100, generatedKeys.getLong(1)); - count++; - } - assertEquals(1, count); + checkStatementExecuteUpdate(stmt, updateCount, 100); } } @Test public void test11StatementExecuteBatch() throws Exception { - clean(); try (Statement stmt = conn.createStatement()) { - for (int i = 100; i < 200; i++) { + for (int i = 10; i < 20; i++) { stmt.addBatch(String.format(sqld, i, i)); } int[] rc = stmt.executeBatch(); - assertArrayEquals(int1x100Array, rc); + assertArrayEquals(int1x10Array, rc); - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(count + 100, generatedKeys.getLong(1)); - count++; - } - assertEquals(100, count); + checkBatch(stmt); } } @Test public void test12StatementExecuteBatch() throws Exception { - clean(); try (Statement stmt = conn.createStatement()) { - for (int i = 100; i < 200; i++) { + for (int i = 10; i < 20; i++) { stmt.addBatch(String.format(sqld, i, i)); stmt.addBatch(String.format(updateSql, i)); stmt.addBatch(String.format(deleteSql, i)); } int[] rc = stmt.executeBatch(); - assertArrayEquals(int1x300Array, rc); + assertArrayEquals(int1x30Array, rc); - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(count + 100, generatedKeys.getLong(1)); - count++; + checkBatch(stmt); + } + } + + @Test + public void test12StatementUpdateExecuteBatch() throws Exception { + try (Statement stmt = conn.createStatement()) { + for (int i = 10; i < 20; i++) { + stmt.addBatch(String.format(sqld, i, i)); } - assertEquals(100, count); + int[] rc = stmt.executeBatch(); + assertArrayEquals(int1x10Array, rc); + } + + try (Statement stmt = conn.createStatement()) { + for (int i = 10; i < 20; i++) { + stmt.addBatch(String.format(updateSql, i)); + } + int[] rc = stmt.executeBatch(); + assertArrayEquals(int1x10Array, rc); + ResultSet generatedKeys = stmt.getGeneratedKeys(); + assertFalse(generatedKeys.next()); + checkGeneratedKeysMetaData(generatedKeys); } } @@ -338,7 +293,8 @@ public void test12StatementExecuteBatch() throws Exception { * ***********************************************************************************************/ @Test public void test13PreparedStatementExecute() throws Exception { - clean(); + thrown.expect(SQLException.class); + thrown.expectMessage("Generated keys not requested"); String sql = sqlx; try (PreparedStatement stmt = conn.prepareStatement(sql)) { stmt.setInt(1, 1); @@ -347,110 +303,107 @@ public void test13PreparedStatementExecute() throws Exception { boolean rc = stmt.execute(); assertFalse(rc); assertEquals(1, stmt.getUpdateCount()); - - String exceptionMsg = null; - try { - stmt.getGeneratedKeys(); - } catch (SQLException e) { - exceptionMsg = e.getMessage(); - } - assertTrue(exceptionMsg.contains("Generated keys not requested")); + stmt.getGeneratedKeys(); } } @Test public void test14PreparedStatementExecute() throws Exception { - clean(); + thrown.expect(SQLException.class); + thrown.expectMessage("Generated keys not requested"); String sql = sqlx; - try (PreparedStatement stmt = conn.prepareStatement(sql, NO_GENERATED_KEYS)) { + try (PreparedStatement stmt = conn.prepareStatement(sql, Statement.NO_GENERATED_KEYS)) { stmt.setInt(1, 1); stmt.setInt(2, 1); boolean rc = stmt.execute(); assertFalse(rc); assertEquals(1, stmt.getUpdateCount()); - - String exceptionMsg = null; - try { - stmt.getGeneratedKeys(); - } catch (SQLException e) { - exceptionMsg = e.getMessage(); - } - assertTrue(exceptionMsg.contains("Generated keys not requested")); + stmt.getGeneratedKeys(); } } @Test public void test15PreparedStatementExecute() throws Exception { - clean(); String sql = sqlx; - try (PreparedStatement stmt = conn.prepareStatement(sql, RETURN_GENERATED_KEYS)) { + try (PreparedStatement stmt = conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) { stmt.setInt(1, 100); stmt.setInt(2, 100); boolean rc = stmt.execute(); - assertFalse(rc); - assertEquals(1, stmt.getUpdateCount()); + checkStatementExecute(stmt, rc); + } + } - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(100, generatedKeys.getLong(1)); - count++; - } - assertEquals(1, count); + @Test + public void test15PreparedUpdateStatementExecute() throws Exception { + try (Statement stmt = conn.createStatement()) { + stmt.execute(sql200); + } + String sql = updatesql200; + try (PreparedStatement stmt = conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) { + stmt.setInt(1, 200); + boolean rc = stmt.execute(); + checkUpdateGenerate(stmt, rc); } } @Test public void test16PreparedStatementExecute() throws Exception { - clean(); String sql = sqlx; try (PreparedStatement stmt = conn.prepareStatement(sql, colIndex)) { stmt.setInt(1, 100); stmt.setInt(2, 100); boolean rc = stmt.execute(); - assertFalse(rc); - assertEquals(1, stmt.getUpdateCount()); + checkStatementExecute(stmt, rc); + } + } - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(100, generatedKeys.getLong(1)); - count++; - } - assertEquals(1, count); + @Test + public void test16PreparedUpdateStatementExecute() throws Exception { + try (Statement stmt = conn.createStatement()) { + stmt.execute(sql200); + } + String sql = updatesql200; + try (PreparedStatement stmt = conn.prepareStatement(sql, colIndex)) { + stmt.setInt(1, 200); + boolean rc = stmt.execute(); + checkUpdateGenerate(stmt, rc); } } @Test public void test17PreparedStatementExecute() throws Exception { - clean(); String sql = sqlx; try (PreparedStatement stmt = conn.prepareStatement(sql, colString)) { stmt.setInt(1, 100); stmt.setInt(2, 100); boolean rc = stmt.execute(); - assertFalse(rc); - assertEquals(1, stmt.getUpdateCount()); + checkStatementExecute(stmt, rc); + } + } - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(100, generatedKeys.getLong(1)); - count++; - } - assertEquals(1, count); + @Test + public void test17PreparedUpdateStatementExecute() throws Exception { + try (Statement stmt = conn.createStatement()) { + stmt.execute(sql200); + } + String sql = updatesql200; + try (PreparedStatement stmt = conn.prepareStatement(sql, colString)) { + stmt.setInt(1, 200); + boolean rc = stmt.execute(); + checkUpdateGenerate(stmt, rc); } } @Test public void test18PreparedStatementExecuteUpdate() throws Exception { - clean(); + thrown.expect(SQLException.class); + thrown.expectMessage("Generated keys not requested"); String sql = sqlx; - try (PreparedStatement stmt = conn.prepareStatement(sql, NO_GENERATED_KEYS)) { + try (PreparedStatement stmt = conn.prepareStatement(sql, Statement.NO_GENERATED_KEYS)) { stmt.setInt(1, 100); stmt.setInt(2, 100); @@ -459,105 +412,90 @@ public void test18PreparedStatementExecuteUpdate() throws Exception { assertEquals(1, updateCount); assertEquals(1, stmt.getUpdateCount()); - String exceptionMsg = null; - try { - stmt.getGeneratedKeys(); - } catch (SQLException e) { - exceptionMsg = e.getMessage(); - } - assertTrue(exceptionMsg.contains("Generated keys not requested")); + stmt.getGeneratedKeys(); } } @Test public void test19PreparedStatementExecuteUpdate() throws Exception { - clean(); String sql = sqlx; - try (PreparedStatement stmt = conn.prepareStatement(sql, RETURN_GENERATED_KEYS)) { + try (PreparedStatement stmt = conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) { stmt.setInt(1, 100); stmt.setInt(2, 100); - int updateCount = stmt.executeUpdate(); - - assertEquals(1, updateCount); - assertEquals(1, stmt.getUpdateCount()); - - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(100, generatedKeys.getLong(1)); - count++; - } - assertEquals(1, count); + checkStatementExecuteUpdate(stmt, updateCount, 100); } } @Test public void test20PreparedStatementExecuteBatch() throws Exception { - clean(); String sql = sqlx; - try (PreparedStatement stmt = conn.prepareStatement(sql, NO_GENERATED_KEYS)) { - for (int i = 100; i < 200; i++) { + try (PreparedStatement stmt = conn.prepareStatement(sql, Statement.NO_GENERATED_KEYS)) { + for (int i = 10; i < 20; i++) { stmt.setInt(1, i); stmt.setInt(2, i); stmt.addBatch(); } int[] rc = stmt.executeBatch(); - assertArrayEquals(int1x100Array, rc); - - String exceptionMsg = null; - try { - stmt.getGeneratedKeys(); - } catch (SQLException e) { - exceptionMsg = e.getMessage(); - } - assertTrue(exceptionMsg.contains("Generated keys not requested")); + assertArrayEquals(int1x10Array, rc); + thrown.expect(SQLException.class); + thrown.expectMessage("Generated keys not requested"); + stmt.getGeneratedKeys(); } } @Test public void test21PreparedStatementExecuteBatch() throws Exception { - clean(); String sql = sqlx; - try (PreparedStatement stmt = conn.prepareStatement(sql, RETURN_GENERATED_KEYS)) { - for (int i = 100; i < 200; i++) { + try (PreparedStatement stmt = conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) { + for (int i = 10; i < 20; i++) { stmt.setInt(1, i); stmt.setInt(2, i); stmt.addBatch(); } int[] rc = stmt.executeBatch(); - assertArrayEquals(int1x100Array, rc); + assertArrayEquals(int1x10Array, rc); - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(count + 100, generatedKeys.getLong(1)); - count++; - } - assertEquals(100, count); + checkBatch(stmt); } } @Test public void test22PreparedStatementExecuteBatch() throws Exception { - clean(); String sql = sqlx; - try (PreparedStatement stmt = conn.prepareStatement(sql, RETURN_GENERATED_KEYS)) { - for (int i = 100; i < 200; i++) { + try (PreparedStatement stmt = conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) { + for (int i = 10; i < 20; i++) { stmt.addBatch(String.format(sqld, i, i)); stmt.addBatch(String.format(updateSql, i)); stmt.addBatch(String.format(deleteSql, i)); } int[] rc = stmt.executeBatch(); - assertArrayEquals(int1x300Array, rc); + assertArrayEquals(int1x30Array, rc); - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(count + 100, generatedKeys.getLong(1)); - count++; + checkBatch(stmt); + } + } + + @Test + public void test22PreparedStatementUpdateExecuteBatch() throws Exception { + try (Statement stmt = conn.createStatement()) { + for (int i = 10; i < 20; i++) { + stmt.addBatch(String.format(sqld, i, i)); } - assertEquals(100, count); + int[] rc = stmt.executeBatch(); + assertArrayEquals(int1x10Array, rc); + } + + try (PreparedStatement stmt = conn.prepareStatement(updatesql200, Statement.RETURN_GENERATED_KEYS)) { + for (int i = 10; i < 20; i++) { + stmt.setInt(1, i); + stmt.addBatch(); + } + int[] rc = stmt.executeBatch(); + assertArrayEquals(int1x10Array, rc); + ResultSet generatedKeys = stmt.getGeneratedKeys(); + assertFalse(generatedKeys.next()); + checkGeneratedKeysMetaData(generatedKeys); } } @@ -566,24 +504,16 @@ public void test22PreparedStatementExecuteBatch() throws Exception { * ***********************************************************************************************/ @Test public void test23SetNull() throws Exception { - clean(); String sql = sqlx; - try (PreparedStatement stmt = conn.prepareStatement(sql, RETURN_GENERATED_KEYS)) { + try (PreparedStatement stmt = conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)) { stmt.setInt(1, 100); stmt.setInt(2, 1000); int updateCount = stmt.executeUpdate(); - assertEquals(1, updateCount); - assertEquals(1, stmt.getUpdateCount()); - - ResultSet generatedKeys = stmt.getGeneratedKeys(); - int count = 0; - while (generatedKeys.next()) { - assertEquals(1000, generatedKeys.getLong(1)); - count++; - } - assertEquals(1, count); + checkStatementExecuteUpdate(stmt, updateCount, 1000); + ResultSet generatedKeys; + int count; //set null stmt.setInt(1, 100); @@ -603,4 +533,62 @@ public void test23SetNull() throws Exception { assertEquals(1, count); } } + + private void checkStatementExecute(Statement stmt, boolean rc) throws SQLException { + assertFalse(rc); + assertEquals(1, stmt.getUpdateCount()); + + ResultSet generatedKeys = stmt.getGeneratedKeys(); + int count = 0; + while (generatedKeys.next()) { + assertEquals(100, generatedKeys.getLong(1)); + count++; + } + assertEquals(1, count); + checkGeneratedKeysMetaData(generatedKeys); + } + + private void checkStatementExecuteUpdate(Statement stmt, int updateCount, int i) throws SQLException { + assertEquals(1, updateCount); + assertEquals(1, stmt.getUpdateCount()); + + ResultSet generatedKeys = stmt.getGeneratedKeys(); + int count = 0; + while (generatedKeys.next()) { + assertEquals(i, generatedKeys.getLong(1)); + count++; + } + assertEquals(1, count); + checkGeneratedKeysMetaData(generatedKeys); + } + + private void checkUpdateGenerate(Statement stmt, boolean rc) throws SQLException { + assertFalse(rc); + assertEquals(1, stmt.getUpdateCount()); + ResultSet generatedKeys = stmt.getGeneratedKeys(); + assertFalse(generatedKeys.next()); + checkGeneratedKeysMetaData(generatedKeys); + } + + private void checkGeneratedKeysMetaData(ResultSet generatedKeys) throws SQLException { + ResultSetMetaData rsmd = generatedKeys.getMetaData(); + assertEquals(1, rsmd.getColumnCount()); + assertEquals(Types.BIGINT, rsmd.getColumnType(1)); + assertEquals("GENERATED_KEY", rsmd.getColumnLabel(1)); + assertEquals("java.math.BigInteger", rsmd.getColumnClassName(1)); + assertEquals(20, rsmd.getColumnDisplaySize(1)); + assertEquals(20, rsmd.getPrecision(1)); + } + + private void checkBatch(Statement stmt) throws SQLException { + ResultSet generatedKeys = stmt.getGeneratedKeys(); + int count = 0; + while (generatedKeys.next()) { + assertEquals(count + 10, generatedKeys.getLong(1)); + count++; + } + assertEquals(10, count); + checkGeneratedKeysMetaData(generatedKeys); + } + }