diff --git a/test/unit/org/apache/cassandra/index/sai/plan/SingleRestrictionEstimatedRowCountTest.java b/test/unit/org/apache/cassandra/index/sai/plan/SingleRestrictionEstimatedRowCountTest.java index 757b9ae2939a..d29c62f19b15 100644 --- a/test/unit/org/apache/cassandra/index/sai/plan/SingleRestrictionEstimatedRowCountTest.java +++ b/test/unit/org/apache/cassandra/index/sai/plan/SingleRestrictionEstimatedRowCountTest.java @@ -20,9 +20,10 @@ import java.math.BigDecimal; import java.math.BigInteger; +import java.util.AbstractMap; +import java.util.HashMap; +import java.util.Map; -import org.junit.After; -import org.junit.Before; import org.junit.Test; import org.apache.cassandra.Util; @@ -44,7 +45,9 @@ public class SingleRestrictionEstimatedRowCountTest extends SAITester { - private int queryOptLevel; + static protected Map, ColumnFamilyStore> tables = new HashMap<>(); + static Version[] versions = new Version[]{ Version.DB, Version.EB }; + static CQL3Type.Native[] types = new CQL3Type.Native[]{ INT, DECIMAL, VARINT }; static protected Object getFilterValue(CQL3Type.Native type, int value) { @@ -61,69 +64,75 @@ static protected Object getFilterValue(CQL3Type.Native type, int value) return null; } - @Before - public void setup() + static Map.Entry tablesEntryKey(Version version, CQL3Type.Native type) { - queryOptLevel = QueryController.QUERY_OPT_LEVEL; - QueryController.QUERY_OPT_LEVEL = 0; - } - - @After - public void teardown() - { - QueryController.QUERY_OPT_LEVEL = queryOptLevel; + return new AbstractMap.SimpleEntry<>(version, type); } @Test - public void testInequality() + public void testMemtablesSAI() { - var test = new RowCountTest(Operator.NEQ, 25); + createTables(); + + RowCountTest test = new RowCountTest(Operator.NEQ, 25); test.doTest(Version.DB, INT, 97.0); test.doTest(Version.EB, INT, 97.0); // Truncated numeric types planned differently test.doTest(Version.DB, DECIMAL, 97.0); test.doTest(Version.EB, DECIMAL, 97.0); test.doTest(Version.EB, VARINT, 97.0); - } - @Test - public void testHalfRangeMiddle() - { - var test = new RowCountTest(Operator.LT, 50); + test = new RowCountTest(Operator.LT, 50); test.doTest(Version.DB, INT, 48); test.doTest(Version.EB, INT, 48); test.doTest(Version.DB, DECIMAL, 48); test.doTest(Version.EB, DECIMAL, 48); - } - @Test - public void testHalfRangeEverything() - { - var test = new RowCountTest(Operator.LT, 150); + test = new RowCountTest(Operator.LT, 150); test.doTest(Version.DB, INT, 97); test.doTest(Version.EB, INT, 97); test.doTest(Version.DB, DECIMAL, 97); test.doTest(Version.EB, DECIMAL, 97); - } - @Test - public void testEquality() - { - var test = new RowCountTest(Operator.EQ, 31); + test = new RowCountTest(Operator.EQ, 31); test.doTest(Version.DB, INT, 15); test.doTest(Version.EB, INT, 0); test.doTest(Version.DB, DECIMAL, 15); test.doTest(Version.EB, DECIMAL, 0); } - protected ColumnFamilyStore prepareTable(CQL3Type.Native type) + + void createTables() + { + for (Version version : versions) + { + SAIUtil.setLatestVersion(version); + for (CQL3Type.Native type : types) + { + createTable("CREATE TABLE %s (pk text PRIMARY KEY, age " + type + ')'); + createIndex("CREATE CUSTOM INDEX ON %s(age) USING 'StorageAttachedIndex'"); + tables.put(tablesEntryKey(version, type), getCurrentColumnFamilyStore()); + } + } + flush(); + for (ColumnFamilyStore cfs : tables.values()) + populateTable(cfs); + } + + void populateTable(ColumnFamilyStore cfs) { - createTable("CREATE TABLE %s (pk text PRIMARY KEY, age " + type + ')'); - createIndex("CREATE CUSTOM INDEX ON %s(age) USING 'StorageAttachedIndex'"); - return getCurrentColumnFamilyStore(); + // Avoid race condition of starting before flushing completed + cfs.unsafeRunWithoutFlushing(() -> { + for (int i = 0; i < 100; i++) + { + String query = String.format("INSERT INTO %s (pk, age) VALUES (?, " + i + ')', + cfs.keyspace.getName() + '.' + cfs.name); + executeFormattedQuery(query, "key" + i); + } + }); } - class RowCountTest + static class RowCountTest { final Operator op; final int filterValue; @@ -136,20 +145,8 @@ class RowCountTest void doTest(Version version, CQL3Type.Native type, double expectedRows) { - Version latest = Version.latest(); - SAIUtil.setLatestVersion(version); - - ColumnFamilyStore cfs = prepareTable(type); - // Avoid race condition of flushing after the index creation - cfs.unsafeRunWithoutFlushing(() -> { - for (int i = 0; i < 100; i++) - { - execute("INSERT INTO %s (pk, age) VALUES (?," + i + ')', "key" + i); - } - }); - + ColumnFamilyStore cfs = tables.get(new AbstractMap.SimpleEntry<>(version, type)); Object filter = getFilterValue(type, filterValue); - ReadCommand rc = Util.cmd(cfs) .columns("age") .filterOn("age", op, filter) @@ -159,6 +156,7 @@ void doTest(Version version, CQL3Type.Native type, double expectedRows) version.onDiskFormat().indexFeatureSet(), new QueryContext(), null); + long totalRows = controller.planFactory.tableMetrics.rows; assertEquals(0, cfs.metrics().liveSSTableCount.getValue().intValue()); assertEquals(97, totalRows); @@ -172,8 +170,6 @@ void doTest(Version version, CQL3Type.Native type, double expectedRows) assertEquals(expectedRows, root.expectedRows(), 0.1); assertEquals(expectedRows, planNode.expectedKeys(), 0.1); assertEquals(expectedRows / totalRows, planNode.selectivity(), 0.001); - - SAIUtil.setLatestVersion(latest); } } }