This repository was archived by the owner on Oct 23, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +47
-2
lines changed
react-native-pytorch-core/cxx Expand file tree Collapse file tree 2 files changed +47
-2
lines changed Original file line number Diff line number Diff line change @@ -227,7 +227,10 @@ jsi::Value dataImpl(
227
227
// BigIntArray
228
228
if (type == torch_::kInt64 ) {
229
229
throw jsi::JSError (
230
- runtime, " the property 'data' of BigInt Tensor is not supported." );
230
+ runtime,
231
+ " the property 'data' for a tensor of dtype torch.int64 is not"
232
+ " supported. Work around this with .to({dtype: torch.int32})"
233
+ " This might alter the tensor values." );
231
234
}
232
235
233
236
std::string typedArrayName;
@@ -296,6 +299,16 @@ jsi::Value itemImpl(
296
299
size_t count) {
297
300
auto thiz =
298
301
thisValue.asObject (runtime).asHostObject <TensorHostObject>(runtime);
302
+
303
+ // TODO(T113480543): enable BigInt once Hermes supports it
304
+ if (thiz->tensor .dtype () == torch_::kInt64 ) {
305
+ throw jsi::JSError (
306
+ runtime,
307
+ " the property 'item' for a tensor of dtype torch.int64 is not"
308
+ " supported. Work around this with .to({dtype: torch.int32})"
309
+ " This might alter the tensor values." );
310
+ }
311
+
299
312
auto scalar = thiz->tensor .item ();
300
313
if (scalar.isIntegral (/* includeBool=*/ false )) {
301
314
return jsi::Value (scalar.toInt ());
Original file line number Diff line number Diff line change @@ -245,7 +245,20 @@ TEST_F(TorchliveTensorRuntimeTest, TensorDataTest) {
245
245
const tensor = torch.tensor([128, 255], {dtype: torch.long});
246
246
tensor.data();
247
247
)" ;
248
- EXPECT_THROW (eval (tensorWithDtypeAsInt64), facebook::jsi::JSError);
248
+ EXPECT_THROW (
249
+ {
250
+ try {
251
+ eval (tensorWithDtypeAsInt64);
252
+ } catch (const facebook::jsi::JSError& e) {
253
+ EXPECT_TRUE (
254
+ std::string (e.what ()).find (
255
+ " property 'data' for a tensor of dtype torch.int64 is not supported." ) !=
256
+ std::string::npos)
257
+ << e.what ();
258
+ throw ;
259
+ }
260
+ },
261
+ facebook::jsi::JSError);
249
262
}
250
263
251
264
TEST_F (TorchliveTensorRuntimeTest, TensorIndexing) {
@@ -746,6 +759,25 @@ TEST_F(TorchliveTensorRuntimeTest, TensorItemTest) {
746
759
tensor.item();
747
760
)" ;
748
761
EXPECT_THROW (eval (tensorItemForMultiElementTensor), facebook::jsi::JSError);
762
+
763
+ std::string tensorItemInt64 = R"(
764
+ const tensor = torch.tensor(1, {dtype: torch.int64});
765
+ tensor.item();
766
+ )" ;
767
+ EXPECT_THROW (
768
+ {
769
+ try {
770
+ eval (tensorItemInt64);
771
+ } catch (const facebook::jsi::JSError& e) {
772
+ EXPECT_TRUE (
773
+ std::string (e.what ()).find (
774
+ " property 'item' for a tensor of dtype torch.int64 is not supported." ) !=
775
+ std::string::npos)
776
+ << e.what ();
777
+ throw ;
778
+ }
779
+ },
780
+ facebook::jsi::JSError);
749
781
}
750
782
751
783
TEST_F (TorchliveTensorRuntimeTest, TensorSqrtTest) {
You can’t perform that action at this time.
0 commit comments