Skip to content

Commit

Permalink
[fix] handle zero-dim arrays in variance.lisp
Browse files Browse the repository at this point in the history
  • Loading branch information
digikar99 committed Dec 30, 2024
1 parent 433cda1 commit a44c3d2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
26 changes: 14 additions & 12 deletions dense-numericals-src/statistics/variance.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@
(type (integer 0 #.array-dimension-limit) size))
(nu:multiply array array :broadcast nil :out square)
(locally (declare (compiler-macro-notes:muffle runtime-array-allocation))
(let ((array-sum (nu:sum array :axes nil :keep-dims nil))
(square-sum (nu:sum square :axes nil :keep-dims nil)))
(/ (- square-sum
(/ (* array-sum array-sum)
size))
(- size ddof))))))
(let ((array-sum (aref (nu:sum array :axes nil :keep-dims nil)))
(square-sum (aref (nu:sum square :axes nil :keep-dims nil))))
(ensure-array (/ (- square-sum
(/ (* array-sum array-sum)
size))
(- size ddof))
()
<type>)))))


;; FIXME: These tests are not exhaustive.
Expand All @@ -141,11 +143,11 @@
(loop :for *array-element-type* :in `(single-float
double-float)
:do
(5am:is (= (coerce 2/3 *array-element-type*)
(nu:variance (nu:asarray '(1 2 3)))))
(5am:is (= (coerce 5/3 *array-element-type*)
(nu:variance (nu:asarray '((2 3 4)
(4 5 6))))))
(5am:is (nu:array= (ensure-array 2/3)
(nu:variance (nu:asarray '(1 2 3)))))
(5am:is (nu:array= (ensure-array 5/3)
(nu:variance (nu:asarray '((2 3 4)
(4 5 6))))))
(5am:is (nu:array= (nu:asarray '(1 1 1))
(nu:variance (nu:asarray '((2 3 4)
(4 5 6)))
Expand Down Expand Up @@ -193,7 +195,7 @@
:keep-dims t)))

(let ((array (nu:rand 100)))
(5am:is (float-close-p (nu:variance array)
(5am:is (float-close-p (row-major-aref (nu:variance array) 0)
(let* ((sum 0)
(mean
(progn
Expand Down
26 changes: 14 additions & 12 deletions src/statistics/variance.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@
(type (integer 0 #.array-dimension-limit) size))
(nu:multiply array array :broadcast nil :out square)
(locally (declare (compiler-macro-notes:muffle runtime-array-allocation))
(let ((array-sum (nu:sum array :axes nil :keep-dims nil))
(square-sum (nu:sum square :axes nil :keep-dims nil)))
(/ (- square-sum
(/ (* array-sum array-sum)
size))
(- size ddof))))))
(let ((array-sum (aref (nu:sum array :axes nil :keep-dims nil)))
(square-sum (aref (nu:sum square :axes nil :keep-dims nil))))
(ensure-array (/ (- square-sum
(/ (* array-sum array-sum)
size))
(- size ddof))
()
<type>)))))


;; FIXME: These tests are not exhaustive.
Expand All @@ -141,11 +143,11 @@
(loop :for *array-element-type* :in `(single-float
double-float)
:do
(5am:is (= (coerce 2/3 *array-element-type*)
(nu:variance (nu:asarray '(1 2 3)))))
(5am:is (= (coerce 5/3 *array-element-type*)
(nu:variance (nu:asarray '((2 3 4)
(4 5 6))))))
(5am:is (nu:array= (ensure-array 2/3)
(nu:variance (nu:asarray '(1 2 3)))))
(5am:is (nu:array= (ensure-array 5/3)
(nu:variance (nu:asarray '((2 3 4)
(4 5 6))))))
(5am:is (nu:array= (nu:asarray '(1 1 1))
(nu:variance (nu:asarray '((2 3 4)
(4 5 6)))
Expand Down Expand Up @@ -193,7 +195,7 @@
:keep-dims t)))

(let ((array (nu:rand 100)))
(5am:is (float-close-p (nu:variance array)
(5am:is (float-close-p (row-major-aref (nu:variance array) 0)
(let* ((sum 0)
(mean
(progn
Expand Down

0 comments on commit a44c3d2

Please sign in to comment.