diff --git a/src/high-level/types/complex-single-float.lisp b/src/high-level/types/complex-single-float.lisp index e7a4c967..667b0163 100644 --- a/src/high-level/types/complex-single-float.lisp +++ b/src/high-level/types/complex-single-float.lisp @@ -29,6 +29,37 @@ (loop :for i :below (size vector1) :sum (* (tref vector1 i) (conjugate (tref vector2 i))))) +(defmethod =-lisp ((scalar1 complex) (scalar2 complex) &optional epsilon) + (declare (ignore epsilon)) + (and (=-lisp (realpart scalar1) (realpart scalar2)) + (=-lisp (imagpart scalar1) (imagpart scalar2)))) + +(defmethod =-lisp ((scalar1 complex) (scalar2 real) &optional epsilon) + (let ((imagpart (imagpart scalar1)) + (zero nil)) + (typecase imagpart ;; TODO: do we need to care about short-floats and long-floats? On which implementations does it matter? + (single-float (setf epsilon (or epsilon *float-comparison-threshold*) + zero 0.0s0)) + (double-float (setf epsilon (or epsilon *double-comparison-threshold*) + zero 0.0d0)) + ; If imagpart is rational it's also guaranteed nonzero per ANSI, which means we can just go ahead and exit now + (t (return-from =-lisp nil))) + (and (%scalar= zero imagpart epsilon) + (%scalar= (realpart scalar1) scalar2 epsilon)))) + +(defmethod =-lisp ((scalar1 real) (scalar2 complex) &optional epsilon) + (let ((imagpart (imagpart scalar2)) + (zero nil)) + (typecase imagpart + (single-float (setf epsilon (or epsilon *float-comparison-threshold*) + zero 0.0s0)) + (double-float (setf epsilon (or epsilon *double-comparison-threshold*) + zero 0.0d0)) + ; If imagpart is rational it's also guaranteed nonzero per ANSI, which means we can just go ahead and exit now + (t (return-from =-lisp nil))) + (and (%scalar= zero imagpart epsilon) + (%scalar= (realpart scalar2) scalar1 epsilon)))) + (defmethod =-lisp ((tensor1 tensor/complex-single-float) (tensor2 tensor/complex-single-float) &optional (epsilon *float-comparison-threshold*)) (unless (equal (shape tensor1) (shape tensor2)) (return-from =-lisp nil)) diff --git a/src/high-level/types/double-float.lisp b/src/high-level/types/double-float.lisp index 1eaf7424..a9ae2bee 100644 --- a/src/high-level/types/double-float.lisp +++ b/src/high-level/types/double-float.lisp @@ -22,6 +22,15 @@ matrix/double-float vector/double-float) +(defmethod =-lisp ((scalar1 double-float) (scalar2 double-float) &optional (epsilon *double-comparison-threshold*)) + (%scalar= scalar1 scalar2 epsilon)) + +(defmethod =-lisp ((scalar1 rational) (scalar2 double-float) &optional (epsilon *double-comparison-threshold*)) + (%scalar= scalar1 scalar2 epsilon)) + +(defmethod =-lisp ((scalar1 double-float) (scalar2 rational) &optional (epsilon *double-comparison-threshold*)) + (%scalar= scalar1 scalar2 epsilon)) + (defmethod =-lisp ((tensor1 tensor/double-float) (tensor2 tensor/double-float) &optional (epsilon *double-comparison-threshold*)) (unless (equal (shape tensor1) (shape tensor2)) (return-from =-lisp nil)) diff --git a/src/high-level/types/single-float.lisp b/src/high-level/types/single-float.lisp index 67f45530..86b1ed44 100644 --- a/src/high-level/types/single-float.lisp +++ b/src/high-level/types/single-float.lisp @@ -22,6 +22,37 @@ matrix/single-float vector/single-float) +;; Might want to inline this +(defun %scalar= (s1 s2 epsilon) + "For equality checks of inexact scalars." + (declare (type number s1 s2) ; you can use this on complex but it might be slower because of the sqrt and two multiplies. Maybe. + (type real epsilon)) + (<= (abs (- s1 + s2)) + epsilon)) + +(defmethod =-lisp ((scalar1 rational) (scalar2 rational) &optional epsilon) + (declare (ignore epsilon)) + "Rationals (integers and ratios) should be compared exactly." + (common-lisp:= scalar1 scalar2)) + +(defmethod =-lisp ((scalar1 single-float) (scalar2 single-float) &optional (epsilon *float-comparison-threshold*)) + (%scalar= scalar1 scalar2 epsilon)) + +(defmethod =-lisp ((scalar1 rational) (scalar2 single-float) &optional (epsilon *float-comparison-threshold*)) + (%scalar= scalar1 scalar2 epsilon)) + +(defmethod =-lisp ((scalar1 single-float) (scalar2 rational) &optional (epsilon *float-comparison-threshold*)) + (%scalar= scalar1 scalar2 epsilon)) + +(defmethod =-lisp ((scalar1 float) (scalar2 single-float) &optional (epsilon *float-comparison-threshold*)) + "This covers comparing double-float to single-float. Use least precise epsilon." + (%scalar= scalar1 scalar2 epsilon)) + +(defmethod =-lisp ((scalar1 single-float) (scalar2 float) &optional (epsilon *float-comparison-threshold*)) + "This covers comparing single-float to double-float. Use least precise epsilon." + (%scalar= scalar1 scalar2 epsilon)) + (defmethod =-lisp ((tensor1 tensor/single-float) (tensor2 tensor/single-float) &optional (epsilon *float-comparison-threshold*)) (unless (equal (shape tensor1) (shape tensor2)) (return-from =-lisp nil)) diff --git a/tests/abstract-tensor-tests.lisp b/tests/abstract-tensor-tests.lisp index 9c881db6..e54a2dd5 100644 --- a/tests/abstract-tensor-tests.lisp +++ b/tests/abstract-tensor-tests.lisp @@ -4,6 +4,81 @@ (in-package #:magicl-tests) +(defmacro swapping-arguments-is ((predicate arg1 arg2)) + "Try both argument orders for a commutative predicate." + (let ((arg1sym (gensym)) + (arg2sym (gensym))) + `(let ((,arg1sym ,arg1) + (,arg2sym ,arg2)) + (is (,predicate ,arg1sym ,arg2sym)) + (is (,predicate ,arg2sym ,arg1sym))))) + +(defmacro swapping-arguments-not ((predicate arg1 arg2)) + "Try both argument orders for a negated commutative predicate." + (let ((arg1sym (gensym)) + (arg2sym (gensym))) + `(let ((,arg1sym ,arg1) + (,arg2sym ,arg2)) + (is (not (,predicate ,arg1sym ,arg2sym))) + (is (not (,predicate ,arg2sym ,arg1sym)))))) + +(deftest test-scalar-equality () + "Test the various scalar equality predicates." + (let ((exactvalues '((-1 0 1) ; integers + (-3/2 0 3/2) ; ratios + (-1.0s0 0.0s0 1.0s0) ; single-floats + (-1.0d0 0.0d0 1.0d0) ; double-floats + (#c(-1.0s0 1.0s0) #c(-1.0s0 0.0s0) #c(1.0s0 -1.0s0) #c(1.0s0 1.0s0)) ; complex-singles + (#c(-1.0d0 1.0d0) #c(-1.0d0 0.0d0) #c(1.0d0 -1.0d0) #c(1.0d0 1.0d0)) ; complex-doubles + )) + (inexactvalues '(-1.0s0 0.0s0 1.0s0 ; single-floats + -1.0d0 0.0d0 1.0d0 ; double-floats + #c(-1.0s0 1.0s0) #c(-1.0s0 0.0s0) #c(1.0s0 -1.0s0) #c(1.0s0 1.0s0) ; complex-singles + #c(-1.0d0 1.0d0) #c(-1.0d0 0.0d0) #c(1.0d0 -1.0d0) #c(1.0d0 1.0d0) ; complex-doubles + )) + (small-single-delta (/ magicl::*float-comparison-threshold* 2)) + (small-double-delta (/ magicl::*double-comparison-threshold* 2)) + (big-single-delta (* magicl::*float-comparison-threshold* 2)) + (big-double-delta (* magicl::*double-comparison-threshold* 2))) + + (flet ((test-exact (group1 group2) + "Verify that magicl:= matches common-lisp:= where appropriate" + (dolist (x1 group1) + (dolist (x2 group2) + (if (common-lisp:= x1 x2) + (swapping-arguments-is (magicl:= x1 x2)) + (swapping-arguments-not (magicl:= x1 x2)))))) + + (test-inexact (x) + "Verify that magicl:= works as expected on inexact values close to epsilon" + (let ((smalldelta (etypecase x ; a delta small enough that = should still be true + (single-float small-single-delta) + (double-float small-double-delta) + ((complex single-float) small-single-delta) + ((complex double-float) small-double-delta))) + (bigdelta (etypecase x ; a delta big enough that = should become false + (single-float big-single-delta) + (double-float big-double-delta) + ((complex single-float) big-single-delta) + ((complex double-float) big-double-delta)))) + + (swapping-arguments-is (magicl:= x (+ x smalldelta))) + (swapping-arguments-is (magicl:= x (- x smalldelta))) + (swapping-arguments-not (magicl:= x (+ x bigdelta))) + (swapping-arguments-not (magicl:= x (- x bigdelta))) + ; offset the imaginary parts. Bonus: This also causes real/complex comparisons when x is real. + (swapping-arguments-is (magicl:= x (+ x (complex 0.0 smalldelta)))) + (swapping-arguments-is (magicl:= x (- x (complex 0.0 smalldelta)))) + (swapping-arguments-not (magicl:= x (+ x (complex 0.0 bigdelta)))) + (swapping-arguments-not (magicl:= x (- x (complex 0.0 bigdelta))))))) + + (dolist (group1 exactvalues) + (dolist (group2 exactvalues) + (test-exact group1 group2))) + + (dolist (x inexactvalues) + (test-inexact x))))) + (deftest test-tensor-equality () "Test that tensor equality is sane for tensors of dimension 1 to 8" (loop :for dimensions :on '(8 7 6 5 4 3 2 1) :do diff --git a/tests/high-level-tests.lisp b/tests/high-level-tests.lisp index d1b14c1c..dfb59ee0 100644 --- a/tests/high-level-tests.lisp +++ b/tests/high-level-tests.lisp @@ -11,24 +11,24 @@ (let* ((x (magicl:from-list '(6 4 2 1 -2 8 1 5 7) '(3 3) :type 'double-float)) (d (magicl:det x))) - (is (= d -306d0)))) + (is (magicl:= d -306d0)))) (deftest test-p-norm () "Test that the p-norm of vectors returns sane values." ;; Basic 3-4-5 - (is (= 5 (magicl:norm (magicl:from-list '(3 4) '(2))))) + (is (magicl:= 5 (magicl:norm (magicl:from-list '(3 4) '(2))))) ;; One element vector should return element for all (loop :for val :in '(-3 0 10) :do (let ((x (magicl:from-list (list val) '(1)))) - (is (= (abs val) (magicl:norm x 1))) - (is (= (abs val) (magicl:norm x 2))) - (is (= (abs val) (magicl:norm x 3))) - (is (= (abs val) (magicl:norm x :infinity))))) + (is (magicl:= (abs val) (magicl:norm x 1))) + (is (magicl:= (abs val) (magicl:norm x 2))) + (is (magicl:= (abs val) (magicl:norm x 3))) + (is (magicl:= (abs val) (magicl:norm x :infinity))))) ;; Test known values (let ((x (magicl:from-list '(1 -2 3 4 5 -6) '(6)))) - (is (= 6 (magicl:norm x :infinity))) - (is (= 21 (magicl:norm x 1))) - (is (= 9.539392 (magicl:norm x 2))))) + (is (magicl:= 6 (magicl:norm x :infinity))) + (is (magicl:= 21 (magicl:norm x 1))) + (is (magicl:= 9.539392 (magicl:norm x 2))))) (deftest test-examples () "Run all of the examples. Does not check for their correctness."