Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #925: ht.nonzero() returns tuple of 1-D arrays instead of n-D arrays #937

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f7adcf2
Create ci.yaml
mtar Feb 25, 2022
2ab82b5
Update ci.yaml
mtar Feb 25, 2022
f261e8e
Update ci.yaml
mtar Feb 25, 2022
9b863a7
Create CITATION.cff
mtar Mar 4, 2022
2b2622a
Update CITATION.cff
mtar Mar 7, 2022
a15b299
Update ci.yaml
mtar Mar 8, 2022
8910bf7
Update ci.yaml
mtar Mar 9, 2022
f8dc8b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 9, 2022
767eabc
Delete pre-commit.yml
mtar Mar 9, 2022
3cd1d33
Merge branch 'main' into enhancement/203-ghactions-matrix
mtar Mar 9, 2022
61cef7f
Update ci.yaml
mtar Mar 9, 2022
74b1a30
Update CITATION.cff
mtar Mar 11, 2022
93cd831
Update tutorial.ipynb
mtar Mar 11, 2022
2a25d22
Merge pull request #931 from helmholtz-analytics/doc/901-tutorial_update
coquelin77 Mar 14, 2022
e154ab9
Merge branch 'main' into docs/927-citation
coquelin77 Mar 14, 2022
7c57942
Merge pull request #929 from helmholtz-analytics/docs/927-citation
coquelin77 Mar 14, 2022
114e74e
Merge branch 'main' into enhancement/203-ghactions-matrix
coquelin77 Mar 14, 2022
14aae08
Merge pull request #924 from helmholtz-analytics/enhancement/203-ghac…
coquelin77 Mar 14, 2022
dd1b83d
Delete logo_heAT.pdf
Markus-Goetz Mar 15, 2022
7e6ad4a
ht.nonzero() returns tuple of 1-D arrays instead of n-D arrays
Mystic-Slice Mar 23, 2022
aeb5b6e
Updated documentation and Unit-tests
Mystic-Slice Mar 24, 2022
03e1287
Merge branch '914_adv-indexing-outshape-outsplit' into NonzeroFunction
ClaudiaComito Mar 25, 2022
420f064
replace x.larray with local_x
ClaudiaComito Mar 25, 2022
a00ed61
Code fixes
Mystic-Slice Mar 28, 2022
d4a8813
Fix return type of nonzero function and gout value
Mystic-Slice Mar 30, 2022
67fcdc8
Made sure DNDarray meta-data is available to the tuple members
Mystic-Slice Mar 30, 2022
39103fa
Transpose before if-branching + adjustments to accomodate it
Mystic-Slice Apr 1, 2022
3ed205c
Fixed global shape assignment
Mystic-Slice Apr 5, 2022
70dded6
Updated changelog
Mystic-Slice Apr 8, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: ci

on:
pull_request_review:
types: [submitted]

jobs:
approved:
if: github.event.review.state == 'approved'
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
py-version:
- 3.7
- 3.8
mpi: [ 'openmpi' ]
install-options: [ '.', '.[hdf5,netcdf]' ]
pytorch-version:
- 'torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2'
- 'torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1'
- 'torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0'


name: Python ${{ matrix.py-version }} with ${{ matrix.pytorch-version }}; options ${{ matrix.install-options }}
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Setup MPI
uses: mpi4py/setup-mpi@v1
with:
mpi: ${{ matrix.mpi }}
- name: Use Python ${{ matrix.py-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.py-version }}
architecture: x64
- name: Test
run: |
pip install pytest
pip install ${{ matrix.pytorch-version }} -f https://download.pytorch.org/whl/torch_stable.html
pip install ${{ matrix.install-options }}
mpirun -n 3 pytest heat/
mpirun -n 4 pytest heat/
14 changes: 0 additions & 14 deletions .github/workflows/pre-commit.yml

