37
37
namespace matx
38
38
{
39
39
namespace detail {
40
- template <class T > class LinspaceOp {
40
+ template <class T , int NUM_RC > class LinspaceOp : public BaseOp <LinspaceOp<T, NUM_RC>> {
41
41
private:
42
- Range<T> range_;
43
-
42
+ cuda::std::array<T, NUM_RC> steps_;
43
+ cuda::std::array<T, NUM_RC> firsts_;
44
+ int axis_;
45
+ index_t count_;
44
46
public:
45
47
using value_type = T;
46
48
using matxop = bool ;
47
49
48
50
__MATX_INLINE__ std::string str () const { return " linspace" ; }
49
51
52
+ static inline constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank () { return NUM_RC; }
50
53
51
- inline LinspaceOp (T first, T last , index_t count)
54
+ inline LinspaceOp (const T (&firsts)[NUM_RC], const T (&lasts)[NUM_RC] , index_t count, int axis)
52
55
{
53
- #ifdef __CUDA_ARCH__
54
- range_ = Range<T>{first, (last - first) / static_cast <T>(count - 1 )};
55
- #else
56
- // Host has no support for most half precision operators/intrinsics
57
- if constexpr (is_matx_half_v<T>) {
58
- range_ = Range<T>{static_cast <float >(first),
59
- (static_cast <float >(last) - static_cast <float >(first)) /
60
- static_cast <float >(count - 1 )};
56
+ axis_ = axis;
57
+ count_ = count;
58
+ for (int i = 0 ; i < NUM_RC; ++i) {
59
+ firsts_[i] = firsts[i];
60
+ steps_[i] = (lasts[i] - firsts[i]) / static_cast <T>(count - 1 );
61
61
}
62
- else {
63
- range_ = Range<T>{first, (last - first) / static_cast <T>(count - 1 )};
62
+ }
63
+
64
+ template <typename ... Is>
65
+ __MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ T operator ()(Is... indices) const {
66
+ static_assert (sizeof ...(indices) == NUM_RC, " Number of indices incorrect in linspace" );
67
+ cuda::std::array idx{indices...};
68
+ if constexpr (sizeof ...(indices) == 1 ) {
69
+ return firsts_[0 ] + steps_[0 ] * static_cast <T>(idx[0 ]);
70
+ } else {
71
+ if (axis_ == 0 ) {
72
+ return firsts_[idx[1 ]] + steps_[idx[1 ]] * static_cast <T>(idx[0 ]);
73
+ } else {
74
+ return firsts_[idx[0 ]] + steps_[idx[0 ]] * static_cast <T>(idx[1 ]);
75
+ }
64
76
}
65
- #endif
66
77
}
67
78
68
- __MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ T operator ()(index_t idx) const { return range_ (idx); }
79
+ constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size (int dim) const
80
+ {
81
+ if constexpr (NUM_RC == 1 ) {
82
+ return count_;
83
+ } else {
84
+ if (dim != axis_) {
85
+ return NUM_RC;
86
+ } else {
87
+ return count_;
88
+ }
89
+ }
90
+ }
69
91
};
70
92
}
71
93
72
94
95
+ /* *
96
+ * @brief Create a matrix linearly-spaced range of values
97
+ *
98
+ * Creates a set of values using starts and stops that are linearly-
99
+ * spaced apart over the set of values. Distance is determined
100
+ * by the count parameter
101
+ *
102
+ * @tparam NUM_RC Number of rows or columns, depending on the axis
103
+ * @tparam T Type of the values
104
+ * @param firsts First values
105
+ * @param lasts Last values
106
+ * @param count Number of values in a row or column, depending on the axis
107
+ * @param axis Axis to operate over
108
+ * @return Operator with linearly-spaced values
109
+ */
110
+ template <int NUM_RC, typename T = float >
111
+ inline auto linspace (const T (&firsts)[NUM_RC], const T (&lasts)[NUM_RC], index_t count, int axis = 0)
112
+ {
113
+ return detail::LinspaceOp<T, NUM_RC>(firsts, lasts, count, axis);
114
+ }
115
+
116
+ /* *
117
+ * @brief Create a linearly-spaced vector of values
118
+ *
119
+ * Creates a set of values using startsand stop that are linearly-
120
+ * spaced apart over the set of values. Distance is determined
121
+ * by the count parameter
122
+ *
123
+ * @tparam T Type of the values
124
+ * @param first First value
125
+ * @param last Last value
126
+ * @param count Number of values in a row or column, depending on the axis
127
+ * @param axis Axis to operate over
128
+ * @return Operator with linearly-spaced values
129
+ */
130
+ template <typename T = float >
131
+ inline auto linspace (T first, T last, index_t count, int axis = 0 )
132
+ {
133
+ const T firsts[] = {first};
134
+ const T lasts[] = {last};
135
+ return linspace (firsts, lasts, count, axis);
136
+ }
137
+
73
138
/* *
74
139
* @brief Create a linearly-spaced range of values
75
140
*
76
141
* Creates a set of values using a start and end that are linearly-
77
142
* spaced apart over the set of values. Distance is determined
78
143
* by the shape and selected dimension.
79
144
*
80
- * @tparam T Operator type
81
145
* @tparam Dim Dimension to operate over
82
- * @tparam ShapeType Shape type
83
- * @param s Shape object
146
+ * @tparam NUM_RC Rank of shape
147
+ * @tparam T Operator type
148
+ * @param s Array of sizes
84
149
* @param first First value
85
150
* @param last Last value
86
151
* @return Operator with linearly-spaced values
87
152
*/
88
- template <int Dim, typename ShapeType, typename T,
89
- std::enable_if_t <!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool > = true >
90
- inline auto linspace (ShapeType &&s, T first, T last)
91
- {
92
- constexpr int RANK = cuda::std::tuple_size<std::decay_t <ShapeType>>::value;
93
- static_assert (RANK > Dim);
94
- auto count = *(s.begin () + Dim);
95
- detail::LinspaceOp<T> l (first, last, count);
96
- return detail::matxGenerator1D_t<detail::LinspaceOp<T>, Dim, ShapeType>(std::forward<ShapeType>(s), l);
97
- }
153
+ template <int Dim, int NUM_RC, typename T>
154
+ [[deprecated(" Use matx::linspace(T first, T last, index_t count, int axis = 0) instead." )]]
155
+ inline auto linspace([[maybe_unused]]const index_t (&s)[NUM_RC], T first, T last)
156
+ {
157
+ const T firsts[] = {first};
158
+ const T lasts[] = {last};
159
+ return linspace (firsts, lasts, NUM_RC, 0 );
160
+ }
98
161
99
162
/* *
100
163
* @brief Create a linearly-spaced range of values
@@ -103,17 +166,24 @@ namespace matx
103
166
* spaced apart over the set of values. Distance is determined
104
167
* by the shape and selected dimension.
105
168
*
106
- * @tparam Dim Dimension to operate over
107
- * @tparam RANK Rank of shape
108
169
* @tparam T Operator type
109
- * @param s Array of sizes
170
+ * @tparam Dim Dimension to operate over
171
+ * @tparam ShapeType Shape type
172
+ * @param s Shape object
110
173
* @param first First value
111
174
* @param last Last value
112
175
* @return Operator with linearly-spaced values
113
176
*/
114
- template <int Dim, int RANK, typename T>
115
- inline auto linspace (const index_t (&s)[RANK], T first, T last)
116
- {
117
- return linspace<Dim>(detail::to_array (s), first, last);
118
- }
177
+ template <int Dim, typename ShapeType, typename T,
178
+ std::enable_if_t <!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool > = true >
179
+ [[deprecated(" Use matx::linspace(T first, T last, index_t count, int axis = 0) instead." )]]
180
+ inline auto linspace(ShapeType &&s, T first, T last)
181
+ {
182
+ constexpr int NUM_RC = cuda::std::tuple_size<std::decay_t <ShapeType>>::value;
183
+ static_assert (NUM_RC > Dim);
184
+ auto count = *(s.begin () + Dim);
185
+ const T firsts[] = {first};
186
+ const T lasts[] = {last};
187
+ return linspace (firsts, lasts, count, 0 );
188
+ }
119
189
} // end namespace matx
0 commit comments