Skip to content

Commit d983cb7

Browse files
pytorchbotmalfet
andauthored
[MPS] Fix abs for complex types (#126096)
[MPS] Fix `abs` for complex types (#125662) By calling `realPartOfTensor:` if input type is complex on Sonoma and fall back to `at::view_as_real` trick on Ventura. Split `unary_op` template into `unary_op` and `unary_op_noresize`, which skips resize and empty checks Marked `abs`, `isclose` and `nn.functional.softsign` OpInfo tests as supported by complex types Fixes #125135 Pull Request resolved: #125662 Approved by: https://github.com/kulinseth (cherry picked from commit 0fd1fc1) Co-authored-by: Nikita Shulga <[email protected]>
1 parent 75e01e7 commit d983cb7

File tree

3 files changed

+63
-28
lines changed

3 files changed

+63
-28
lines changed

aten/src/ATen/native/mps/MPSGraphSonomaOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode)
2525
-(MPSGraphTensor * _Nonnull) conjugateWithTensor:(MPSGraphTensor * _Nonnull) tensor
2626
name:(NSString * _Nullable) name;
2727

28+
-(MPSGraphTensor * _Nonnull) realPartOfTensor:(MPSGraphTensor * _Nonnull) tensor
29+
name:(NSString * _Nullable) name;
30+
31+
2832
-(MPSGraphTensor * _Nonnull) fastFourierTransformWithTensor:(MPSGraphTensor * _Nonnull) tensor
2933
axes:(NSArray<NSNumber *> * _Nonnull) axes
3034
descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor

aten/src/ATen/native/mps/operations/UnaryOps.mm

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,23 +75,10 @@ static bool is_empty_tensor(const Tensor& self) {
7575
return self.numel() == 0;
7676
}
7777

78-
static void unary_op(const Tensor& self,
79-
const Tensor& output_,
80-
std::string op_name,
81-
UnaryOpBlock unaryBlock,
82-
is_noop_p is_noop = is_empty_tensor) {
78+
static void unary_op_noresize(const Tensor& self, const Tensor& output_, std::string op_name, UnaryOpBlock unaryBlock) {
8379
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte),
8480
"MPS support unary op with uint8 natively starting from macOS 13.0");
8581

86-
if (!output_.is_same_size(self)) {
87-
output_.resize_(self.sizes());
88-
}
89-
90-
if (is_noop(self)) {
91-
output_.copy_(self);
92-
return;
93-
}
94-
9582
auto output = output_;
9683
bool needsCopyToOutput = false;
9784
if (output.storage_offset() || !output.is_contiguous()) {
@@ -139,6 +126,23 @@ static void unary_op(const Tensor& self,
139126
}
140127
}
141128

129+
static void unary_op(const Tensor& self,
130+
const Tensor& output_,
131+
std::string op_name,
132+
UnaryOpBlock unaryBlock,
133+
is_noop_p is_noop = is_empty_tensor) {
134+
if (!output_.is_same_size(self)) {
135+
output_.resize_(self.sizes());
136+
}
137+
138+
if (is_noop(self)) {
139+
output_.copy_(self);
140+
return;
141+
}
142+
143+
unary_op_noresize(self, output_, op_name, unaryBlock);
144+
}
145+
142146
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
143147
// Rounding is a no-op for integral types, and also a reasonable workaround
144148
// For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
@@ -168,6 +172,12 @@ static void unary_op(const Tensor& self,
168172
return [mpsGraph logarithmWithTensor:addedTensor name:nil];
169173
}
170174

175+
static MPSGraphTensor* lengthOfComplexAsReal(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
176+
auto squares = [mpsGraph squareWithTensor:inputTensor name:nil];
177+
auto sumSquares = [mpsGraph reductionSumWithTensor:squares axis:-1 name:nil];
178+
return [mpsGraph squareRootWithTensor:sumSquares name:nil];
179+
}
180+
171181
} // namespace mps
172182

173183
TORCH_IMPL_FUNC(trunc_out_mps)(const Tensor& self, const Tensor& output) {
@@ -226,14 +236,6 @@ static void unary_op(const Tensor& self,
226236
}); \
227237
}
228238