This file was deleted.

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- [#876](https://github.com/helmholtz-analytics/heat/pull/876) Make examples work (Lasso and kNN)
- [#894](https://github.com/helmholtz-analytics/heat/pull/894) Change inclusion of license file
- [#884](https://github.com/helmholtz-analytics/heat/pull/884) Added capabilities for PyTorch 1.10.0, this is now the recommended version to use.
- [#937](https://github.com/helmholtz-analytics/heat/pull/937) Modified `ht.nonzero()` to return a tuple of 1-D arrays containing the non-zero indices in each dimension.

## Bug Fixes
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension
Expand Down
68 changes: 68 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Götz"
given-names: "Markus"
- family-names: "Debus"
given-names: "Charlotte"
- family-names: "Coquelin"
given-names: "Daniel"
- family-names: "Krajsek"
given-names: "Kai"
- family-names: "Comito"
given-names: "Claudia"
- family-names: "Knechtges"
given-names: "Philipp"
- family-names: "Hagemeier"
given-names: "Björn"
- family-names: "Tarnawa"
given-names: "Michael"
- family-names: "Hanselmann"
given-names: "Simon"
- family-names: "Siggel"
given-names: "Martin"
- family-names: "Basermann"
given-names: "Achim"
- family-names: "Streit"
given-names: "Achim"
title: "Heat - Helmholtz Analytics Toolkit"
version: 1.1.0
date-released: 2021-09-21
url: "https://github.com/helmholtz-analytics/heat"
preferred-citation:
type: conference-paper
authors:
- family-names: "Götz"
given-names: "Markus"
- family-names: "Debus"
given-names: "Charlotte"
- family-names: "Coquelin"
given-names: "Daniel"
- family-names: "Krajsek"
given-names: "Kai"
- family-names: "Comito"
given-names: "Claudia"
- family-names: "Knechtges"
given-names: "Philipp"
- family-names: "Hagemeier"
given-names: "Björn"
- family-names: "Tarnawa"
given-names: "Michael"
- family-names: "Hanselmann"
given-names: "Simon"
- family-names: "Siggel"
given-names: "Martin"
- family-names: "Basermann"
given-names: "Achim"
- family-names: "Streit"
given-names: "Achim"
title: "HeAT -- a Distributed and GPU-accelerated Tensor Framework for Data Analytics"
year: 2020
collection-title: "2020 IEEE International Conference on Big Data (IEEE Big Data 2020)"
collection-doi: 10.1109/BigData50022.2020.9378050
conference:
name: 2020 IEEE International Conference on Big Data (IEEE Big Data 2020)
date-start: 2020-12-10
date-end: 2020-12-13
start: 276
end: 287
Binary file removed doc/images/logo_heAT.pdf
Binary file not shown.
4 changes: 2 additions & 2 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar
output_split = None

# data are not distributed or split dimension is not affected by indexing
if not self.is_distributed or key[self.split] == slice(None):
if not self.is_distributed() or key[self.split] == slice(None):
return DNDarray(
self.larray[key],
gshape=output_shape,
Expand Down Expand Up @@ -1654,7 +1654,7 @@ def __set(arr: DNDarray, value: DNDarray):
raise Exception("Advanced indexing is not supported yet")

split = self.split
if not self.is_distributed or key[split] == slice(None):
if not self.is_distributed() or key[split] == slice(None):
return __set(self[key], value)

if isinstance(key[split], slice):
Expand Down
57 changes: 30 additions & 27 deletions heat/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
__all__ = ["nonzero", "where"]


def nonzero(x: DNDarray) -> DNDarray:
def nonzero(x: DNDarray) -> Tuple[DNDarray, ...]:
"""
Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero.. (using ``torch.nonzero``)
If ``x`` is split then the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray`
Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``,
containing the indices of the non-zero elements in that dimension. If ``x`` is split then
the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray`
can be UNBALANCED as it contains the indices of the non-zero elements on each node.
Returns an array with one entry for each dimension of ``x``, containing the indices of the non-zero elements in that dimension.
The values in ``x`` are always tested and returned in row-major, C-style order.
The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``.

Expand All @@ -32,10 +32,8 @@ def nonzero(x: DNDarray) -> DNDarray:
>>> import heat as ht
>>> x = ht.array([[3, 0, 0], [0, 4, 1], [0, 6, 0]], split=0)
>>> ht.nonzero(x)
DNDarray([[0, 0],
[1, 1],
[1, 2],
[2, 1]], dtype=ht.int64, device=cpu:0, split=0)
(DNDarray([0, 1, 1, 2], dtype=ht.int64, device=cpu:0, split=None),
DNDarray([0, 1, 2, 1], dtype=ht.int64, device=cpu:0, split=None))
>>> y = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=0)
>>> y > 3
DNDarray([[False, False, False],
Expand All @@ -48,6 +46,8 @@ def nonzero(x: DNDarray) -> DNDarray:
[2, 0],
[2, 1],
[2, 2]], dtype=ht.int64, device=cpu:0, split=0)
(DNDarray([1, 1, 1, 2, 2, 2], dtype=ht.int64, device=cpu:0, split=None),
DNDarray([0, 1, 2, 0, 1, 2], dtype=ht.int64, device=cpu:0, split=None))
>>> y[ht.nonzero(y > 3)]
DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0)
"""
Expand All @@ -56,39 +56,42 @@ def nonzero(x: DNDarray) -> DNDarray:
except AttributeError:
raise TypeError("Input must be a DNDarray, is {}".format(type(x)))

lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False).transpose(0, 1)

if x.split is None:
# if there is no split then just return the values from torch
lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False)
# if there is no split then just return the transpose of values from torch

gout = list(lcl_nonzero.size())
is_split = None
else:
# a is split
lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False)
# adjust local indices along split dimension
_, displs = x.counts_displs()
lcl_nonzero[..., x.split] += displs[x.comm.rank]
lcl_nonzero[x.split] += displs[x.comm.rank]
del displs

# get global size of split dimension
gout = list(lcl_nonzero.size())
gout[0] = x.comm.allreduce(gout[0], MPI.SUM)
gout[1] = x.comm.allreduce(gout[1], MPI.SUM)
is_split = 0

if x.ndim == 1:
lcl_nonzero = lcl_nonzero.squeeze(dim=1)
for g in range(len(gout) - 1, -1, -1):
if gout[g] == 1:
del gout[g]

return DNDarray(
lcl_nonzero,
gshape=tuple(gout),
dtype=types.canonical_heat_type(lcl_nonzero.dtype),
split=is_split,
device=x.device,
comm=x.comm,
balanced=False,
non_zero_indices = list(
[
DNDarray(
dim_indices,
gshape=tuple(gout),
dtype=types.canonical_heat_type(lcl_nonzero.dtype),
split=is_split,
device=x.device,
comm=x.comm,
balanced=False,
)
for dim_indices in lcl_nonzero
]
)

return tuple(non_zero_indices)


DNDarray.nonzero = lambda self: nonzero(self)
DNDarray.nonzero.__doc__ = nonzero.__doc__
Expand Down
14 changes: 7 additions & 7 deletions heat/core/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ def test_nonzero(self):
a = ht.array([[1, 2, 3], [4, 5, 2], [7, 8, 9]], split=None)
cond = a > 3
nz = ht.nonzero(cond)
self.assertEqual(nz.gshape, (5, 2))
self.assertEqual(nz.dtype, ht.int64)
self.assertEqual(nz.split, None)
self.assertEqual(len(nz), 2)
self.assertEqual(len(nz[0]), 5)
self.assertEqual(nz[0].dtype, ht.int64)

# split
a = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=1)
cond = a > 3
nz = cond.nonzero()
self.assertEqual(nz.gshape, (6, 2))
self.assertEqual(nz.dtype, ht.int64)
self.assertEqual(nz.split, 0)
a[nz] = 10.0
self.assertEqual(len(nz), 2)
self.assertEqual(len(nz[0]), 6)
self.assertEqual(nz[0].dtype, ht.int64)
a[nz] = 10
self.assertEqual(ht.all(a[nz] == 10), 1)

def test_where(self):
Expand Down
32 changes: 0 additions & 32 deletions scripts/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1044,38 +1044,6 @@
"a + b"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The example below will show that it is also possible to use operations on tensors with different split and the proper result calculated. However, this should be used seldomly and with small data amounts only, as it entails sending large amounts of data over the network."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0/2) tensor([[9., 9., 9., 9., 9., 9.],\n",
"(0/2) [9., 9., 9., 9., 9., 9.]])\n",
"(1/2) tensor([[9., 9., 9., 9., 9., 9.],\n",
"(1/2) [9., 9., 9., 9., 9., 9.]])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = ht.full((4, 6,), 8, split=0)\n",
"b = ht.ones((4, 6,), split=1)\n",
"a + b"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down