Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
AllanZyne committed Sep 14, 2024
1 parent 8622fa0 commit 59422b2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 61 deletions.
60 changes: 0 additions & 60 deletions examples/transposed3d.py

This file was deleted.

17 changes: 17 additions & 0 deletions src/idtr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,13 +572,17 @@ _idtr_copy_reshape(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,

namespace {

///
/// An util class of multi-dimensional index
///
class id {
public:
id(size_t dims) : _values(dims) {}
id(size_t dims, int64_t *value) : _values(value, value + dims) {}
id(const std::vector<int64_t> &values) : _values(values) {}
id(const std::vector<int64_t> &&values) : _values(std::move(values)) {}

/// Permute this id by axes and return a new id
id permute(std::vector<int64_t> axes) const {
std::vector<int64_t> new_values(_values.size());
for (size_t i = 0; i < _values.size(); i++) {
Expand All @@ -590,6 +594,7 @@ class id {
int64_t operator[](size_t i) const { return _values[i]; }
int64_t &operator[](size_t i) { return _values[i]; }

/// Subtract another id from this id and return a new id
id operator-(const id &rhs) const {
std::vector<int64_t> new_values(_values.size());
for (size_t i = 0; i < _values.size(); i++) {
Expand All @@ -598,6 +603,7 @@ class id {
return id(std::move(new_values));
}

/// Subtract another id from this id and return a new id
id operator-(const int64_t *rhs) const {
std::vector<int64_t> new_values(_values.size());
for (size_t i = 0; i < _values.size(); i++) {
Expand All @@ -606,6 +612,10 @@ class id {
return id(std::move(new_values));
}

/// Increase the last dimension value of this id which bounds by shape
///
/// Example:
/// In shape (2,2) : (0,0)->(0,1)->(1,0)->(1,1)->(0,0)
void next(const int64_t *shape) {
size_t i = _values.size();
while (i--) {
Expand All @@ -623,15 +633,20 @@ class id {
std::vector<int64_t> _values;
};

///
/// An wrapper template class for distribute multi-dimensional array
///
template <typename T> class ndarray {
public:
ndarray(int64_t nDims, int64_t *gShape, int64_t *gOffsets, void *lData,
int64_t *lShape, int64_t *lStrides)
: _nDims(nDims), _gShape(gShape), _gOffsets(gOffsets), _lData((T *)lData),
_lShape(lShape), _lStrides(lStrides) {}

/// Return the first global index of local data
id firstLocalIndex() const { return id(_nDims, _gOffsets); }

/// Interate all global indices in local data
void localIndices(const std::function<void(const id &)> &callback) const {
size_t size = lSize();
id idx = firstLocalIndex();
Expand All @@ -641,6 +656,7 @@ template <typename T> class ndarray {
}
}

/// Interate all global indices of the array
void globalIndices(const std::function<void(const id &)> &callback) const {
size_t size = gSize();
id idx(_nDims);
Expand All @@ -660,6 +676,7 @@ template <typename T> class ndarray {
return offset;
}

/// Using global index to access its data
T &operator[](const id &idx) { return _lData[getLocalDataOffset(idx)]; }
T operator[](const id &idx) const { return _lData[getLocalDataOffset(idx)]; }

Expand Down
38 changes: 37 additions & 1 deletion test/test_manip.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,45 @@ def test_todevice_host2gpu(self):
b = a.to_device(device="GPU")
assert numpy.allclose(sp.to_numpy(b), [0, 1, 2, 3, 4, 5, 6, 7])

def test_permute_dims(self):
def test_permute_dims1(self):
a = sp.arange(0, 10, 1, sp.int64)
b = sp.reshape(a, (2, 5))
c1 = sp.to_numpy(sp.permute_dims(b, [1, 0]))
c2 = sp.to_numpy(b).transpose(1, 0)
assert numpy.allclose(c1, c2)

def test_permute_dims2(self):
# === sharpy
sp_a = sp.arange(0, 2 * 3 * 4, 1)
sp_a = sp.reshape(sp_a, [2, 3, 4])

# b = a.swapaxes(1,0).swapaxes(1,2)
sp_b = sp.permute_dims(sp_a, (1, 0, 2)) # 2x4x4 -> 4x2x4 || 4x4x4
sp_b = sp.permute_dims(sp_b, (0, 2, 1)) # 4x2x4 -> 4x4x2 || 4x4x4

# c = b.swapaxes(1,2).swapaxes(1,0)
sp_c = sp.permute_dims(sp_b, (0, 2, 1))
sp_c = sp.permute_dims(sp_c, (1, 0, 2))

assert numpy.allclose(sp.to_numpy(sp_a), sp.to_numpy(sp_c))

# d = a.swapaxes(2,1).swapaxes(2,0)
sp_d = sp.permute_dims(sp_a, (0, 2, 1))
sp_d = sp.permute_dims(sp_d, (2, 1, 0))

# c = d.swapaxes(2,1).swapaxes(0,1)
sp_e = sp.permute_dims(sp_d, (0, 2, 1))
sp_e = sp.permute_dims(sp_e, (1, 0, 2))

# === numpy
np_a = numpy.arange(0, 2 * 3 * 4, 1)
np_a = numpy.reshape(np_a, [2, 3, 4])

np_b = np_a.swapaxes(1, 0).swapaxes(1, 2)
assert numpy.allclose(sp.to_numpy(sp_b), np_b)

np_d = np_a.swapaxes(2, 1).swapaxes(2, 0)
assert numpy.allclose(sp.to_numpy(sp_d), np_d)

np_e = np_d.swapaxes(2, 1).swapaxes(0, 1)
assert numpy.allclose(sp.to_numpy(sp_e), np_e)

0 comments on commit 59422b2

Please sign in to comment.