Skip to content
This repository was archived by the owner on Oct 23, 2023. It is now read-only.

Commit 6f499ca

Browse files
chrisklaiberfacebook-github-bot
authored andcommitted
error when BigInt required for Tensor.item, Tensor.data
Summary: Now that int64 / long are supported by the Dtype type, give users a helpful message when attempting to use them in an unsupported way Reviewed By: raedle Differential Revision: D37387583 fbshipit-source-id: f2e51a96a3724dcb3e605aa3c9439ee85e3944e7
1 parent a7ee796 commit 6f499ca

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

react-native-pytorch-core/cxx/src/torchlive/torch/TensorHostObject.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,10 @@ jsi::Value dataImpl(
227227
// BigIntArray
228228
if (type == torch_::kInt64) {
229229
throw jsi::JSError(
230-
runtime, "the property 'data' of BigInt Tensor is not supported.");
230+
runtime,
231+
"the property 'data' for a tensor of dtype torch.int64 is not"
232+
" supported. Work around this with .to({dtype: torch.int32})"
233+
" This might alter the tensor values.");
231234
}
232235

233236
std::string typedArrayName;
@@ -296,6 +299,16 @@ jsi::Value itemImpl(
296299
size_t count) {
297300
auto thiz =
298301
thisValue.asObject(runtime).asHostObject<TensorHostObject>(runtime);
302+
303+
// TODO(T113480543): enable BigInt once Hermes supports it
304+
if (thiz->tensor.dtype() == torch_::kInt64) {
305+
throw jsi::JSError(
306+
runtime,
307+
"the property 'item' for a tensor of dtype torch.int64 is not"
308+
" supported. Work around this with .to({dtype: torch.int32})"
309+
" This might alter the tensor values.");
310+
}
311+
299312
auto scalar = thiz->tensor.item();
300313
if (scalar.isIntegral(/*includeBool=*/false)) {
301314
return jsi::Value(scalar.toInt());

react-native-pytorch-core/cxx/test/TensorTests.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,20 @@ TEST_F(TorchliveTensorRuntimeTest, TensorDataTest) {
245245
const tensor = torch.tensor([128, 255], {dtype: torch.long});
246246
tensor.data();
247247
)";
248-
EXPECT_THROW(eval(tensorWithDtypeAsInt64), facebook::jsi::JSError);
248+
EXPECT_THROW(
249+
{
250+
try {
251+
eval(tensorWithDtypeAsInt64);
252+
} catch (const facebook::jsi::JSError& e) {
253+
EXPECT_TRUE(
254+
std::string(e.what()).find(
255+
"property 'data' for a tensor of dtype torch.int64 is not supported.") !=
256+
std::string::npos)
257+
<< e.what();
258+
throw;
259+
}
260+
},
261+
facebook::jsi::JSError);
249262
}
250263

251264
TEST_F(TorchliveTensorRuntimeTest, TensorIndexing) {
@@ -746,6 +759,25 @@ TEST_F(TorchliveTensorRuntimeTest, TensorItemTest) {
746759
tensor.item();
747760
)";
748761
EXPECT_THROW(eval(tensorItemForMultiElementTensor), facebook::jsi::JSError);
762+
763+
std::string tensorItemInt64 = R"(
764+
const tensor = torch.tensor(1, {dtype: torch.int64});
765+
tensor.item();
766+
)";
767+
EXPECT_THROW(
768+
{
769+
try {
770+
eval(tensorItemInt64);
771+
} catch (const facebook::jsi::JSError& e) {
772+
EXPECT_TRUE(
773+
std::string(e.what()).find(
774+
"property 'item' for a tensor of dtype torch.int64 is not supported.") !=
775+
std::string::npos)
776+
<< e.what();
777+
throw;
778+
}
779+
},
780+
facebook::jsi::JSError);
749781
}
750782

751783
TEST_F(TorchliveTensorRuntimeTest, TensorSqrtTest) {

0 commit comments

Comments
 (0)