diff --git a/src/main/java/com/jd/jdbc/engine/table/TableInsertEngine.java b/src/main/java/com/jd/jdbc/engine/table/TableInsertEngine.java index ad675eb..ee92c8f 100644 --- a/src/main/java/com/jd/jdbc/engine/table/TableInsertEngine.java +++ b/src/main/java/com/jd/jdbc/engine/table/TableInsertEngine.java @@ -108,8 +108,6 @@ public class TableInsertEngine implements PrimitiveEngine, TableShardQuery { */ private Boolean multiShardAutocommit; - private long insertId; - private InsertEngine insertEngine; public TableInsertEngine(final Engine.InsertOpcode insertOpcode, final VKeyspace keyspace, final LogicTable table) { @@ -138,14 +136,17 @@ public IExecute.ExecuteMultiShardResponse execute(final IContext ctx, final Vcur if (RoleUtils.notMaster(ctx)) { throw new SQLException("insert is not allowed for read only connection"); } + long insertId; PrimitiveEngine primitiveEngine; switch (this.insertOpcode) { case InsertByDestination: case InsertUnsharded: + insertId = Generate.processGenerate(vcursor, generate, bindVariableMap); primitiveEngine = getInsertUnshardedEngine(ctx, vcursor, bindVariableMap); break; case InsertSharded: case InsertShardedIgnore: + insertId = Generate.processGenerate(vcursor, generate, bindVariableMap); primitiveEngine = getInsertShardedEngine(ctx, vcursor, bindVariableMap); break; default: @@ -163,7 +164,6 @@ public IExecute.ExecuteMultiShardResponse execute(final IContext ctx, final Vcur } private PrimitiveEngine getInsertUnshardedEngine(final IContext ctx, final Vcursor vcursor, final Map bindVariableMap) throws SQLException { - insertId = Generate.processGenerate(vcursor, generate, bindVariableMap); List actualTables = new ArrayList<>(); List> indexesPerTable = new ArrayList<>(); buildActualTables(bindVariableMap, actualTables, indexesPerTable); @@ -173,7 +173,6 @@ private PrimitiveEngine getInsertUnshardedEngine(final IContext ctx, final Vcurs } private PrimitiveEngine getInsertShardedEngine(final IContext ctx, final Vcursor vcursor, final Map bindVariableMap) throws SQLException { - insertId = Generate.processGenerate(vcursor, generate, bindVariableMap); List actualTables = new ArrayList<>(); List> indexesPerTable = new ArrayList<>(); buildActualTables(bindVariableMap, actualTables, indexesPerTable); diff --git a/src/test/java/com/jd/jdbc/table/engine/InsertTest.java b/src/test/java/com/jd/jdbc/table/engine/InsertTest.java index ec92291..1ba7db4 100644 --- a/src/test/java/com/jd/jdbc/table/engine/InsertTest.java +++ b/src/test/java/com/jd/jdbc/table/engine/InsertTest.java @@ -17,6 +17,8 @@ package com.jd.jdbc.table.engine; import com.jd.jdbc.table.TableTestUtil; +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; @@ -30,6 +32,8 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @@ -43,6 +47,7 @@ import testsuite.internal.testcase.TestSuiteCase; public class InsertTest extends TestSuite { + protected Connection conn; protected List testCaseList; @@ -56,6 +61,14 @@ protected String getUrl() { return getConnectionUrl(Driver.of(TestSuiteShardSpec.TWO_SHARDS)) + "&useAffectedRows=false"; } + protected String getUser() { + return getUser(Driver.of(TestSuiteShardSpec.TWO_SHARDS)); + } + + protected String getPassword() { + return getPassword(Driver.of(TestSuiteShardSpec.TWO_SHARDS)); + } + private void initCase() throws IOException, SQLException { testCaseList = initCase("src/test/resources/engine/tableengine/insert_case.json", TestCase.class, conn.getCatalog()); testCaseList.addAll(initCase("src/test/resources/engine/tableengine/insert_case_upperCase.json", TestCase.class, conn.getCatalog())); @@ -102,6 +115,84 @@ public void testSameKeySequence() throws Exception { insert(); } + @Test + @Ignore + public void testConcurrentSequence() throws Exception { + TableTestUtil.setSplitTableConfig("engine/tableengine/split-table-seq.yml"); + ExecutorService executorService = getThreadPool(10, 10); + + try (Statement stmt = conn.createStatement()) { + stmt.execute("delete from table_engine_test"); + } + + HikariConfig config = new HikariConfig(); + config.setDriverClassName("com.jd.jdbc.vitess.VitessDriver"); + config.setJdbcUrl(getUrl()); + config.setMinimumIdle(10); + config.setMaximumPoolSize(10); + config.setUsername(getUser()); + config.setPassword(getPassword()); + + HikariDataSource hikariDataSource = new HikariDataSource(config); + + AtomicBoolean atomicBoolean = new AtomicBoolean(true); + for (int i = 0; i < 20; i++) { + int finalI = i; + executorService.execute(() -> { + try (Connection connection = hikariDataSource.getConnection()) { + // insert split table sequence return generatedKey + PreparedStatement prepareStatement = connection.prepareStatement("insert into table_engine_test(f_key) values (?)", Statement.RETURN_GENERATED_KEYS); + prepareStatement.setInt(1, finalI); + Assert.assertFalse(prepareStatement.execute()); + Assert.assertEquals(1, prepareStatement.getUpdateCount()); + + // check getGeneratedKeys + ResultSet generatedKeys = prepareStatement.getGeneratedKeys(); + Assert.assertTrue(generatedKeys.next()); + long id = generatedKeys.getLong(1); + Assert.assertFalse(generatedKeys.next()); + + // check last_insert_id + final ResultSet lastInsertIdResultSet = connection.createStatement().executeQuery("select last_insert_id()"); + Assert.assertTrue(lastInsertIdResultSet.next()); + Assert.assertEquals(id, lastInsertIdResultSet.getLong(1)); + Assert.assertFalse(lastInsertIdResultSet.next()); + + // check last_insert_id + final ResultSet identityResultSet = connection.createStatement().executeQuery("select @@identity"); + Assert.assertTrue(identityResultSet.next()); + Assert.assertEquals(id, identityResultSet.getLong(1)); + Assert.assertFalse(identityResultSet.next()); + + // select by id + PreparedStatement selectPrepareStatement = connection.prepareStatement("select f_key,id from table_engine_test where f_key = ? and id = ?"); + selectPrepareStatement.setInt(1, finalI); + selectPrepareStatement.setLong(2, id); + Assert.assertTrue(selectPrepareStatement.execute()); + ResultSet resultSet = selectPrepareStatement.getResultSet(); + Assert.assertTrue(resultSet.next()); + Assert.assertEquals(finalI, resultSet.getInt(1)); + Assert.assertEquals(id, resultSet.getLong(2)); + } catch (Exception e) { + e.printStackTrace(); + atomicBoolean.set(false); + } + } + ); + } + final long start = System.currentTimeMillis(); + while (true) { + if (!atomicBoolean.get()) { + Assert.fail(); + break; + } + if (System.currentTimeMillis() - start > 15 * 1000) { + break; + } + } + executorService.shutdownNow(); + } + protected void insert() throws SQLException, NoSuchFieldException, IllegalAccessException { insert(false, true); } diff --git a/src/test/java/com/jd/jdbc/table/engine/unshard/InsertUnShardTest.java b/src/test/java/com/jd/jdbc/table/engine/unshard/InsertUnShardTest.java index df9e025..a7c5410 100644 --- a/src/test/java/com/jd/jdbc/table/engine/unshard/InsertUnShardTest.java +++ b/src/test/java/com/jd/jdbc/table/engine/unshard/InsertUnShardTest.java @@ -27,6 +27,16 @@ protected String getUrl() { return getConnectionUrl(Driver.of(TestSuiteShardSpec.NO_SHARDS)) + "&useAffectedRows=false"; } + @Override + protected String getUser() { + return getUser(Driver.of(TestSuiteShardSpec.NO_SHARDS)); + } + + @Override + protected String getPassword() { + return getPassword(Driver.of(TestSuiteShardSpec.NO_SHARDS)); + } + @Override protected void insert() throws SQLException, NoSuchFieldException, IllegalAccessException { insert(false, false); diff --git a/src/test/java/testsuite/TestSuite.java b/src/test/java/testsuite/TestSuite.java index def0ec8..c218c10 100644 --- a/src/test/java/testsuite/TestSuite.java +++ b/src/test/java/testsuite/TestSuite.java @@ -82,7 +82,7 @@ public static void closeConnection(Connection conn) { protected static ExecutorService getThreadPool(int num, int max) { ExecutorService pool = new ThreadPoolExecutor(num, max, - 0L, TimeUnit.MILLISECONDS, + 60, TimeUnit.SECONDS, new LinkedBlockingQueue<>(), new ThreadFactory() { private final AtomicInteger threadNumber = new AtomicInteger(1);