Skip to content

Commit 56a8b62

Browse files
authored
Fix possible loss of data warnings. (#70)
* Fix loss of precision warnings.
1 parent faa30f7 commit 56a8b62

File tree

2 files changed

+29
-25
lines changed

2 files changed

+29
-25
lines changed

numpy/include/numpy.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ class ndarray {
108108
// Properties
109109
_Dtype dtype() const { return dtype_traits<T>::name; }
110110

111-
int ndim() const { return _array.dimension(); }
111+
int ndim() const { return static_cast<int>(_array.dimension()); }
112112

113-
int size() const { return _array.size(); }
113+
int size() const { return static_cast<int>(_array.size()); }
114114

115115
_ShapeLike shape() const { return _ShapeLike(_array.shape().begin(), _array.shape().end()); }
116116

@@ -395,14 +395,14 @@ class ndarray {
395395
}
396396

397397
// Searching and Sorting Functions
398-
T argmin() const { return (xt::argmin(_array))[0]; }
398+
pkpy::int64 argmin() const { return (xt::argmin(_array))[0]; }
399399

400400
ndarray<T> argmin(int axis) const {
401401
xt::xarray<T> result = xt::argmin(_array, {axis});
402402
return ndarray<T>(result);
403403
}
404404

405-
T argmax() const { return (xt::argmax(_array))[0]; }
405+
pkpy::int64 argmax() const { return (xt::argmax(_array))[0]; }
406406

407407
ndarray<T> argmax(int axis) const {
408408
xt::xarray<T> result = xt::argmax(_array, {axis});
@@ -493,7 +493,7 @@ class random {
493493
public:
494494
random() {
495495
auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
496-
xt::random::seed(seed);
496+
xt::random::seed(static_cast<xt::random::seed_type>(seed));
497497
}
498498

499499
template <typename T>
@@ -558,9 +558,9 @@ xt::xarray<std::common_type_t<T, U>> matrix_mul(const xt::xarray<T>& a, const xt
558558
b_copy = xt::reshape_view(b_copy, {3, 1});
559559
}
560560
if (a_copy.dimension() == 2 && b_copy.dimension() == 2) {
561-
int m = a_copy.shape()[0];
562-
int n = a_copy.shape()[1];
563-
int p = b_copy.shape()[1];
561+
int m = static_cast<int>(a_copy.shape()[0]);
562+
int n = static_cast<int>(a_copy.shape()[1]);
563+
int p = static_cast<int>(b_copy.shape()[1]);
564564

565565
Mat result = xt::zeros<result_type>({m, p});
566566

numpy/src/numpy.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,11 @@ class ndarray : public ndarray_base {
296296

297297
ndarray(const int32 value) : data(value) {}
298298

299-
ndarray(const int_ value) : data(value) {}
299+
ndarray(const int_ value) : data(static_cast<T>(value)) {}
300300

301301
ndarray(const float32 value) : data(value) {}
302302

303-
ndarray(const float64 value) : data(value) {}
303+
ndarray(const float64 value) : data(static_cast<T>(value)) {}
304304

305305
ndarray(const pkpy::numpy::ndarray<T>& _arr) : data(_arr) {}
306306

@@ -394,7 +394,9 @@ class ndarray : public ndarray_base {
394394
}
395395

396396
py::object min() const override {
397-
if constexpr (std::is_same_v<T, bool_> || std::is_same_v<T, int8> || std::is_same_v<T, int16> ||
397+
if constexpr (std::is_same_v<T, bool_>) {
398+
return py::bool_(data.min());
399+
} else if constexpr (std::is_same_v<T, int8> || std::is_same_v<T, int16> ||
398400
std::is_same_v<T, int32> || std::is_same_v<T, int64>) {
399401
return py::int_(data.min());
400402
} else if constexpr(std::is_same_v<T, float32> || std::is_same_v<T, float64>) {
@@ -426,8 +428,10 @@ class ndarray : public ndarray_base {
426428
}
427429

428430
py::object max() const override {
429-
if constexpr (std::is_same_v<T, bool_> || std::is_same_v<T, int8> || std::is_same_v<T, int16> ||
430-
std::is_same_v<T, int32> || std::is_same_v<T, int64>) {
431+
if constexpr (std::is_same_v<T, bool_>) {
432+
return py::bool_(data.max());
433+
} else if constexpr (std::is_same_v<T, int8> || std::is_same_v<T, int16> ||
434+
std::is_same_v<T, int32> || std::is_same_v<T, int64>) {
431435
return py::int_(data.max());
432436
} else if constexpr(std::is_same_v<T, float32> || std::is_same_v<T, float64>) {
433437
return py::float_(data.max());
@@ -548,7 +552,7 @@ class ndarray : public ndarray_base {
548552
std::is_same_v<T, int32> || std::is_same_v<T, int64>) {
549553
return py::int_(data.argmin());
550554
} else if constexpr(std::is_same_v<T, float32> || std::is_same_v<T, float64>) {
551-
return py::float_(data.argmin());
555+
return py::int_(data.argmin());
552556
} else {
553557
throw std::runtime_error("Unsupported type");
554558
}
@@ -561,7 +565,7 @@ class ndarray : public ndarray_base {
561565
std::is_same_v<T, int32> || std::is_same_v<T, int64>) {
562566
return py::int_(data.argmax());
563567
} else if constexpr(std::is_same_v<T, float32> || std::is_same_v<T, float64>) {
564-
return py::float_(data.argmax());
568+
return py::int_(data.argmax());
565569
} else {
566570
throw std::runtime_error("Unsupported type");
567571
}
@@ -1910,7 +1914,7 @@ class ndarray : public ndarray_base {
19101914
}
19111915
} else if constexpr(std::is_same_v<T, float64>) {
19121916
if (data.ndim() == 1) {
1913-
data.set_item(index, value);
1917+
data.set_item(index, static_cast<T>(value));
19141918
} else {
19151919
data.set_item(index, (pkpy::numpy::adapt<int_>(std::vector{value})).astype<float64>());
19161920
}
@@ -1958,7 +1962,7 @@ class ndarray : public ndarray_base {
19581962
}
19591963
} else if constexpr(std::is_same_v<T, int_>) {
19601964
if (data.ndim() == 1) {
1961-
data.set_item(index, value);
1965+
data.set_item(index, static_cast<T>(value));
19621966
} else {
19631967
data.set_item(index, (pkpy::numpy::adapt<float64>(std::vector{value})).astype<int_>());
19641968
}
@@ -2010,13 +2014,13 @@ class ndarray : public ndarray_base {
20102014
data.set_item(index, (pkpy::numpy::adapt<int_>(std::vector{value})).astype<float64>());
20112015
}
20122016
} else if(indices.size() == 2 && indices.size() <= data.ndim())
2013-
data.set_item_2d(indices[0], indices[1], value);
2017+
data.set_item_2d(indices[0], indices[1], static_cast<T>(value));
20142018
else if(indices.size() == 3 && indices.size() <= data.ndim())
2015-
data.set_item_3d(indices[0], indices[1], indices[2], value);
2019+
data.set_item_3d(indices[0], indices[1], indices[2], static_cast<T>(value));
20162020
else if(indices.size() == 4 && indices.size() <= data.ndim())
2017-
data.set_item_4d(indices[0], indices[1], indices[2], indices[3], value);
2021+
data.set_item_4d(indices[0], indices[1], indices[2], indices[3], static_cast<T>(value));
20182022
else if(indices.size() == 5 && indices.size() <= data.ndim())
2019-
data.set_item_5d(indices[0], indices[1], indices[2], indices[3], indices[4], value);
2023+
data.set_item_5d(indices[0], indices[1], indices[2], indices[3], indices[4], static_cast<T>(value));
20202024
}
20212025

20222026
void set_item_tuple_float(py::tuple args, float64 value) override {
@@ -2032,13 +2036,13 @@ class ndarray : public ndarray_base {
20322036
data.set_item(index, (pkpy::numpy::adapt<float64>(std::vector{value})).astype<int_>());
20332037
}
20342038
} else if(indices.size() == 2 && indices.size() <= data.ndim())
2035-
data.set_item_2d(indices[0], indices[1], value);
2039+
data.set_item_2d(indices[0], indices[1], static_cast<T>(value));
20362040
else if(indices.size() == 3 && indices.size() <= data.ndim())
2037-
data.set_item_3d(indices[0], indices[1], indices[2], value);
2041+
data.set_item_3d(indices[0], indices[1], indices[2], static_cast<T>(value));
20382042
else if(indices.size() == 4 && indices.size() <= data.ndim())
2039-
data.set_item_4d(indices[0], indices[1], indices[2], indices[3], value);
2043+
data.set_item_4d(indices[0], indices[1], indices[2], indices[3], static_cast<T>(value));
20402044
else if(indices.size() == 5 && indices.size() <= data.ndim())
2041-
data.set_item_5d(indices[0], indices[1], indices[2], indices[3], indices[4], value);
2045+
data.set_item_5d(indices[0], indices[1], indices[2], indices[3], indices[4], static_cast<T>(value));
20422046
}
20432047

20442048
void set_item_vector_int1(const std::vector<int>& indices, int_ value) override {

0 commit comments

Comments
 (0)