Skip to content

Commit

Permalink
[MPS] Fix abs for complex types (#126096)
Browse files Browse the repository at this point in the history
[MPS] Fix `abs` for complex types (#125662)

By calling `realPartOfTensor:` if input type is complex on Sonoma and fall back to `at::view_as_real` trick on Ventura.

Split `unary_op` template into `unary_op` and `unary_op_noresize`, which skips resize and empty checks

Marked `abs`, `isclose` and `nn.functional.softsign` OpInfo tests as supported by complex types

Fixes #125135

Pull Request resolved: #125662
Approved by: https://github.com/kulinseth

(cherry picked from commit 0fd1fc1)

Co-authored-by: Nikita Shulga <[email protected]>
  • Loading branch information
pytorchbot and malfet committed May 13, 2024
1 parent 75e01e7 commit d983cb7
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 28 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/native/mps/MPSGraphSonomaOps.h
Expand Up @@ -25,6 +25,10 @@ typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode)
-(MPSGraphTensor * _Nonnull) conjugateWithTensor:(MPSGraphTensor * _Nonnull) tensor
name:(NSString * _Nullable) name;

-(MPSGraphTensor * _Nonnull) realPartOfTensor:(MPSGraphTensor * _Nonnull) tensor
name:(NSString * _Nullable) name;


-(MPSGraphTensor * _Nonnull) fastFourierTransformWithTensor:(MPSGraphTensor * _Nonnull) tensor
axes:(NSArray<NSNumber *> * _Nonnull) axes
descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor
Expand Down
80 changes: 54 additions & 26 deletions aten/src/ATen/native/mps/operations/UnaryOps.mm
Expand Up @@ -75,23 +75,10 @@ static bool is_empty_tensor(const Tensor& self) {
return self.numel() == 0;
}

static void unary_op(const Tensor& self,
const Tensor& output_,
std::string op_name,
UnaryOpBlock unaryBlock,
is_noop_p is_noop = is_empty_tensor) {
static void unary_op_noresize(const Tensor& self, const Tensor& output_, std::string op_name, UnaryOpBlock unaryBlock) {
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte),
"MPS support unary op with uint8 natively starting from macOS 13.0");

if (!output_.is_same_size(self)) {
output_.resize_(self.sizes());
}

if (is_noop(self)) {
output_.copy_(self);
return;
}

auto output = output_;
bool needsCopyToOutput = false;
if (output.storage_offset() || !output.is_contiguous()) {
Expand Down Expand Up @@ -139,6 +126,23 @@ static void unary_op(const Tensor& self,
}
}

static void unary_op(const Tensor& self,
const Tensor& output_,
std::string op_name,
UnaryOpBlock unaryBlock,
is_noop_p is_noop = is_empty_tensor) {
if (!output_.is_same_size(self)) {
output_.resize_(self.sizes());
}

if (is_noop(self)) {
output_.copy_(self);
return;
}

unary_op_noresize(self, output_, op_name, unaryBlock);
}

MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
// Rounding is a no-op for integral types, and also a reasonable workaround
// For MPSGraph bug on Apple Silicon, that throws `Function floorOp_i64 was not found in the library`
Expand Down Expand Up @@ -168,6 +172,12 @@ static void unary_op(const Tensor& self,
return [mpsGraph logarithmWithTensor:addedTensor name:nil];
}

static MPSGraphTensor* lengthOfComplexAsReal(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
auto squares = [mpsGraph squareWithTensor:inputTensor name:nil];
auto sumSquares = [mpsGraph reductionSumWithTensor:squares axis:-1 name:nil];
return [mpsGraph squareRootWithTensor:sumSquares name:nil];
}

} // namespace mps

TORCH_IMPL_FUNC(trunc_out_mps)(const Tensor& self, const Tensor& output) {
Expand Down Expand Up @@ -226,14 +236,6 @@ static void unary_op(const Tensor& self,
}); \
}

