Skip to content

Commit 3ae4b13

Browse files
authored
Update linspace for more of a pythonic syntax (#935)
1 parent 419a1c3 commit 3ae4b13

File tree

12 files changed

+198
-89
lines changed

12 files changed

+198
-89
lines changed

docs_input/api/creation/operators/linspace.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ linspace
44
========
55

66
Return a range of linearly-spaced numbers using first and last value. The step size is
7-
determined by the shape.
7+
determined by the `count` parameter. `axis` (either 0 or 1) can be used to make the increasing
8+
sequence along the specified axis.
89

9-
.. doxygenfunction:: matx::linspace(ShapeType &&s, T first, T last)
10-
.. doxygenfunction:: matx::linspace(const index_t (&s)[RANK], T first, T last)
10+
.. doxygenfunction:: matx::linspace(T first, T last, index_t count, int axis = 0)
11+
.. doxygenfunction:: matx::linspace(const T (&firsts)[NUM_RC], const T (&lasts)[NUM_RC], index_t count, int axis = 0)
1112

1213
Examples
1314
~~~~~~~~

examples/spectrogram.cu

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
7272
constexpr uint32_t num_iterations = 100;
7373
float time_ms;
7474

75-
cuda::std::array<index_t, 1> num_samps{N};
76-
cuda::std::array<index_t, 1> half_win{nfft / 2 + 1};
77-
cuda::std::array<index_t, 1> s_time_shape{(N - noverlap) / nstep};
7875

7976
auto time = make_tensor<float>({N});
8077
auto modulation = make_tensor<float>({N});
@@ -88,7 +85,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
8885

8986
// Set up all static buffers
9087
// time = np.arange(N) / float(fs)
91-
(time = linspace<0>(num_samps, 0.0f, static_cast<float>(N) - 1.0f) / fs)
88+
(time = linspace(0.0f, static_cast<float>(N) - 1.0f, N) / fs)
9289
.run(exec);
9390
// mod = 500 * np.cos(2*np.pi*0.25*time)
9491
(modulation = 500.f * cos(2.f * static_cast<typename complex::value_type>(M_PI) * 0.25f * time)).run(exec);
@@ -108,7 +105,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
108105

109106
// DFT Sample Frequencies (rfftfreq)
110107
(freqs = (1.0f / (static_cast<float>(nfft) * 1.f / fs)) *
111-
linspace<0>(half_win, 0.0f, static_cast<float>(nfft) / 2.0f))
108+
linspace(0.0f, static_cast<float>(nfft) / 2.0f, nfft / 2 + 1))
112109
.run(exec);
113110

114111
// Create overlapping matrix of segments.
@@ -122,8 +119,8 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
122119
auto Sxx = fftStackedMatrix.RealView().Permute({1, 0});
123120

124121
// Spectral time axis
125-
(s_time = linspace<0>(s_time_shape, static_cast<float>(nperseg) / 2.0f,
126-
static_cast<float>(N - nperseg) / 2.0f + 1) /
122+
(s_time = linspace(static_cast<float>(nperseg) / 2.0f,
123+
static_cast<float>(N - nperseg) / 2.0f + 1, (N - noverlap) / nstep) /
127124
fs)
128125
.run(exec);
129126

examples/spectrogram_graph.cu

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,23 +78,19 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
7878
constexpr uint32_t num_iterations = 20;
7979
float time_ms;
8080

81-
cuda::std::array<index_t, 1> num_samps{N};
82-
cuda::std::array<index_t, 1> half_win{nfft / 2 + 1};
83-
cuda::std::array<index_t, 1> s_time_shape{(N - noverlap) / nstep};
84-
8581
tensor_t<float, 1> time({N});
8682
tensor_t<float, 1> modulation({N});
8783
tensor_t<float, 1> carrier({N});
8884
tensor_t<float, 1> noise({N});
8985
tensor_t<float, 1> x({N});
90-
auto freqs = make_tensor<float>(half_win);
86+
auto freqs = make_tensor<float>({nfft / 2 + 1});
9187
tensor_t<complex, 2> fftStackedMatrix(
9288
{(N - noverlap) / nstep, nfft / 2 + 1});
9389
tensor_t<float, 1> s_time({(N - noverlap) / nstep});
9490

9591
// Set up all static buffers
9692
// time = np.arange(N) / float(fs)
97-
(time = linspace<0>(num_samps, 0.0f, static_cast<float>(N) - 1.0f) / fs)
93+
(time = linspace(0.0f, static_cast<float>(N) - 1.0f, N) / fs)
9894
.run(exec);
9995
// mod = 500 * np.cos(2*np.pi*0.25*time)
10096
(modulation = 500 * cos(2 * M_PI * 0.25 * time)).run(exec);
@@ -115,7 +111,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
115111

116112
// DFT Sample Frequencies (rfftfreq)
117113
(freqs = (1.0 / (static_cast<float>(nfft) * 1 / fs)) *
118-
linspace<0>(half_win, 0.0f, static_cast<float>(nfft) / 2.0f))
114+
linspace(0.0f, static_cast<float>(nfft) / 2.0f, nfft / 2 + 1))
119115
.run(exec);
120116