229-
#define CREATE_MPS_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \
230-
Tensor& func_out(const Tensor& self, Tensor& output) { \
231-
mps::unary_op(self, output, #func_out, ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
232-
return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \
233-
}); \
234-
return output; \
235-
}
236-
237239
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp_out_mps, exponent)
238240
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp2_out_mps, exponentBase2)
239241
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(reciprocal_out_mps, reciprocal)
@@ -257,7 +259,35 @@ static void unary_op(const Tensor& self,
257259
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(acosh_out_mps, acosh)
258260
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atanh_out_mps, atanh)
259261

260-
CREATE_MPS_UNARY_TORCH_IMPL_FUNC(abs_out_mps, absolute)
262+
Tensor& abs_out_mps(const Tensor& self, Tensor& output) {
263+
using namespace mps;
264+
265+
if (!output.is_same_size(self)) {
266+
output.resize_(self.sizes());
267+
}
268+
269+
if (self.numel() == 0) {
270+
return output;
271+
}
272+
273+
if (supportsComplex() || !self.is_complex()) {
274+
unary_op_noresize(self, output, "abs_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
275+
auto rc = [mpsGraph absoluteWithTensor:inputTensor name:nil];
276+
if (self.is_complex()) {
277+
rc = [mpsGraph realPartOfTensor:rc name:nil];
278+
}
279+
return rc;
280+
});
281+
} else {
282+
Tensor realInput = at::view_as_real(self);
283+
unary_op_noresize(
284+
realInput, output, "abs_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
285+
auto rc = lengthOfComplexAsReal(mpsGraph, inputTensor);
286+
return [mpsGraph reshapeTensor:rc withShape:getMPSShape(output) name:nil];
287+
});
288+
}
289+
return output;
290+
}
261291

262292
Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
263293
auto bool_self = self.to(ScalarType::Bool);
@@ -484,9 +514,7 @@ static void cumulative_op_impl(const Tensor& self,
484514
Tensor realOutput = at::view_as_real(output);
485515

486516
auto complex_sgn_op = [&](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) -> MPSGraphTensor* {
487-
MPSGraphTensor* squares = [mpsGraph squareWithTensor:inputTensor name:nil];
488-
MPSGraphTensor* sumSquares = [mpsGraph reductionSumWithTensor:squares axis:-1 name:nil];
489-
MPSGraphTensor* norm = [mpsGraph squareRootWithTensor:sumSquares name:nil];
517+
MPSGraphTensor* norm = mps::lengthOfComplexAsReal(mpsGraph, inputTensor);
490518
MPSGraphTensor* zero = [mpsGraph constantWithScalar:0.0 dataType:norm.dataType];
491519
MPSGraphTensor* isZero = [mpsGraph equalWithPrimaryTensor:norm secondaryTensor:zero name:nil];
492520
MPSGraphTensor* sgnTensor = [mpsGraph divisionWithPrimaryTensor:inputTensor secondaryTensor:norm name:nil];

test/test_mps.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@ def mps_ops_modifier(ops):
239239
'__radd__',
240240
'__rmul__',
241241
'__getitem__',
242+
'abs',
242243
'add',
243244
'atleast_1d',
244245
'atleast_2d',
@@ -296,8 +297,8 @@ def mps_ops_modifier(ops):
296297
'narrow_copy',
297298
'nn.functional.conv1d',
298299
'nn.functional.conv_transpose1d',
299-
'nn.functional.padcircular',
300300
'nn.functional.feature_alpha_dropoutwithout_train',
301+
'nn.functional.padcircular',
301302
'nn.functional.unfold',
302303
'ones',
303304
'outer',
@@ -392,6 +393,7 @@ def mps_ops_modifier(ops):
392393
'half',
393394
'hstack',
394395
'int',
396+
'isclose',
395397
'isnan',
396398
'ldexp',
397399
'log10',
@@ -412,12 +414,13 @@ def mps_ops_modifier(ops):
412414
'mean',
413415
'ne',
414416
'neg',
415-
'nn.functional.rms_norm',
416417
'nn.functional.padconstant',
417418
'nn.functional.padreflect',
418419
'nn.functional.padreplicate',
419420
'nn.functional.pixel_shuffle',
420421
'nn.functional.pixel_unshuffle',
422+
'nn.functional.rms_norm',
423+
'nn.functional.softsign',
421424
'nn.functional.tanhshrink',
422425
'nonzero',
423426
'prod',

0 commit comments

Comments
 (0)