#define CREATE_MPS_UNARY_TORCH_IMPL_FUNC(func_out, func_stub) \
Tensor& func_out(const Tensor& self, Tensor& output) { \
mps::unary_op(self, output, #func_out, ^MPSGraphTensor*(MPSGraph * mpsGraph, MPSGraphTensor * inputTensor) { \
return [mpsGraph func_stub##WithTensor:inputTensor name:nil]; \
}); \
return output; \
}

CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp_out_mps, exponent)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp2_out_mps, exponentBase2)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(reciprocal_out_mps, reciprocal)
Expand All @@ -257,7 +259,35 @@ static void unary_op(const Tensor& self,
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(acosh_out_mps, acosh)
CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atanh_out_mps, atanh)

CREATE_MPS_UNARY_TORCH_IMPL_FUNC(abs_out_mps, absolute)
Tensor& abs_out_mps(const Tensor& self, Tensor& output) {
using namespace mps;

if (!output.is_same_size(self)) {
output.resize_(self.sizes());
}

if (self.numel() == 0) {
return output;
}

if (supportsComplex() || !self.is_complex()) {
unary_op_noresize(self, output, "abs_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
auto rc = [mpsGraph absoluteWithTensor:inputTensor name:nil];
if (self.is_complex()) {
rc = [mpsGraph realPartOfTensor:rc name:nil];
}
return rc;
});
} else {
Tensor realInput = at::view_as_real(self);
unary_op_noresize(
realInput, output, "abs_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
auto rc = lengthOfComplexAsReal(mpsGraph, inputTensor);
return [mpsGraph reshapeTensor:rc withShape:getMPSShape(output) name:nil];
});
}
return output;
}

Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
auto bool_self = self.to(ScalarType::Bool);
Expand Down Expand Up @@ -484,9 +514,7 @@ static void cumulative_op_impl(const Tensor& self,
Tensor realOutput = at::view_as_real(output);

auto complex_sgn_op = [&](MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) -> MPSGraphTensor* {
MPSGraphTensor* squares = [mpsGraph squareWithTensor:inputTensor name:nil];
MPSGraphTensor* sumSquares = [mpsGraph reductionSumWithTensor:squares axis:-1 name:nil];
MPSGraphTensor* norm = [mpsGraph squareRootWithTensor:sumSquares name:nil];
MPSGraphTensor* norm = mps::lengthOfComplexAsReal(mpsGraph, inputTensor);
MPSGraphTensor* zero = [mpsGraph constantWithScalar:0.0 dataType:norm.dataType];
MPSGraphTensor* isZero = [mpsGraph equalWithPrimaryTensor:norm secondaryTensor:zero name:nil];
MPSGraphTensor* sgnTensor = [mpsGraph divisionWithPrimaryTensor:inputTensor secondaryTensor:norm name:nil];
Expand Down
7 changes: 5 additions & 2 deletions test/test_mps.py
Expand Up @@ -239,6 +239,7 @@ def mps_ops_modifier(ops):
'__radd__',
'__rmul__',
'__getitem__',
'abs',
'add',
'atleast_1d',
'atleast_2d',
Expand Down Expand Up @@ -296,8 +297,8 @@ def mps_ops_modifier(ops):
'narrow_copy',
'nn.functional.conv1d',
'nn.functional.conv_transpose1d',
'nn.functional.padcircular',
'nn.functional.feature_alpha_dropoutwithout_train',
'nn.functional.padcircular',
'nn.functional.unfold',
'ones',
'outer',
Expand Down Expand Up @@ -392,6 +393,7 @@ def mps_ops_modifier(ops):
'half',
'hstack',
'int',
'isclose',
'isnan',
'ldexp',
'log10',
Expand All @@ -412,12 +414,13 @@ def mps_ops_modifier(ops):
'mean',
'ne',
'neg',
'nn.functional.rms_norm',
'nn.functional.padconstant',
'nn.functional.padreflect',
'nn.functional.padreplicate',
'nn.functional.pixel_shuffle',
'nn.functional.pixel_unshuffle',
'nn.functional.rms_norm',
'nn.functional.softsign',
'nn.functional.tanhshrink',
'nonzero',
'prod',
Expand Down

0 comments on commit d983cb7

Please sign in to comment.