Skip to content

Commit

Permalink
arminmax
Browse files Browse the repository at this point in the history
  • Loading branch information
DrTimothyAldenDavis committed Jul 3, 2024
1 parent c9acd62 commit bac16ab
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 7 deletions.
26 changes: 21 additions & 5 deletions experimental/algorithm/LAGraph_argminmax.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,26 @@ int argminmax
// for dim=2: find the position of the min/max entry in each row:
// p = G*y, so that p(i) = j if x(i) = A(i,j) = min/max (A (i,:)).

// Use the SECONDI operator since built-in indexing is 0-based. The ANY
// monoid would be faster, but this uses MIN monoid so that the result for
// the user is repeatable.
GRB_TRY (GrB_mxm (*p, NULL, NULL, GxB_MIN_SECONDI_INT64, G, y, desc)) ;
#if 0
printf ("argmin/max with 2ndi\n") ;
// Use the SECONDI operator since built-in indexing is 0-based. The
// ANY monoid would be faster, but this uses MIN monoid so that the
// result for the user is repeatable.
// p = G*y or G'*y using the MIN_SECONDI semiring
GRB_TRY (GrB_mxm (*p, NULL, NULL, GxB_MIN_SECONDI_INT64, G, y, desc)) ;
#else
printf ("argmin/max without 2ndi\n") ;
// H = rowindex (G) if dim is 1, or colindex (G) if dim is 2.
GrB_Matrix H = NULL ;
GRB_TRY (GrB_Matrix_new (&H, GrB_INT64, nrows, ncols)) ;
GRB_TRY (GrB_apply (H, NULL, NULL,
(dim == 1) ? GrB_ROWINDEX_INT64 : GrB_COLINDEX_INT64, G,
(int64_t) 0, NULL)) ;
// p = H*y or H'*y using the MIN_FIRST semiring
GRB_TRY (GrB_mxm (*p, NULL, NULL, GrB_MIN_FIRST_SEMIRING_INT64, H, y,
desc)) ;
GRB_TRY (GrB_Matrix_free (&H)) ;
#endif

//--------------------------------------------------------------------------
// free workspace
Expand Down Expand Up @@ -312,7 +328,7 @@ int LAGraph_argminmax
GRB_TRY (GrB_Matrix_extractElement_INT64 (&(I [0]), *p, 0, 0)) ;
// I [1] = p [I [0]-1] (use -1 since I[0] is 1-based),
// which is the column index of the global argmin/max of A
GRB_TRY (GrB_Matrix_extractElement_INT64 (&(I [1]), p1, I [0] - 1, 0)) ;
GRB_TRY (GrB_Matrix_extractElement_INT64 (&(I [1]), p1, I [0], 0)) ;
}

// free workspace and create p = [row, col]
Expand Down
89 changes: 89 additions & 0 deletions experimental/algorithm/hpec24_notes/LAGraph_argminmax.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// A simplified algorithm for HPEC'24
// assume the matrix type is FP64
// assume argmax
// use mxv where appropriate
// don't use the ANY monoid.

//------------------------------------------------------------------------------
// argmax: compute argmax of each row of A
//------------------------------------------------------------------------------

