@@ -75,23 +75,10 @@ static bool is_empty_tensor(const Tensor& self) {
75
75
return self.numel () == 0 ;
76
76
}
77
77
78
- static void unary_op (const Tensor& self,
79
- const Tensor& output_,
80
- std::string op_name,
81
- UnaryOpBlock unaryBlock,
82
- is_noop_p is_noop = is_empty_tensor) {
78
+ static void unary_op_noresize (const Tensor& self, const Tensor& output_, std::string op_name, UnaryOpBlock unaryBlock) {
83
79
TORCH_CHECK (!(!is_macos_13_or_newer () && self.scalar_type () == ScalarType::Byte),
84
80
" MPS support unary op with uint8 natively starting from macOS 13.0" );
85
81
86
- if (!output_.is_same_size (self)) {
87
- output_.resize_ (self.sizes ());
88
- }
89
-
90
- if (is_noop (self)) {
91
- output_.copy_ (self);
92
- return ;
93
- }
94
-
95
82
auto output = output_;
96
83
bool needsCopyToOutput = false ;
97
84
if (output.storage_offset () || !output.is_contiguous ()) {
@@ -139,6 +126,23 @@ static void unary_op(const Tensor& self,
139
126
}
140
127
}
141
128
129
+ static void unary_op (const Tensor& self,
130
+ const Tensor& output_,
131
+ std::string op_name,
132
+ UnaryOpBlock unaryBlock,
133
+ is_noop_p is_noop = is_empty_tensor) {
134
+ if (!output_.is_same_size (self)) {
135
+ output_.resize_ (self.sizes ());
136
+ }
137
+
138
+ if (is_noop (self)) {
139
+ output_.copy_ (self);
140
+ return ;
141
+ }
142
+
143
+ unary_op_noresize (self, output_, op_name, unaryBlock);
144
+ }
145
+
142
146
MPSGraphTensor* trunc_tensor (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
143
147
// Rounding is a no-op for integral types, and also a reasonable workaround
144
148
// For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
@@ -168,6 +172,12 @@ static void unary_op(const Tensor& self,
168
172
return [mpsGraph logarithmWithTensor: addedTensor name: nil ];
169
173
}
170
174
175
+ static MPSGraphTensor* lengthOfComplexAsReal (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
176
+ auto squares = [mpsGraph squareWithTensor: inputTensor name: nil ];
177
+ auto sumSquares = [mpsGraph reductionSumWithTensor: squares axis: -1 name: nil ];
178
+ return [mpsGraph squareRootWithTensor: sumSquares name: nil ];
179
+ }
180
+
171
181
} // namespace mps
172
182
173
183
TORCH_IMPL_FUNC (trunc_out_mps)(const Tensor& self, const Tensor& output) {
@@ -226,14 +236,6 @@ static void unary_op(const Tensor& self,
226
236
}); \
227
237
}
228
238
229
- #define CREATE_MPS_UNARY_TORCH_IMPL_FUNC (func_out, func_stub ) \
230
- Tensor& func_out (const Tensor& self, Tensor& output) { \
231
- mps::unary_op (self, output, #func_out, ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
232
- return [mpsGraph func_stub##WithTensor:inputTensor name: nil ]; \
233
- }); \
234
- return output; \
235
- }
236
-
237
239
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC (exp_out_mps, exponent)
238
240
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC (exp2_out_mps, exponentBase2)
239
241
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC (reciprocal_out_mps, reciprocal)
@@ -257,7 +259,35 @@ static void unary_op(const Tensor& self,
257
259
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC (acosh_out_mps, acosh)
258
260
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC (atanh_out_mps, atanh)
259
261
260
- CREATE_MPS_UNARY_TORCH_IMPL_FUNC (abs_out_mps, absolute)
262
+ Tensor& abs_out_mps (const Tensor& self, Tensor& output) {
263
+ using namespace mps ;
264
+
265
+ if (!output.is_same_size (self)) {
266
+ output.resize_ (self.sizes ());
267
+ }
268
+
269
+ if (self.numel () == 0 ) {
270
+ return output;
271
+ }
272
+
273
+ if (supportsComplex () || !self.is_complex ()) {
274
+ unary_op_noresize (self, output, " abs_out_mps" , ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
275
+ auto rc = [mpsGraph absoluteWithTensor: inputTensor name: nil ];
276
+ if (self.is_complex ()) {
277
+ rc = [mpsGraph realPartOfTensor: rc name: nil ];
278
+ }
279
+ return rc;
280
+ });
281
+ } else {
282
+ Tensor realInput = at::view_as_real (self);
283
+ unary_op_noresize (
284
+ realInput, output, " abs_out_mps" , ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
285
+ auto rc = lengthOfComplexAsReal (mpsGraph, inputTensor);
286
+ return [mpsGraph reshapeTensor: rc withShape: getMPSShape (output) name: nil ];
287
+ });
288
+ }
289
+ return output;
290
+ }
261
291
262
292
Tensor& logical_not_out_mps (const Tensor& self, Tensor& output) {
263
293
auto bool_self = self.to (ScalarType::Bool);
@@ -484,9 +514,7 @@ static void cumulative_op_impl(const Tensor& self,
484
514
Tensor realOutput = at::view_as_real (output);
485
515
486
516
auto complex_sgn_op = [&](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) -> MPSGraphTensor* {
487
- MPSGraphTensor* squares = [mpsGraph squareWithTensor: inputTensor name: nil ];
488
- MPSGraphTensor* sumSquares = [mpsGraph reductionSumWithTensor: squares axis: -1 name: nil ];
489
- MPSGraphTensor* norm = [mpsGraph squareRootWithTensor: sumSquares name: nil ];
517
+ MPSGraphTensor* norm = mps::lengthOfComplexAsReal (mpsGraph, inputTensor);
490
518
MPSGraphTensor* zero = [mpsGraph constantWithScalar: 0.0 dataType: norm.dataType];
491
519
MPSGraphTensor* isZero = [mpsGraph equalWithPrimaryTensor: norm secondaryTensor: zero name: nil ];
492
520
MPSGraphTensor* sgnTensor = [mpsGraph divisionWithPrimaryTensor: inputTensor secondaryTensor: norm name: nil ];
0 commit comments