Skip to content

Commit 80e753c

Browse files
committed
cpu: pooling: modify acl_pooling for stateless functions
Change-Id: I30a987c8c56e1b0a64e3b2268cc96ec30b2abce4
1 parent bbf8399 commit 80e753c

File tree

2 files changed

+112
-102
lines changed

2 files changed

+112
-102
lines changed

src/cpu/aarch64/acl_pooling.cpp

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2022-2023 Arm Ltd. and affiliates
2+
* Copyright 2022-2023, 2025 Arm Ltd. and affiliates
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -21,35 +21,72 @@ namespace impl {
2121
namespace cpu {
2222
namespace aarch64 {
2323

24+
status_t acl_pooling_fwd_t::init(engine_t *engine) {
25+
auto asp = pd()->asp_;
26+
27+
auto op = std::make_unique<arm_compute::experimental::op::CpuPooling>();
28+
29+
pooling_op_ = std::move(op);
30+
31+
// Configure pooling operation when workspace tensor is used, mem allocation happens
32+
if(asp.use_ws){
33+
pooling_op_->configure(&asp.src_info, &asp.dst_info, asp.pool_info, &asp.ws_info);
34+
}
35+
// Configure pooling operation when workspace tensor is not used, mem allocation happens
36+
else{
37+
pooling_op_->configure(&asp.src_info, &asp.dst_info, asp.pool_info, nullptr);
38+
}
39+
40+
return status::success;
41+
}
42+
2443
status_t acl_pooling_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
25-
// Lock here is needed because resource_mapper does not support
26-
// concurrent access.
27-
std::lock_guard<std::mutex> _lock {this->mtx};
2844
status_t status = status::success;
29-
auto src_base = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
30-
auto dst_base = CTX_OUT_MEM(void *, DNNL_ARG_DST);
45+
46+
auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
47+
auto dst = CTX_OUT_MEM(void *, DNNL_ARG_DST);
48+
3149
void *ws_base;
3250

33-
// Retrieve primitive resource and configured Compute Library objects
34-
auto *acl_resource
35-
= ctx.get_resource_mapper()->get<acl_pooling_resource_t>(this);
36-
acl_pooling_obj_t &acl_obj = acl_resource->get_acl_obj();
51+
auto asp = pd()->asp_;
52+
53+
arm_compute::Tensor src_tensor;
54+
arm_compute::Tensor dst_tensor;
3755

38-
if (acl_obj.use_ws) ws_base = CTX_OUT_MEM(void *, DNNL_ARG_WORKSPACE);
56+
src_tensor.allocator()->init(asp.src_info);
57+
src_tensor.allocator()->import_memory(const_cast<void *>(src));
58+
dst_tensor.allocator()->init(asp.dst_info);
59+
dst_tensor.allocator()->import_memory(dst);
3960

40-
// import_memory() and free() methods do not allocate/free any additional
41-
// memory, only acquire/release pointers.
42-
acl_obj.src_tensor.allocator()->import_memory(const_cast<void *>(src_base));
43-
acl_obj.dst_tensor.allocator()->import_memory(dst_base);
44-
if (acl_obj.use_ws) acl_obj.ws_tensor.allocator()->import_memory(ws_base);
61+
arm_compute::Tensor scratch_tensor;
62+
void *scratchpad_base = ctx.get_scratchpad_grantor().get<void>(
63+
memory_tracking::names::key_pool_reduction
64+
);
65+
scratch_tensor.allocator()->init(arm_compute::TensorInfo(
66+
asp.dst_info.tensor_shape(), 1, arm_compute::DataType::F32
67+
));
68+
scratch_tensor.allocator()->import_memory(scratchpad_base);
4569

46-
acl_obj.pool.run();
70+
arm_compute::Tensor ws_tensor;
4771

48-
acl_obj.src_tensor.allocator()->free();
49-
acl_obj.dst_tensor.allocator()->free();
50-
if (acl_obj.use_ws) acl_obj.ws_tensor.allocator()->free();
72+
if (asp.use_ws) {
73+
ws_base = CTX_OUT_MEM(void *, DNNL_ARG_WORKSPACE);
74+
ws_tensor.allocator()->init(asp.ws_info);
75+
ws_tensor.allocator()->import_memory(ws_base);
76+
}
77+
//for scratchpad based tensor
78+
arm_compute::ITensorPack run_pack {
79+
{arm_compute::TensorType::ACL_SRC_0, &src_tensor},
80+
{arm_compute::TensorType::ACL_DST_0, &dst_tensor},
81+
{arm_compute::TensorType::ACL_INT_0, &scratch_tensor}};
82+
83+
if (asp.use_ws) {
84+
run_pack.add_tensor(arm_compute::TensorType::ACL_DST_1, &ws_tensor);
85+
}
86+
pooling_op_->run(run_pack);
5187

5288
return status;
89+
5390
}
5491

5592
} // namespace aarch64

src/cpu/aarch64/acl_pooling.hpp

Lines changed: 55 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2022-2023 Arm Ltd. and affiliates
2+
* Copyright 2022-2023, 2025 Arm Ltd. and affiliates
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,72 +17,39 @@
1717
#ifndef CPU_AARCH64_ACL_POOLING_HPP
1818
#define CPU_AARCH64_ACL_POOLING_HPP
1919

20-
#include "cpu/aarch64/acl_utils.hpp"
2120
#include "cpu/cpu_pooling_pd.hpp"
2221

22+
#include "cpu/aarch64/acl_utils.hpp"
23+
24+
#include "arm_compute/core/TensorInfo.h"
25+
#include "arm_compute/runtime/IOperator.h"
26+
#include "arm_compute/runtime/experimental/operators/CpuPooling.h"
27+
2328
namespace dnnl {
2429
namespace impl {
2530
namespace cpu {
2631
namespace aarch64 {
2732

28-
struct acl_pooling_obj_t {
29-
arm_compute::NEPoolingLayer pool;
30-
arm_compute::Tensor src_tensor;
31-
arm_compute::Tensor ws_tensor;
32-
arm_compute::Tensor dst_tensor;
33-
bool use_ws;
34-
};
35-
3633
struct acl_pooling_conf_t {
37-
arm_compute::PoolingLayerInfo pool_info;
3834
arm_compute::TensorInfo src_info;
39-
arm_compute::TensorInfo ws_info;
4035
arm_compute::TensorInfo dst_info;
36+
arm_compute::PoolingLayerInfo pool_info;
37+
arm_compute::TensorInfo ws_info;
4138
bool use_ws;
4239
};
4340

44-
struct acl_pooling_resource_t : public resource_t {
45-
acl_pooling_resource_t()
46-
: acl_pooling_obj_(utils::make_unique<acl_pooling_obj_t>()) {}
47-
48-
status_t configure(const acl_pooling_conf_t &app) {
49-
if (!acl_pooling_obj_) return status::out_of_memory;
50-
51-
// Init Compute Library tensors based on info from descriptor
52-
acl_pooling_obj_->src_tensor.allocator()->init(app.src_info);
53-
acl_pooling_obj_->dst_tensor.allocator()->init(app.dst_info);
54-
55-
if (app.use_ws) {
56-
acl_pooling_obj_->ws_tensor.allocator()->init(app.ws_info);
57-
acl_pooling_obj_->pool.configure(&acl_pooling_obj_->src_tensor,
58-
&acl_pooling_obj_->dst_tensor, app.pool_info,
59-
&acl_pooling_obj_->ws_tensor);
60-
acl_pooling_obj_->use_ws = true;
61-
} else {
62-
acl_pooling_obj_->pool.configure(&acl_pooling_obj_->src_tensor,
63-
&acl_pooling_obj_->dst_tensor, app.pool_info);
64-
}
65-
66-
return status::success;
67-
}
68-
69-
acl_pooling_obj_t &get_acl_obj() const { return *acl_pooling_obj_; }
70-
71-
DNNL_DISALLOW_COPY_AND_ASSIGN(acl_pooling_resource_t);
72-
73-
private:
74-
std::unique_ptr<acl_pooling_obj_t> acl_pooling_obj_;
75-
}; // acl_pooling_resource_t
76-
7741
struct acl_pooling_fwd_t : public primitive_t {
7842
struct pd_t : public cpu_pooling_fwd_pd_t {
7943
using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
80-
81-
DECLARE_COMMON_PD_T("acl", acl_pooling_fwd_t);
44+
DECLARE_COMMON_PD_T("acl", acl_pooling_fwd_t, USE_GLOBAL_SCRATCHPAD);
8245

8346
status_t init(engine_t *engine) {
47+
auto scratchpad = scratchpad_registry().registrar();
48+
CHECK(init_scratchpad(scratchpad));
49+
50+
// ACL supports forward propagation only
8451
bool ok = set_default_params() == status::success
85-
&& is_fwd() // ACL supports forward propagation only
52+
&& is_fwd()
8653
&& utils::everyone_is(
8754
src_md()->data_type, dst_md()->data_type)
8855
&& utils::one_of(
@@ -97,21 +64,22 @@ struct acl_pooling_fwd_t : public primitive_t {
9764
// Choose the pooling type
9865
const alg_kind_t alg = pod->alg_kind;
9966
const bool is_max_pool = (alg == alg_kind::pooling_max);
100-
app.pool_info.pool_type = is_max_pool
67+
asp_.pool_info.pool_type = is_max_pool
10168
? arm_compute::PoolingType::MAX
10269
: arm_compute::PoolingType::AVG;
10370

10471
// Check if workspace Tensor is needed
10572
const bool ws_init = (is_max_pool
10673
&& pod->prop_kind == prop_kind::forward_training);
107-
app.use_ws = ws_init;
74+
asp_.use_ws = ws_init;
10875

10976
ACL_CHECK_SUPPORT(ws_init && src_md()->data_type != data_type::f32,
11077
"ACL Max pooling forward training only supports f32");
11178

11279
if (ws_init)
80+
// ACL only supports U32/S32 no U8
11381
init_default_ws(
114-
data_type::s32); // ACL only supports U32/S32 no U8
82+
data_type::s32);
11583
auto src_tag = memory_desc_matches_one_of_tag(
11684
*src_md(), format_tag::nhwc, format_tag::nchw);
11785
auto dst_tag = memory_desc_matches_one_of_tag(
@@ -129,12 +97,12 @@ struct acl_pooling_fwd_t : public primitive_t {
12997
ACL_CHECK_SUPPORT(ndims != 4, "Tensor is not 4d");
13098

13199
// Pooling window
132-
app.pool_info.pool_size = arm_compute::Size2D(KW(), KH());
100+
asp_.pool_info.pool_size = arm_compute::Size2D(KW(), KH());
133101
// Choose the data layout
134102
bool is_nhwc = src_tag == format_tag::nhwc;
135103
const auto acl_layout = is_nhwc ? arm_compute::DataLayout::NHWC
136104
: arm_compute::DataLayout::NCHW;
137-
app.pool_info.data_layout = acl_layout;
105+
asp_.pool_info.data_layout = acl_layout;
138106
const auto acl_data_t
139107
= acl_utils::get_acl_data_t(src_d.data_type());
140108

@@ -158,41 +126,43 @@ struct acl_pooling_fwd_t : public primitive_t {
158126
"kernels are faster for this problem");
159127
}
160128

161-
app.pool_info.exclude_padding
129+
asp_.pool_info.exclude_padding
162130
= (alg == alg_kind::pooling_avg_exclude_padding);
163131

164-
app.pool_info.pad_stride_info = arm_compute::PadStrideInfo(KSW(),
132+
asp_.pool_info.pad_stride_info = arm_compute::PadStrideInfo(KSW(),
165133
KSH(), padL(), padR(), padT(), padB(),
166134
arm_compute::DimensionRoundingType::FLOOR);
167135

168-
app.src_info = arm_compute::TensorInfo(is_nhwc
136+
asp_.src_info = arm_compute::TensorInfo(is_nhwc
169137
? arm_compute::TensorShape(IC(), IW(), IH(), MB())
170138
: arm_compute::TensorShape(IW(), IH(), IC(), MB()),
171139
1, acl_data_t, acl_layout);
172-
app.dst_info = arm_compute::TensorInfo(is_nhwc
140+
asp_.dst_info = arm_compute::TensorInfo(is_nhwc
173141
? arm_compute::TensorShape(OC(), OW(), OH(), MB())
174142
: arm_compute::TensorShape(OW(), OH(), OC(), MB()),
175143
1, acl_data_t, acl_layout);
176144

177145
// Use datatype lowest property instead of using -INF
178-
app.pool_info.use_inf_as_limit = false;
146+
asp_.pool_info.use_inf_as_limit = false;
179147

180148
if (ws_init) {
181-
app.ws_info = arm_compute::TensorInfo(is_nhwc
149+
asp_.ws_info = arm_compute::TensorInfo(is_nhwc
182150
? arm_compute::TensorShape(
183151
OC(), OW(), OH(), MB())
184152
: arm_compute::TensorShape(
185153
OW(), OH(), OC(), MB()),
186154
1, arm_compute::DataType::U32, acl_layout);
187155

188156
// Return kernel indices instead of source indices.
189-
app.pool_info.use_kernel_indices = true;
157+
asp_.pool_info.use_kernel_indices = true;
190158
ACL_CHECK_VALID(
191-
arm_compute::NEPoolingLayer::validate(&app.src_info,
192-
&app.dst_info, app.pool_info, &app.ws_info));
159+
arm_compute::experimental::op::CpuPooling::validate(
160+
&asp_.src_info, &asp_.dst_info, asp_.pool_info, &asp_.ws_info));
193161
} else {
194-
ACL_CHECK_VALID(arm_compute::NEPoolingLayer::validate(
195-
&app.src_info, &app.dst_info, app.pool_info));
162+
asp_.pool_info.use_kernel_indices = false;
163+
ACL_CHECK_VALID(
164+
arm_compute::experimental::op::CpuPooling::validate(
165+
&asp_.src_info, &asp_.dst_info, asp_.pool_info));
196166
}
197167

198168
return status::success;
@@ -262,34 +232,37 @@ struct acl_pooling_fwd_t : public primitive_t {
262232
return problem_size > cutoff * thread_count;
263233
}
264234

265-
acl_pooling_conf_t app = utils::zero<decltype(app)>();
266-
};
235+
acl_pooling_conf_t asp_;
236+
237+
status_t init_scratchpad(
238+
memory_tracking::registrar_t &scratchpad) {
239+
const memory_desc_wrapper dst_d(&dst_md_);
240+
scratchpad.book(
241+
memory_tracking::names::key_pool_reduction,
242+
dst_d.nelems(), sizeof(float)
243+
);
244+
if (asp_.use_ws) {
245+
scratchpad.book(
246+
memory_tracking::names::key_pool_ind_plain2blocked_cvt,
247+
dst_d.nelems(), sizeof(uint32_t));
248+
}
249+
return status::success;
250+
}
251+
252+
}; // pd_t
267253

254+
// constructor
268255
acl_pooling_fwd_t(const pd_t *apd) : primitive_t(apd) {}
269256

270257
status_t execute(const exec_ctx_t &ctx) const override {
271258
return execute_forward(ctx);
272259
}
273260

274-
status_t create_resource(
275-
engine_t *engine, resource_mapper_t &mapper) const override {
276-
if (mapper.has_resource(this)) return status::success;
277-
278-
auto r = utils::make_unique<acl_pooling_resource_t>();
279-
if (!r) return status::out_of_memory;
280-
281-
// Configure the resource based on information from primitive descriptor
282-
auto st = r->configure(pd()->app);
283-
if (st == status::success) { mapper.add(this, std::move(r)); }
284-
285-
return st;
286-
}
287-
288261
private:
289-
// execute_forward has to be const thus mutability of mtx
290-
mutable std::mutex mtx;
262+
status_t init(engine_t *engine) override;
291263
status_t execute_forward(const exec_ctx_t &ctx) const;
292264
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
265+
std::unique_ptr<arm_compute::experimental::op::CpuPooling> pooling_op_;
293266
}; // acl_pooling_fwd_t
294267

295268
} // namespace aarch64

0 commit comments

Comments
 (0)