int argmax
(
// output
GrB_Vector *x_handle, // max value in each row of A
GrB_Vector *p_handle, // index of max value in each row of A
// input
GrB_Matrix A // assumed to be GrB_FP64
)
{

//--------------------------------------------------------------------------
// create outputs x and p, and the iso full vector y
//--------------------------------------------------------------------------

GrB_Index nrows, ncols ;
GrB_Matrix_nrows (&nrows, A) ;
GrB_Matrix_ncols (&ncols, A) ;
GrB_Vector y = NULL, x = NULL, p = NULL ;
GrB_Matrix G = NULL, D = NULL ;
GrB_Vector_new (&x, GrB_FP64, nrows) ;
GrB_Vector_new (&y, GrB_FP64, ncols) ;
GrB_Vector_new (&p, GrB_INT64, nrows) ;

// y (:) = 1, an full vector with all entries equal to 1
GrB_Matrix_assign_INT64 (y, NULL, NULL, 1, GrB_ALL, ncols, NULL) ;

//--------------------------------------------------------------------------
// compute x = max(A)
//--------------------------------------------------------------------------

// x = max (A) where x(i) = max (A (i,:))
GrB_mxv (x, NULL, NULL, GrB_MAX_FIRST_SEMIRING_FP64, A, y, NULL) ;

//--------------------------------------------------------------------------
// compute G, where G(i,j)=1 if A(i,j) is the max in its row
//--------------------------------------------------------------------------

// D = diag (x)
GrB_Matrix_diag (&D, x, 0) ;
GrB_Matrix_new (&G, GrB_BOOL, nrows, ncols) ;
// G = D*A using the EQ_EQ_FP64 semiring
GrB_mxm (G, NULL, NULL, GxB_EQ_EQ_FP64, D, A, NULL) ;
// drop explicit zeros from G
GrB_Matrix_select_BOOL (G, NULL, NULL, GrB_VALUENE_BOOL, G, 0, NULL) ;

//--------------------------------------------------------------------------
// extract the positions of the entries in G
//--------------------------------------------------------------------------

// find the position of the max entry in each row:
// p = G*y, so that p(i) = j if x(i) = A(i,j) = max (A (i,:)).

if (no 2ndI op)
{
// H = rowindex (G)
GrB_Matrix H = NULL ;
GrB_Matrix_new (&H, nrows, ncols) ;
GrB_apply (H, NULL, NULL, GrB_ROWINDEX_INT64, G, NULL) ;
// p = H*y
GrB_mxv (p, NULL, NULL, GrB_MIN_FIRST_SEMIRING_INT64, H, y, NULL) ;
GrB_free (&H) ;
}
else
{
// using the SECONDI operator
GrB_mxm (p, NULL, NULL, GxB_MIN_SECONDI_INT64, G, y, NULL) ;
}

//--------------------------------------------------------------------------
// free workspace and return result
//--------------------------------------------------------------------------

GrB_Matrix_free (&D) ;
GrB_Matrix_free (&G) ;
GrB_Matrix_free (&y) ;
(*x_handle) = x ;
(*p_handle) = p ;
}

70 changes: 70 additions & 0 deletions experimental/benchmark/argmax_tests.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@

% for HPEC'24 paper

if (0)
clear all
Prob = ssget ('GAP/GAP-twitter')
end
A = Prob.A ;
nz = nnz (A)
n = size (A,1) ;
A = A + speye (n) ;
nz = nnz (A)
G = GrB (A) ;

% time the GraphBLAS max and argmax methods
for thr = [1 40]
GrB.threads (thr) ;
for trial = 1:3
fprintf ('\ntrial %d, threads %g\n', trial, thr) ;

t = tic ;
x = max (G, [ ], 1) ;
t1 = toc (t) ;
fprintf ('GrB colwise max: %g sec\n', t1) ;
t = tic ;
[x,p] = GrB.argmax (G, 1) ;
t2 = toc (t) ;
fprintf ('GrB colwise argmax: %g sec\n', t2) ;
fprintf ('GrB colwise argmax time / max time: %g\n', t2/t1) ;

t = tic ;
x = max (G, [ ], 2) ;
t1 = toc (t) ;
fprintf ('GrB rowwise max: %g sec\n', t1) ;
t = tic ;
[x,p] = GrB.argmax (G, 2) ;
t2 = toc (t) ;
fprintf ('GrB rowwise argmax: %g sec\n', t2) ;
fprintf ('GrB rowwise argmax time / max time: %g\n', t2/t1) ;

end
end


% time the MATLAB max and argmax methods
for trial = 1:3
fprintf ('\ntrial %d\n', trial) ;

t = tic ;
x = max (A, [ ], 1) ;
t1 = toc (t) ;
fprintf ('MATLAB colwise max: %g sec\n', t1) ;
t = tic ;
[x,p] = max (A, [ ], 1) ;
t2 = toc (t) ;
fprintf ('MATLAB colwise argmax: %g sec\n', t2) ;
fprintf ('MATLAB colwise argmax time / max time: %g\n', t2/t1) ;

t = tic ;
x = max (A, [ ], 2) ;
t1 = toc (t) ;
fprintf ('MATLAB rowwise max: %g sec\n', t1) ;
t = tic ;
[x,p] = max (A, [ ], 2) ;
t2 = toc (t) ;
fprintf ('MATLAB rowwise argmax: %g sec\n', t2) ;
fprintf ('MATLAB rowwise argmax time / max time: %g\n', t2/t1) ;

end

5 changes: 3 additions & 2 deletions experimental/test/test_argminmax.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ void test_argminmax (void)
printf ("\nInput of Matrix:\n") ;
GxB_print(A, 2);
// test the algorithm
OK (LAGraph_argminmax (&x,&p, A,dim,is_min, msg));
printf("\n") ;
int info = LAGraph_argminmax (&x,&p, A,dim,is_min, msg);
printf("%s\n", msg) ;
OK (info) ;
GxB_print(x,3);
GxB_print(p,3);
// print the result
Expand Down

0 comments on commit bac16ab

Please sign in to comment.