Skip to content

Commit

Permalink
Closes #2472, #2800: Update IndexingMsg to use & instead of mod f…
Browse files Browse the repository at this point in the history
…or setting bigint pdarrays (#2793)

* convert from mod to %, except bool

* swtich back to the correct bitop, &, and optimizie using bit shift masking, and add testing

* Updating if block, add indexing to test, add testing to old implementation, add remove value binops forall loop

* making requested changes: fixing tests, moving variable definitions, and adding forall loops

* update tests

---------

Co-authored-by: jaketrookman <[email protected]>
  • Loading branch information
jaketrookman and jaketrookman committed Oct 6, 2023
1 parent e3ecc43 commit 84fbe90
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 12 deletions.
5 changes: 5 additions & 0 deletions PROTO_tests/tests/indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,8 @@ def test_bigint_indexing_preserves_max_bits(self):
a = ak.arange(2**200 - 1, 2**200 + 11, max_bits=max_bits)
assert max_bits == a[ak.arange(10)].max_bits
assert max_bits == a[:].max_bits

def test_handling_bigint_max_bits(self):
a = ak.arange(2**200 - 1, 2**200 + 11, max_bits=3)
a[:] = ak.arange(2**200 - 1, 2**200 + 11)
assert [7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2] == a.to_list()
68 changes: 56 additions & 12 deletions src/IndexingMsg.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -608,31 +608,43 @@ module IndexingMsg
var e = toSymEntry(gEnt,bigint);
var val = valueArg.getBigIntValue();
if e.max_bits != -1 {
mod(val, val, e.max_bits);
var max_size = 1:bigint;
max_size <<= e.max_bits;
max_size -= 1;
val &= max_size;
}
e.a[idx] = val;
}
when (DType.BigInt, DType.Int64) {
var e = toSymEntry(gEnt,bigint);
var val = valueArg.getIntValue():bigint;
if e.max_bits != -1 {
mod(val, val, e.max_bits);
var max_size = 1:bigint;
max_size <<= e.max_bits;
max_size -= 1;
val &= max_size;
}
e.a[idx] = val;
}
when (DType.BigInt, DType.UInt64) {
var e = toSymEntry(gEnt,bigint);
var val = valueArg.getUIntValue():bigint;
if e.max_bits != -1 {
mod(val, val, e.max_bits);
var max_size = 1:bigint;
max_size <<= e.max_bits;
max_size -= 1;
val &= max_size;
}
e.a[idx] = val;
}
when (DType.BigInt, DType.Bool) {
var e = toSymEntry(gEnt,bigint);
var val = valueArg.getBoolValue():bigint;
if e.max_bits != -1 {
mod(val, val, e.max_bits);
var max_size = 1:bigint;
max_size <<= e.max_bits;
max_size -= 1;
val &= max_size;
}
e.a[idx] = val;
}
Expand Down Expand Up @@ -1150,31 +1162,43 @@ module IndexingMsg
var e = toSymEntry(gEnt,bigint);
var val = value.getBigIntValue();
if e.max_bits != -1 {
mod(val, val, e.max_bits);
var max_size = 1:bigint;
max_size <<= e.max_bits;
max_size -= 1;
val &= max_size;
}
e.a[slice] = val;
}
when (DType.BigInt, DType.Int64) {
var e = toSymEntry(gEnt,bigint);
var val = value.getIntValue():bigint;
if e.max_bits != -1 {
mod(val, val, e.max_bits);
var max_size = 1:bigint;
max_size <<= e.max_bits;
max_size -= 1;
val &= max_size;
}
e.a[slice] = val;
}
when (DType.BigInt, DType.UInt64) {
var e = toSymEntry(gEnt,bigint);
var val = value.getUIntValue():bigint;
if e.max_bits != -1 {
mod(val, val, e.max_bits);
var max_size = 1:bigint;
max_size <<= e.max_bits;
max_size -= 1;
val &= max_size;
}
e.a[slice] = val;
}
when (DType.BigInt, DType.Bool) {
var e = toSymEntry(gEnt,bigint);
var val = value.getBoolValue():bigint;
if e.max_bits != -1 {
mod(val, val, e.max_bits);
var max_size = 1:bigint;
max_size <<= e.max_bits;
max_size -= 1;
val &= max_size;
}
e.a[slice] = val;
}
Expand Down Expand Up @@ -1309,7 +1333,12 @@ module IndexingMsg
var x = toSymEntry(gX,bigint);
var y = toSymEntry(gY,bigint);
if x.max_bits != -1 {
mod(y.a, y.a, x.max_bits);
var max_size = 1:bigint;
max_size <<= x.max_bits;
max_size -= 1;
forall y in y.a with (var local_max_size = max_size) {
y &= local_max_size;
}
}
x.a[slice] = y.a;
}
Expand All @@ -1318,7 +1347,12 @@ module IndexingMsg
var y = toSymEntry(gY,int);
var ya = y.a:bigint;
if x.max_bits != -1 {
mod(ya, ya, x.max_bits);
var max_size = 1:bigint;
max_size <<= x.max_bits;
max_size -= 1;
forall y in ya with (var local_max_size = max_size) {
y &= local_max_size;
}
}
x.a[slice] = ya;
}
Expand All @@ -1327,7 +1361,12 @@ module IndexingMsg
var y = toSymEntry(gY,uint);
var ya = y.a:bigint;
if x.max_bits != -1 {
mod(ya, ya, x.max_bits);
var max_size = 1:bigint;
max_size <<= x.max_bits;
max_size -= 1;
forall y in ya with (var local_max_size = max_size) {
y &= local_max_size;
}
}
x.a[slice] = ya;
}
Expand All @@ -1337,7 +1376,12 @@ module IndexingMsg
// TODO change once we can cast directly from bool to bigint
var ya = y.a:int:bigint;
if x.max_bits != -1 {
mod(ya, ya, x.max_bits);
var max_size = 1:bigint;
max_size <<= x.max_bits;
max_size -= 1;
forall y in ya with (var local_max_size = max_size) {
y &= local_max_size;
}
}
x.a[slice] = ya;
}
Expand Down
5 changes: 5 additions & 0 deletions tests/indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,8 @@ def test_bigint_indexing_preserves_max_bits(self):
a = ak.arange(2**200 - 1, 2**200 + 11, max_bits=max_bits)
self.assertEqual(max_bits, a[ak.arange(10)].max_bits)
self.assertEqual(max_bits, a[:].max_bits)

def test_handling_bigint_max_bits(self):
a = ak.arange(2**200 - 1, 2**200 + 11, max_bits=3)
a[:] = ak.arange(2**200 - 1, 2**200 + 11)
self.assertListEqual([7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2], a.to_list())

0 comments on commit 84fbe90

Please sign in to comment.