121117
// Create overlapping matrix of segments.
@@ -129,8 +125,8 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
129125
[[maybe_unused]] auto Sxx = fftStackedMatrix.RealView().Permute({1, 0});
130126

131127
// Spectral time axis
132-
(s_time = linspace<0>(s_time_shape, static_cast<float>(nperseg) / 2.0f,
133-
static_cast<float>(N - nperseg) / 2.0f + 1) /
128+
(s_time = linspace(static_cast<float>(nperseg) / 2.0f,
129+
static_cast<float>(N - nperseg) / 2.0f + 1, (N - noverlap) / nstep) /
134130
fs)
135131
.run(exec);
136132

include/matx/generators/chirp.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,7 @@ namespace matx
226226
template <typename TimeType, typename FreqType>
227227
inline auto chirp(index_t num, TimeType last, FreqType f0, TimeType t1, FreqType f1, ChirpMethod method = ChirpMethod::CHIRP_METHOD_LINEAR)
228228
{
229-
cuda::std::array<index_t, 1> shape = {num};
230-
auto space = linspace<0>(std::move(shape), (TimeType)0, last);
229+
auto space = linspace((TimeType)0, last, num);
231230
return chirp(space, f0, t1, f1, method);
232231
}
233232

@@ -263,8 +262,7 @@ namespace matx
263262
template <typename TimeType, typename FreqType>
264263
inline auto cchirp(index_t num, TimeType last, FreqType f0, TimeType t1, FreqType f1, ChirpMethod method = ChirpMethod::CHIRP_METHOD_LINEAR)
265264
{
266-
cuda::std::array<index_t, 1> shape = {num};
267-
auto space = linspace<0>(std::move(shape), (TimeType)0, last);
265+
auto space = linspace((TimeType)0, last, num);
268266
return cchirp(space, f0, t1, f1, method);
269267
}
270268

include/matx/generators/linspace.h

Lines changed: 107 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -37,64 +37,127 @@
3737
namespace matx
3838
{
3939
namespace detail {
40-
template <class T> class LinspaceOp {
40+
template <class T, int NUM_RC> class LinspaceOp : public BaseOp<LinspaceOp<T, NUM_RC>> {
4141
private:
42-
Range<T> range_;
43-
42+
cuda::std::array<T, NUM_RC> steps_;
43+
cuda::std::array<T, NUM_RC> firsts_;
44+
int axis_;
45+
index_t count_;
4446
public:
4547
using value_type = T;
4648
using matxop = bool;
4749

4850
__MATX_INLINE__ std::string str() const { return "linspace"; }
4951

52+
static inline constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() { return NUM_RC; }
5053

51-
inline LinspaceOp(T first, T last, index_t count)
54+
inline LinspaceOp(const T (&firsts)[NUM_RC], const T (&lasts)[NUM_RC], index_t count, int axis)
5255
{
53-
#ifdef __CUDA_ARCH__
54-
range_ = Range<T>{first, (last - first) / static_cast<T>(count - 1)};
55-
#else
56-
// Host has no support for most half precision operators/intrinsics
57-
if constexpr (is_matx_half_v<T>) {
58-
range_ = Range<T>{static_cast<float>(first),
59-
(static_cast<float>(last) - static_cast<float>(first)) /
60-
static_cast<float>(count - 1)};
56+
axis_ = axis;
57+
count_ = count;
58+
for (int i = 0; i < NUM_RC; ++i) {
59+
firsts_[i] = firsts[i];
60+
steps_[i] = (lasts[i] - firsts[i]) / static_cast<T>(count - 1);
6161
}
62-
else {
63-
range_ = Range<T>{first, (last - first) / static_cast<T>(count - 1)};
62+
}
63+
64+
template <typename... Is>
65+
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ T operator()(Is... indices) const {
66+
static_assert(sizeof...(indices) == NUM_RC, "Number of indices incorrect in linspace");
67+
cuda::std::array idx{indices...};
68+
if constexpr (sizeof...(indices) == 1) {
69+
return firsts_[0] + steps_[0] * static_cast<T>(idx[0]);
70+
} else {
71+
if (axis_ == 0) {
72+
return firsts_[idx[1]] + steps_[idx[1]] * static_cast<T>(idx[0]);
73+
} else {
74+
return firsts_[idx[0]] + steps_[idx[0]] * static_cast<T>(idx[1]);
75+
}
6476
}
65-
#endif
6677
}
6778

68-
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ T operator()(index_t idx) const { return range_(idx); }
79+
constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t Size(int dim) const
80+
{
81+
if constexpr (NUM_RC == 1) {
82+
return count_;
83+
} else {
84+
if (dim != axis_) {
85+
return NUM_RC;
86+
} else {
87+
return count_;
88+
}
89+
}
90+
}
6991
};
7092
}
7193

7294

95+
/**
96+
* @brief Create a matrix linearly-spaced range of values
97+
*
98+
* Creates a set of values using starts and stops that are linearly-
99+
* spaced apart over the set of values. Distance is determined
100+
* by the count parameter
101+
*
102+
* @tparam NUM_RC Number of rows or columns, depending on the axis
103+
* @tparam T Type of the values
104+
* @param firsts First values
105+
* @param lasts Last values
106+
* @param count Number of values in a row or column, depending on the axis
107+
* @param axis Axis to operate over
108+
* @return Operator with linearly-spaced values
109+
*/
110+
template <int NUM_RC, typename T = float>
111+
inline auto linspace(const T (&firsts)[NUM_RC], const T (&lasts)[NUM_RC], index_t count, int axis = 0)
112+
{
113+
return detail::LinspaceOp<T, NUM_RC>(firsts, lasts, count, axis);
114+
}
115+
116+
/**
117+
* @brief Create a linearly-spaced vector of values
118+
*
119+
* Creates a set of values using startsand stop that are linearly-
120+
* spaced apart over the set of values. Distance is determined
121+
* by the count parameter
122+
*
123+
* @tparam T Type of the values
124+
* @param first First value
125+
* @param last Last value
126+
* @param count Number of values in a row or column, depending on the axis
127+
* @param axis Axis to operate over
128+
* @return Operator with linearly-spaced values
129+
*/
130+
template <typename T = float>
131+
inline auto linspace(T first, T last, index_t count, int axis = 0)
132+
{
133+
const T firsts[] = {first};
134+
const T lasts[] = {last};
135+
return linspace(firsts, lasts, count, axis);
136+
}
137+
73138
/**
74139
* @brief Create a linearly-spaced range of values
75140
*
76141
* Creates a set of values using a start and end that are linearly-
77142
* spaced apart over the set of values. Distance is determined
78143
* by the shape and selected dimension.
79144
*
80-
* @tparam T Operator type
81145
* @tparam Dim Dimension to operate over
82-
* @tparam ShapeType Shape type
83-
* @param s Shape object
146+
* @tparam NUM_RC Rank of shape
147+
* @tparam T Operator type
148+
* @param s Array of sizes
84149
* @param first First value
85150
* @param last Last value
86151
* @return Operator with linearly-spaced values
87152
*/
88-
template <int Dim, typename ShapeType, typename T,
89-
std::enable_if_t<!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
90-
inline auto linspace(ShapeType &&s, T first, T last)
91-
{
92-
constexpr int RANK = cuda::std::tuple_size<std::decay_t<ShapeType>>::value;
93-
static_assert(RANK > Dim);
94-
auto count = *(s.begin() + Dim);
95-
detail::LinspaceOp<T> l(first, last, count);
96-
return detail::matxGenerator1D_t<detail::LinspaceOp<T>, Dim, ShapeType>(std::forward<ShapeType>(s), l);
97-
}
153+
template <int Dim, int NUM_RC, typename T>
154+
[[deprecated("Use matx::linspace(T first, T last, index_t count, int axis = 0) instead.")]]
155+
inline auto linspace([[maybe_unused]]const index_t (&s)[NUM_RC], T first, T last)
156+
{
157+
const T firsts[] = {first};
158+
const T lasts[] = {last};
159+
return linspace(firsts, lasts, NUM_RC, 0);
160+
}
98161

99162
/**
100163
* @brief Create a linearly-spaced range of values
@@ -103,17 +166,24 @@ namespace matx
103166
* spaced apart over the set of values. Distance is determined
104167
* by the shape and selected dimension.
105168
*
106-
* @tparam Dim Dimension to operate over
107-
* @tparam RANK Rank of shape
108169
* @tparam T Operator type
109-
* @param s Array of sizes
170+
* @tparam Dim Dimension to operate over
171+
* @tparam ShapeType Shape type
172+
* @param s Shape object
110173
* @param first First value
111174
* @param last Last value
112175
* @return Operator with linearly-spaced values
113176
*/
114-
template <int Dim, int RANK, typename T>
115-
inline auto linspace(const index_t (&s)[RANK], T first, T last)
116-
{
117-
return linspace<Dim>(detail::to_array(s), first, last);
118-
}
177+
template <int Dim, typename ShapeType, typename T,
178+
std::enable_if_t<!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
179+
[[deprecated("Use matx::linspace(T first, T last, index_t count, int axis = 0) instead.")]]
180+
inline auto linspace(ShapeType &&s, T first, T last)
181+
{
182+
constexpr int NUM_RC = cuda::std::tuple_size<std::decay_t<ShapeType>>::value;
183+
static_assert(NUM_RC > Dim);
184+
auto count = *(s.begin() + Dim);
185+
const T firsts[] = {first};
186+
const T lasts[] = {last};
187+
return linspace(firsts, lasts, count, 0);
188+
}
119189
} // end namespace matx

0 commit comments

Comments
 (0)