Description
🐛 Describe the bug
Hello all,
I have built executorch on Linux according to the tutorials. I can run the export example:
import torch
from torch.export import export
from executorch.exir import to_edge
class Add(torch.nn.Module):
def __init__(self):
super(Add, self).__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor):
return x + y
aten_dialect = export(Add(), (torch.ones(1), torch.ones(1)))
edge_program = to_edge(aten_dialect)
executorch_program = edge_program.to_executorch()
with open("add.pte", "wb") as file:
file.write(executorch_program.buffer)
And run the exported .pte in Python:
from executorch.runtime import Runtime
from torch import ones
runtime = Runtime.get()
operator_names = runtime.operator_registry.operator_names
program = runtime.load_program("""./add.pte""")
method = program.load_method("forward")
output = method.execute([ones(1), ones(1)])
print(ones(1), output)
However, when I try to run it from C++:
#include <iostream>
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
using namespace ::executorch::extension;
int main() {
// Create a Module.
Module module("add.pte");
// Wrap the input data with a Tensor.
auto tensorX = make_tensor_ptr(1);
auto tensorY = make_tensor_ptr(1);
// Perform an inference.
auto x = module.set_input("forward", tensorX, 0);
auto y = module.set_input("forward", tensorY, 1);
const auto result = module.forward();
// Check for success or failure.
if (result.ok()) {
// Retrieve the output data.
std::cout << result.ok() << std::endl;
}
return 0;
}
I get these errors:
E 00:00:00.000321 executorch:operator_registry.cpp:185] kernel 'aten::add.out' not found.
E 00:00:00.000333 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000334 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000335 executorch:operator_registry.cpp:186] ]
E 00:00:00.000336 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000336 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000337 executorch:operator_registry.cpp:186] ]
E 00:00:00.000347 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000349 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000350 executorch:operator_registry.cpp:186] ]
E 00:00:00.000351 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000359 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000360 executorch:operator_registry.cpp:186] ]
E 00:00:00.000360 executorch:method.cpp:554] Missing operator: [0] aten::add.out
E 00:00:00.000362 executorch:method.cpp:763] There are 1 instructions don't have corresponding operator registered. See logs for details
E 00:00:00.000387 executorch:operator_registry.cpp:185] kernel 'aten::add.out' not found.
E 00:00:00.000395 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000396 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000397 executorch:operator_registry.cpp:186] ]
E 00:00:00.000398 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000399 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000400 executorch:operator_registry.cpp:186] ]
E 00:00:00.000401 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000402 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000402 executorch:operator_registry.cpp:186] ]
E 00:00:00.000404 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000405 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000406 executorch:operator_registry.cpp:186] ]
E 00:00:00.000407 executorch:method.cpp:554] Missing operator: [0] aten::add.out
E 00:00:00.000407 executorch:method.cpp:763] There are 1 instructions don't have corresponding operator registered. See logs for details
E 00:00:00.000414 executorch:operator_registry.cpp:185] kernel 'aten::add.out' not found.
E 00:00:00.000422 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000423 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000424 executorch:operator_registry.cpp:186] ]
E 00:00:00.000425 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000427 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000428 executorch:operator_registry.cpp:186] ]
E 00:00:00.000429 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000437 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000438 executorch:operator_registry.cpp:186] ]
E 00:00:00.000439 executorch:operator_registry.cpp:186] dtype: 6 | dim order: [
E 00:00:00.000440 executorch:operator_registry.cpp:186] 0,
E 00:00:00.000441 executorch:operator_registry.cpp:186] ]
E 00:00:00.000441 executorch:method.cpp:554] Missing operator: [0] aten::add.out
E 00:00:00.000443 executorch:method.cpp:763] There are 1 instructions don't have corresponding operator registered. See logs for details
Any suggestions what the problem might be?
Versions
PyTorch version: 2.6.0+cpu
Is debug build: False
CUDA used to build PyTorch: Could not collect
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.2 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.31.6
Libc version: glibc-2.39
Python version: 3.12.3 (main, Feb 4 2025, 14:48:35) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 572.83
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 24
On-line CPU(s) list: 0-23
Vendor ID: GenuineIntel
Model name: 12th Gen Intel(R) Core(TM) i9-12900K
CPU family: 6
Model: 151
Thread(s) per core: 2
Core(s) per socket: 12
Socket(s): 1
Stepping: 2
BogoMIPS: 6374.39
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 576 KiB (12 instances)
L1i cache: 384 KiB (12 instances)
L2 cache: 15 MiB (12 instances)
L3 cache: 30 MiB (1 instance)
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] executorch==0.5.0a0+1bc0699
[pip3] numpy==2.0.0
[pip3] torch==2.6.0+cpu
[pip3] torchao==0.8.0+gitebc43034
[pip3] torchaudio==2.6.0
[pip3] torchsr==1.0.4
[pip3] torchvision==0.21.0+cpu
[conda] Could not collect