Skip to content

Commit 4d0f987

Browse files
authored
[WIP] Add libtorch backend, related documentation and related op tests (#1087)
* Add libtorch backend, related documentation and related op tests * update torch op dir Co-authored-by: BowShotDS <[email protected]>
1 parent 8de88da commit 4d0f987

20 files changed

+1407
-14
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ OPTION (TENGINE_ENABLE_OPENCL "With Khronos OpenCL support"
7878
OPTION (TENGINE_ENABLE_OPENDLA "With Khronos OpenDLA support" OFF)
7979
OPTION (TENGINE_ENABLE_TENSORRT "With nVIDIA TensorRT support" OFF)
8080
OPTION (TENGINE_ENABLE_TIM_VX "With VeriSilicon TIM-VX support" OFF)
81+
OPTION (TENGINE_ENABLE_TORCH "With libTorch support" OFF)
8182
OPTION (TENGINE_ENABLE_NNIE "With HiSilicon NNIE support" OFF)
8283
OPTION (TENGINE_ENABLE_VULKAN "With Khronos Vulkan GPU compute support" OFF)
8384

@@ -96,6 +97,7 @@ OPTION (TENGINE_ONLINE_REPORT "online report"
9697
# Do check list
9798
INCLUDE ("${CMAKE_CURRENT_SOURCE_DIR}/cmake/check.cmake")
9899
INCLUDE ("${CMAKE_CURRENT_SOURCE_DIR}/cmake/cuda.cmake")
100+
INCLUDE ("${CMAKE_CURRENT_SOURCE_DIR}/cmake/torch.cmake")
99101
INCLUDE ("${CMAKE_CURRENT_SOURCE_DIR}/cmake/registry.cmake")
100102
INCLUDE ("${CMAKE_CURRENT_SOURCE_DIR}/cmake/utility.cmake")
101103

cmake/check.cmake

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,11 @@ ENDIF()
222222

223223

224224
# C++11 is the base required standard
225-
SET (CMAKE_CXX_STANDARD 11)
225+
IF (TENGINE_ENABLE_TORCH)
226+
SET (CMAKE_CXX_STANDARD 17)
227+
else()
228+
SET (CMAKE_CXX_STANDARD 11)
229+
ENDIF()
226230
SET (CMAKE_CXX_STANDARD_REQUIRED TRUE)
227231
SET (CMAKE_CXX_EXTENSIONS OFF)
228232

cmake/torch.cmake

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# License); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# Copyright (c) 2021, OPEN AI LAB
19+
20+
#
21+
22+
IF (TENGINE_ENABLE_TORCH)
23+
find_package(Torch REQUIRED)
24+
25+
message(STATUS "Torch library status:")
26+
message(STATUS " version: ${TORCH_VERSION}")
27+
message(STATUS " libraries: ${TORCH_LIBS}")
28+
message(STATUS " include path: ${TORCH_INCLUDE_DIRS}")
29+
message(STATUS " torch lib: ${TORCH_LIBRARIES}")
30+
ENDIF()

source/device/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,20 @@ IF (TENGINE_ENABLE_TIM_VX)
148148
LIST (APPEND _REGISTER_DEVICE_LIST "${CMAKE_SOURCE_DIR}/source/device/tim-vx/timvx_device.cc")
149149
ENDIF()
150150

151+
# libTorch
152+
IF (TENGINE_ENABLE_TORCH)
153+
ADD_SUBDIRECTORY (torch)
154+
155+
LIST (APPEND _TENGINE_DEVICE_HEADER_PATH ${TENGINE_TORCH_HEADER_PATH})
156+
LIST (APPEND _TENGINE_DEVICE_LINK_PATH ${TENGINE_TORCH_LINK_PATH})
157+
LIST (APPEND _TENGINE_DEVICE_COMPILER_DEFINES ${TENGINE_TORCH_COMPILER_DEFINES})
158+
LIST (APPEND _TENGINE_DEVICE_COMPILER_OPTIONS ${TENGINE_TORCH_COMPILER_OPTIONS})
159+
LIST (APPEND _TENGINE_DEVICE_LINKER_OPTIONS ${TENGINE_TORCH_LINKER_OPTIONS})
160+
LIST (APPEND _TENGINE_DEVICE_LINK_LIBRARIES ${TENGINE_TORCH_LINK_LIBRARIES})
161+
LIST (APPEND _TENGINE_DEVICE_SOURCE ${TENGINE_TORCH_DEVICE_SOURCE})
162+
LIST (APPEND _REGISTER_DEVICE_LIST "${CMAKE_SOURCE_DIR}/source/device/torch/torch_device.cc")
163+
ENDIF()
164+
151165
# Khronos Vulkan
152166
IF (TENGINE_ENABLE_VULKAN)
153167
ADD_SUBDIRECTORY (vulkan)

source/device/torch/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
include
2+
src

source/device/torch/CMakeLists.txt

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# License); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# Copyright (c) 2021, OPEN AI LAB
19+
20+
#
21+
22+
# 0. clear var
23+
UNSET (_DEV_TORCH_HEADER_PATH)
24+
UNSET (_TORCH_BASE_SOURCE)
25+
UNSET (_TORCH_OPS_SOURCE)
26+
UNSET (_DEV_TORCH_DEVICE_SOURCE)
27+
UNSET (_DEV_TORCH_COMPILER_DEFINES)
28+
UNSET (_DEV_TORCH_COMPILER_OPTIONS)
29+
UNSET (_DEV_TORCH_LINKER_OPTIONS)
30+
UNSET (_DEV_TORCH_LINK_LIBRARIES)
31+
32+
33+
# 1. set source root path
34+
SET(_TORCH_ROOT ${CMAKE_SOURCE_DIR}/source/device/torch)
35+
36+
37+
# 2. add header file path
38+
LIST (APPEND _DEV_TORCH_HEADER_PATH ${_TORCH_ROOT})
39+
LIST (APPEND _DEV_TORCH_HEADER_PATH ${TORCH_INCLUDE_DIRS})
40+
41+
42+
# 3. add linking lib searching path
43+
LIST (APPEND _DEV_TORCH_LINK_PATH ${TORCH_LIBS})
44+
45+
46+
# 4. add source files
47+
AUX_SOURCE_DIRECTORY("${_TORCH_ROOT}" _TORCH_BASE_SOURCE)
48+
AUX_SOURCE_DIRECTORY("${_TORCH_ROOT}/op" _TORCH_OPS_SOURCE)
49+
LIST (APPEND _DEV_TORCH_DEVICE_SOURCE ${_TORCH_BASE_SOURCE})
50+
LIST (APPEND _DEV_TORCH_DEVICE_SOURCE ${_TORCH_OPS_SOURCE})
51+
52+
53+
# 5. add build options for cpu device
54+
# 5.1 is a gcc or clang like compiler
55+
IF (TENGINE_COMPILER_GCC OR TENGINE_COMPILER_CLANG)
56+
ENDIF()
57+
58+
59+
# 5.2 is Microsoft Visual C++
60+
IF (TENGINE_COMPILER_MSVC)
61+
ENDIF()
62+
63+
64+
# 6. add link options
65+
66+
67+
# 7. add link libs
68+
LIST (APPEND _DEV_TORCH_LINK_LIBRARIES ${TORCH_LIBRARIES})
69+
70+
71+
# 8. set all to cmake cache
72+
SET (TENGINE_TORCH_HEADER_PATH ${_DEV_TORCH_HEADER_PATH} CACHE INTERNAL "Tengine TensorRT device header files searching path" FORCE)
73+
SET (TENGINE_TORCH_LINK_PATH ${_DEV_TORCH_LINK_PATH} CACHE INTERNAL "Tengine TensorRT device link libraries searching path" FORCE)
74+
SET (TENGINE_TORCH_DEVICE_SOURCE ${_DEV_TORCH_DEVICE_SOURCE} CACHE INTERNAL "Tengine TensorRT device main source files" FORCE)
75+
SET (TENGINE_TORCH_COMPILER_DEFINES ${_DEV_TORCH_COMPILER_DEFINES} CACHE INTERNAL "Tengine TensorRT about compiler defines" FORCE)
76+
SET (TENGINE_TORCH_COMPILER_OPTIONS ${_DEV_TORCH_COMPILER_OPTIONS} CACHE INTERNAL "Tengine TensorRT about compiler options" FORCE)
77+
SET (TENGINE_TORCH_LINKER_OPTIONS ${_DEV_TORCH_LINKER_OPTIONS} CACHE INTERNAL "Tengine TensorRT about linker options" FORCE)
78+
SET (TENGINE_TORCH_LINK_LIBRARIES ${_DEV_TORCH_LINK_LIBRARIES} CACHE INTERNAL "Tengine TensorRT about link libraries" FORCE)
79+
80+
81+
# 9. install device option
82+
INSTALL (FILES ${_TORCH_ROOT}/torch_define.h DESTINATION include/tengine RENAME torch_device.h)
83+
84+
85+
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* License); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*
21+
* Copyright (c) 2021, Open AI Lab
22+
23+
*/
24+
25+
#include "torch_helper.hpp"
26+
27+
extern "C"
28+
{
29+
#include "operator/op.h"
30+
#include "convolution_param.h"
31+
}
32+
33+
bool Net::AddConvolutionNode(struct node* ir_node)
34+
{
35+
struct conv_param* param = (struct conv_param*)ir_node->op.param_mem;
36+
37+
struct tensor* input_tensor = get_ir_graph_tensor(ir_graph, ir_node->input_tensors[0]);
38+
struct tensor* weight_tensor = get_ir_graph_tensor(ir_graph, ir_node->input_tensors[1]);
39+
struct tensor* output_tensor = get_ir_graph_tensor(ir_graph, ir_node->output_tensors[0]);
40+
struct tensor* bias_tensor;
41+
42+
bool bias = false;
43+
if (ir_node->input_num > 2)
44+
{
45+
bias = true;
46+
bias_tensor = get_ir_graph_tensor(ir_graph, ir_node->input_tensors[2]);
47+
}
48+
49+
torch::nn::Conv2d layer
50+
= torch::nn::Conv2d{create_conv_options(
51+
/*in_planes = */ input_tensor->dims[1], /*out_planes = */ output_tensor->dims[1],
52+
/*kerner_size = */ param->kernel_h, /*stride = */ param->stride_h, /*padding = */ param->pad_h0,
53+
/*groups = */ param->group, /*dilation = */ param->dilation_h, /*bias = */ bias)};
54+
register_module(std::to_string(ir_node->index), layer);
55+
torch_node_map[ir_node->index] = layer;
56+
57+
{
58+
torch::Tensor t = torch::rand({weight_tensor->dims[0], weight_tensor->dims[1], weight_tensor->dims[2], weight_tensor->dims[3]});
59+
void* date_mem = t.data_ptr();
60+
memcpy(date_mem, weight_tensor->data, weight_tensor->elem_num * weight_tensor->elem_size);
61+
layer->weight = register_parameter(std::to_string(ir_node->index) + "_weight", t);
62+
}
63+
64+
if (bias)
65+
{
66+
torch::Tensor t = torch::rand({output_tensor->dims[1]});
67+
void* date_mem = t.data_ptr();
68+
memcpy(date_mem, bias_tensor->data, bias_tensor->elem_num * bias_tensor->elem_size);
69+
layer->bias = register_parameter(std::to_string(ir_node->index) + "_bias", t);
70+
}
71+
72+
73+
return true;
74+
}

source/device/torch/torch_define.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* License); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*
21+
* Copyright (c) 2021, OPEN AI LAB
22+
23+
*/
24+
25+
#pragma once
26+
27+
#define TORCH_DEV_NAME "TORCH"

0 commit comments

Comments
 (0)