Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flake8 code style fixes #241

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions arrayfire/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#######################################################
# Copyright (c) 2019, ArrayFire
# Copyright (c) 2020, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
Expand All @@ -14,11 +14,13 @@
from .array import Array
from .library import backend, safe_call, BINARYOP, c_bool_t, c_double_t, c_int_t, c_pointer, c_uint_t


def _parallel_dim(a, dim, c_func):
out = Array()
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim)))
return out


def _reduce_all(a, c_func):
real = c_double_t(0)
imag = c_double_t(0)
Expand All @@ -29,11 +31,13 @@ def _reduce_all(a, c_func):
imag = imag.value
return real if imag == 0 else real + imag * 1j


def _nan_parallel_dim(a, dim, c_func, nan_val):
out = Array()
safe_call(c_func(c_pointer(out.arr), a.arr, c_int_t(dim), c_double_t(nan_val)))
return out


def _nan_reduce_all(a, c_func, nan_val):
real = c_double_t(0)
imag = c_double_t(0)
Expand All @@ -44,6 +48,7 @@ def _nan_reduce_all(a, c_func, nan_val):
imag = imag.value
return real if imag == 0 else real + imag * 1j


def _FNSD(dim, dims):
if dim >= 0:
return int(dim)
Expand All @@ -55,20 +60,26 @@ def _FNSD(dim, dims):
break
return int(fnsd)


def _rbk_dim(keys, vals, dim, c_func):
keys_out = Array()
vals_out = Array()
rdim = _FNSD(dim, vals.dims())
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim)))
return keys_out, vals_out


def _nan_rbk_dim(a, dim, c_func, nan_val):
keys_out = Array()
vals_out = Array()
# FIXME: vals is undefined
rdim = _FNSD(dim, vals.dims())
safe_call(c_func(c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
# FIXME: keys is undefined
safe_call(c_func(
c_pointer(keys_out.arr), c_pointer(vals_out.arr), keys.arr, vals.arr, c_int_t(rdim), c_double_t(nan_val)))
return keys_out, vals_out


def sum(a, dim=None, nan_val=None):
"""
Calculate the sum of all the elements along a specified dimension.
Expand Down Expand Up @@ -98,8 +109,6 @@ def sum(a, dim=None, nan_val=None):
return _reduce_all(a, backend.get().af_sum_all)




def sumByKey(keys, vals, dim=-1, nan_val=None):
"""
Calculate the sum of elements along a specified dimension according to a key.
Expand All @@ -122,10 +131,10 @@ def sumByKey(keys, vals, dim=-1, nan_val=None):
values: af.Array or scalar number
The sum of all elements in `vals` along dimension `dim` according to keys
"""
if (nan_val is not None):
if nan_val is not None:
return _nan_rbk_dim(keys, vals, dim, backend.get().af_sum_by_key_nan, nan_val)
else:
return _rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)
return _rbk_dim(keys, vals, dim, backend.get().af_sum_by_key)


def product(a, dim=None, nan_val=None):
"""
Expand Down Expand Up @@ -178,10 +187,10 @@ def productByKey(keys, vals, dim=-1, nan_val=None):
values: af.Array or scalar number
The product of all elements in `vals` along dimension `dim` according to keys
"""
if (nan_val is not None):
if nan_val is not None:
return _nan_rbk_dim(keys, vals, dim, backend.get().af_product_by_key_nan, nan_val)
else:
return _rbk_dim(keys, vals, dim, backend.get().af_product_by_key)
return _rbk_dim(keys, vals, dim, backend.get().af_product_by_key)


def min(a, dim=None):
"""
Expand Down Expand Up @@ -227,6 +236,7 @@ def minByKey(keys, vals, dim=-1):
"""
return _rbk_dim(keys, vals, dim, backend.get().af_min_by_key)


def max(a, dim=None):
"""
Find the maximum value of all the elements along a specified dimension.
Expand Down Expand Up @@ -271,6 +281,7 @@ def maxByKey(keys, vals, dim=-1):
"""
return _rbk_dim(keys, vals, dim, backend.get().af_max_by_key)


def all_true(a, dim=None):
"""
Check if all the elements along a specified dimension are true.
Expand Down Expand Up @@ -315,6 +326,7 @@ def allTrueByKey(keys, vals, dim=-1):
"""
return _rbk_dim(keys, vals, dim, backend.get().af_all_true_by_key)


def any_true(a, dim=None):
"""
Check if any the elements along a specified dimension are true.
Expand All @@ -334,8 +346,8 @@ def any_true(a, dim=None):
"""
if dim is not None:
return _parallel_dim(a, dim, backend.get().af_any_true)
else:
return _reduce_all(a, backend.get().af_any_true_all)
return _reduce_all(a, backend.get().af_any_true_all)


def anyTrueByKey(keys, vals, dim=-1):
"""
Expand All @@ -359,6 +371,7 @@ def anyTrueByKey(keys, vals, dim=-1):
"""
return _rbk_dim(keys, vals, dim, backend.get().af_any_true_by_key)


def count(a, dim=None):
"""
Count the number of non zero elements in an array along a specified dimension.
Expand All @@ -378,8 +391,7 @@ def count(a, dim=None):
"""
if dim is not None:
return _parallel_dim(a, dim, backend.get().af_count)
else:
return _reduce_all(a, backend.get().af_count_all)
return _reduce_all(a, backend.get().af_count_all)


def countByKey(keys, vals, dim=-1):
Expand All @@ -404,6 +416,7 @@ def countByKey(keys, vals, dim=-1):
"""
return _rbk_dim(keys, vals, dim, backend.get().af_count_by_key)


def imin(a, dim=None):
"""
Find the value and location of the minimum value along a specified dimension
Expand Down
14 changes: 4 additions & 10 deletions arrayfire/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def cast(a, dtype):
out : af.Array
array containing the values from `a` after converting to `dtype`.
"""
out=Array()
out = Array()
safe_call(backend.get().af_cast(c_pointer(out.arr), a.arr, dtype.value))
return out

Expand Down Expand Up @@ -156,15 +156,8 @@ def clamp(val, low, high):
vdims = dim4_to_tuple(val.dims())
vty = val.type()

if not is_low_array:
low_arr = constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
else:
low_arr = low.arr

if not is_high_array:
high_arr = constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
else:
high_arr = high.arr
low_arr = low.arr if is_low_array else constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still, I don't see a point to revert these changes. These variables are just temp variables and they are not changed later anyhow. But we drop extra 7 lines of code.

high_arr = high.arr if is_high_array else constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)

safe_call(backend.get().af_clamp(c_pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))

Expand Down Expand Up @@ -1003,6 +996,7 @@ def sqrt(a):
"""
return _arith_unary_func(a, backend.get().af_sqrt)


def rsqrt(a):
"""
Reciprocal or inverse square root of each element in the array.
Expand Down
Loading