Skip to content

Commit

Permalink
[fix] Fix ArrayIndexOut0fBoundsException caused by optimistic lock (a…
Browse files Browse the repository at this point in the history
…pache#4066)

(cherry picked from commit e4faf25)
  • Loading branch information
thetumbled authored and dlg99 committed Apr 29, 2024
1 parent 0ec6cd4 commit 9763323
Show file tree
Hide file tree
Showing 12 changed files with 488 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -327,15 +327,18 @@ private static final class Section<V> extends StampedLock {
}

V get(long key, int keyHash) {
int bucket = keyHash;

long stamp = tryOptimisticRead();
boolean acquiredLock = false;

// add local variable here, so OutOfBound won't happen
long[] keys = this.keys;
V[] values = this.values;
// calculate table.length as capacity to avoid rehash changing capacity
int bucket = signSafeMod(keyHash, values.length);

try {
while (true) {
int capacity = this.capacity;
bucket = signSafeMod(bucket, capacity);

// First try optimistic locking
long storedKey = keys[bucket];
Expand All @@ -354,16 +357,15 @@ V get(long key, int keyHash) {
if (!acquiredLock) {
stamp = readLock();
acquiredLock = true;

// update local variable
keys = this.keys;
values = this.values;
bucket = signSafeMod(keyHash, values.length);
storedKey = keys[bucket];
storedValue = values[bucket];
}

if (capacity != this.capacity) {
// There has been a rehashing. We need to restart the search
bucket = keyHash;
continue;
}

if (storedKey == key) {
return storedValue != DeletedValue ? storedValue : null;
} else if (storedValue == EmptyValue) {
Expand All @@ -372,7 +374,7 @@ V get(long key, int keyHash) {
}
}

++bucket;
bucket = (bucket + 1) & (values.length - 1);
}
} finally {
if (acquiredLock) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,11 @@ private static final class Section extends StampedLock {
boolean contains(long item, int hash) {
long stamp = tryOptimisticRead();
boolean acquiredLock = false;
int bucket = signSafeMod(hash, capacity);

// add local variable here, so OutOfBound won't happen
long[] table = this.table;
// calculate table.length as capacity to avoid rehash changing capacity
int bucket = signSafeMod(hash, table.length);

try {
while (true) {
Expand All @@ -311,7 +315,9 @@ boolean contains(long item, int hash) {
stamp = readLock();
acquiredLock = true;

bucket = signSafeMod(hash, capacity);
// update local variable
table = this.table;
bucket = signSafeMod(hash, table.length);
storedItem = table[bucket];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,9 @@ public Map<Long, Long> asMap() {
// A section is a portion of the hash map that is covered by a single
@SuppressWarnings("serial")
private static final class Section extends StampedLock {
// Each item take up 2 continuous array space.
private static final int ITEM_SIZE = 2;

// Keys and values are stored interleaved in the table array
private volatile long[] table;

Expand All @@ -389,7 +392,7 @@ private static final class Section extends StampedLock {
float expandFactor, float shrinkFactor) {
this.capacity = alignToPowerOfTwo(capacity);
this.initCapacity = this.capacity;
this.table = new long[2 * this.capacity];
this.table = new long[ITEM_SIZE * this.capacity];
this.size = 0;
this.usedBuckets = 0;
this.autoShrink = autoShrink;
Expand All @@ -405,7 +408,10 @@ private static final class Section extends StampedLock {
long get(long key, int keyHash) {
long stamp = tryOptimisticRead();
boolean acquiredLock = false;
int bucket = signSafeMod(keyHash, capacity);
// add local variable here, so OutOfBound won't happen
long[] table = this.table;
// calculate table.length/2 as capacity to avoid rehash changing capacity
int bucket = signSafeMod(keyHash, table.length / ITEM_SIZE);

try {
while (true) {
Expand All @@ -427,7 +433,9 @@ long get(long key, int keyHash) {
stamp = readLock();
acquiredLock = true;

bucket = signSafeMod(keyHash, capacity);
// update local variable
table = this.table;
bucket = signSafeMod(keyHash, table.length / ITEM_SIZE);
storedKey = table[bucket];
storedValue = table[bucket + 1];
}
Expand All @@ -440,7 +448,7 @@ long get(long key, int keyHash) {
}
}

bucket = (bucket + 2) & (table.length - 1);
bucket = (bucket + ITEM_SIZE) & (table.length - 1);
}
} finally {
if (acquiredLock) {
Expand Down Expand Up @@ -493,7 +501,7 @@ long put(long key, long value, int keyHash, boolean onlyIfAbsent, LongLongFuncti
}
}

bucket = (bucket + 2) & (table.length - 1);
bucket = (bucket + ITEM_SIZE) & (table.length - 1);
}
} finally {
if (usedBuckets > resizeThresholdUp) {
Expand Down Expand Up @@ -551,7 +559,7 @@ long addAndGet(long key, long delta, int keyHash) {
}
}

bucket = (bucket + 2) & (table.length - 1);
bucket = (bucket + ITEM_SIZE) & (table.length - 1);
}
} finally {
if (usedBuckets > resizeThresholdUp) {
Expand Down Expand Up @@ -611,7 +619,7 @@ boolean compareAndSet(long key, long currentValue, long newValue, int keyHash) {
}
}

bucket = (bucket + 2) & (table.length - 1);
bucket = (bucket + ITEM_SIZE) & (table.length - 1);
}
} finally {
if (usedBuckets > resizeThresholdUp) {
Expand Down Expand Up @@ -650,7 +658,7 @@ private long remove(long key, long value, int keyHash) {
return ValueNotFound;
}

bucket = (bucket + 2) & (table.length - 1);
bucket = (bucket + ITEM_SIZE) & (table.length - 1);
}

} finally {
Expand Down Expand Up @@ -681,7 +689,7 @@ int removeIf(LongPredicate filter) {
int removedCount = 0;
try {
// Go through all the buckets for this section
for (int bucket = 0; size > 0 && bucket < table.length; bucket += 2) {
for (int bucket = 0; size > 0 && bucket < table.length; bucket += ITEM_SIZE) {
long storedKey = table[bucket];

if (storedKey != DeletedKey && storedKey != EmptyKey) {
Expand Down Expand Up @@ -719,7 +727,7 @@ int removeIf(LongLongPredicate filter) {
int removedCount = 0;
try {
// Go through all the buckets for this section
for (int bucket = 0; size > 0 && bucket < table.length; bucket += 2) {
for (int bucket = 0; size > 0 && bucket < table.length; bucket += ITEM_SIZE) {
long storedKey = table[bucket];
long storedValue = table[bucket + 1];

Expand Down Expand Up @@ -753,20 +761,20 @@ int removeIf(LongLongPredicate filter) {
}

private void cleanBucket(int bucket) {
int nextInArray = (bucket + 2) & (table.length - 1);
int nextInArray = (bucket + ITEM_SIZE) & (table.length - 1);
if (table[nextInArray] == EmptyKey) {
table[bucket] = EmptyKey;
table[bucket + 1] = ValueNotFound;
--usedBuckets;

// Cleanup all the buckets that were in `DeletedKey` state, so that we can reduce unnecessary expansions
bucket = (bucket - 2) & (table.length - 1);
bucket = (bucket - ITEM_SIZE) & (table.length - 1);
while (table[bucket] == DeletedKey) {
table[bucket] = EmptyKey;
table[bucket + 1] = ValueNotFound;
--usedBuckets;

bucket = (bucket - 2) & (table.length - 1);
bucket = (bucket - ITEM_SIZE) & (table.length - 1);
}
} else {
table[bucket] = DeletedKey;
Expand Down Expand Up @@ -807,7 +815,7 @@ public void forEach(BiConsumerLong processor) {
}

// Go through all the buckets for this section
for (int bucket = 0; bucket < table.length; bucket += 2) {
for (int bucket = 0; bucket < table.length; bucket += ITEM_SIZE) {
long storedKey = table[bucket];
long storedValue = table[bucket + 1];

Expand All @@ -833,11 +841,11 @@ public void forEach(BiConsumerLong processor) {

private void rehash(int newCapacity) {
// Expand the hashmap
long[] newTable = new long[2 * newCapacity];
long[] newTable = new long[ITEM_SIZE * newCapacity];
Arrays.fill(newTable, EmptyKey);

// Re-hash table
for (int i = 0; i < table.length; i += 2) {
for (int i = 0; i < table.length; i += ITEM_SIZE) {
long storedKey = table[i];
long storedValue = table[i + 1];
if (storedKey != EmptyKey && storedKey != DeletedKey) {
Expand All @@ -855,7 +863,7 @@ private void rehash(int newCapacity) {
}

private void shrinkToInitCapacity() {
long[] newTable = new long[2 * initCapacity];
long[] newTable = new long[ITEM_SIZE * initCapacity];
Arrays.fill(newTable, EmptyKey);

table = newTable;
Expand All @@ -881,7 +889,7 @@ private static void insertKeyValueNoLock(long[] table, int capacity, long key, l
return;
}

bucket = (bucket + 2) & (table.length - 1);
bucket = (bucket + ITEM_SIZE) & (table.length - 1);
}
}
}
Expand All @@ -897,6 +905,8 @@ static final long hash(long key) {
}

static final int signSafeMod(long n, int max) {
// as the ITEM_SIZE of Section is 2, so the index is the multiple of 2
// that is to left shift 1 bit
return (int) (n & (max - 1)) << 1;
}

Expand Down
Loading

0 comments on commit 9763323

Please sign in to comment.