Skip to content

Commit 3bca3f3

Browse files
committed
not alloc mem for cm if CM_MASK==0
1 parent 7a3d3e4 commit 3bca3f3

File tree

1 file changed

+47
-34
lines changed

1 file changed

+47
-34
lines changed

src/plugins/intel_gpu/src/plugin/ops/moe_expert.cpp

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,18 @@ static void prepare_weights(ProgramBuilder& p, const std::shared_ptr<ov::op::int
9090

9191
cldnn::layout ptr_layout(ov::PartialShape{static_cast<int>(op->get_config().expert_num)}, cldnn::data_types::u64, cldnn::format::byfx);
9292
cldnn::layout gate_up_ptr_layout(ov::PartialShape{static_cast<int>(op->get_config().expert_num * 64 / sizeof(uint64_t))}, cldnn::data_types::u64, cldnn::format::byfx);
93-
// [64bytes]->gate_addrs,up_addrs, gate_scales_addrs, up_scales_addrs,gate_zp_addrs,up_zp_addrs, padding1, padding2
94-
scale_zp.gate_up_addrs = p.get_engine().allocate_memory(gate_up_ptr_layout, cldnn::allocation_type::usm_device, false);
95-
scale_zp.down_addrs = p.get_engine().allocate_memory(ptr_layout, cldnn::allocation_type::usm_device, false);
96-
scale_zp.down_scales_addrs = p.get_engine().allocate_memory(ptr_layout, cldnn::allocation_type::usm_device, false);
97-
scale_zp.down_zp_addrs = p.get_engine().allocate_memory(ptr_layout, cldnn::allocation_type::usm_device, false);
93+
int cm_mask = 1;
94+
auto env = std::getenv("CM_MASK");
95+
if (env) {
96+
cm_mask = std::atoi(env);
97+
}
98+
if (cm_mask) {
99+
// [64bytes]->gate_addrs,up_addrs, gate_scales_addrs, up_scales_addrs,gate_zp_addrs,up_zp_addrs, padding1, padding2
100+
scale_zp.gate_up_addrs = p.get_engine().allocate_memory(gate_up_ptr_layout, cldnn::allocation_type::usm_device, false);
101+
scale_zp.down_addrs = p.get_engine().allocate_memory(ptr_layout, cldnn::allocation_type::usm_device, false);
102+
scale_zp.down_scales_addrs = p.get_engine().allocate_memory(ptr_layout, cldnn::allocation_type::usm_device, false);
103+
scale_zp.down_zp_addrs = p.get_engine().allocate_memory(ptr_layout, cldnn::allocation_type::usm_device, false);
104+
}
98105
std::array<std::vector<uint64_t>, 3> buf_down;
99106
struct addrs {
100107
uint64_t gate_addrs;
@@ -131,45 +138,51 @@ static void prepare_weights(ProgramBuilder& p, const std::shared_ptr<ov::op::int
131138
auto idx = rt["__scale_const__"].as<int>();
132139
OPENVINO_ASSERT(idx >= 0 && idx < 3);
133140
params[i].param[idx].scale = alloc(node, true);
134-
params[i].param[idx].scale_ba = alloc(node, false);
135-
auto p = reinterpret_cast<uint64_t>(params[i].param[idx].scale_ba->buffer_ptr());
136-
switch (idx) {
137-
case 0:
138-
buf_gate_up[i].gate_scales_addrs = p;
139-
break;
140-
case 1:
141-
buf_gate_up[i].up_scales_addrs = p;
142-
break;
143-
default:
144-
buf_down[1].push_back(p);
145-
break;
141+
if (cm_mask) {
142+
params[i].param[idx].scale_ba = alloc(node, false);
143+
auto p = reinterpret_cast<uint64_t>(params[i].param[idx].scale_ba->buffer_ptr());
144+
switch (idx) {
145+
case 0:
146+
buf_gate_up[i].gate_scales_addrs = p;
147+
break;
148+
case 1:
149+
buf_gate_up[i].up_scales_addrs = p;
150+
break;
151+
default:
152+
buf_down[1].push_back(p);
153+
break;
154+
}
146155
}
147156
}
148157
if (rt.count("__zp_const__")) {
149158
auto idx = rt["__zp_const__"].as<int>();
150159
OPENVINO_ASSERT(idx >= 0 && idx < 3);
151160
params[i].param[idx].zp = alloc(node, true);
152-
params[i].param[idx].zp_ba = alloc(node, false);
153-
auto p = reinterpret_cast<uint64_t>(params[i].param[idx].zp_ba->buffer_ptr());
154-
switch (idx) {
155-
case 0:
156-
buf_gate_up[i].gate_zp_addrs = p;
157-
break;
158-
case 1:
159-
buf_gate_up[i].up_zp_addrs = p;
160-
break;
161-
default:
162-
buf_down[2].push_back(p);
163-
break;
161+
if (cm_mask) {
162+
params[i].param[idx].zp_ba = alloc(node, false);
163+
auto p = reinterpret_cast<uint64_t>(params[i].param[idx].zp_ba->buffer_ptr());
164+
switch (idx) {
165+
case 0:
166+
buf_gate_up[i].gate_zp_addrs = p;
167+
break;
168+
case 1:
169+
buf_gate_up[i].up_zp_addrs = p;
170+
break;
171+
default:
172+
buf_down[2].push_back(p);
173+
break;
174+
}
164175
}
165176
}
166177
}
167178
}
168-
auto& stream = p.get_engine().get_service_stream();
169-
scale_zp.gate_up_addrs->copy_from(stream, buf_gate_up.data(), 0, 0, gate_up_ptr_layout.bytes_count(), true);
170-
scale_zp.down_addrs->copy_from(stream, buf_down[0].data(), 0, 0, ptr_layout.bytes_count(), true);
171-
scale_zp.down_scales_addrs->copy_from(stream, buf_down[1].data(), 0, 0, ptr_layout.bytes_count(), true);
172-
scale_zp.down_zp_addrs->copy_from(stream, buf_down[2].data(), 0, 0, ptr_layout.bytes_count(), true);
179+
if (cm_mask) {
180+
auto& stream = p.get_engine().get_service_stream();
181+
scale_zp.gate_up_addrs->copy_from(stream, buf_gate_up.data(), 0, 0, gate_up_ptr_layout.bytes_count(), true);
182+
scale_zp.down_addrs->copy_from(stream, buf_down[0].data(), 0, 0, ptr_layout.bytes_count(), true);
183+
scale_zp.down_scales_addrs->copy_from(stream, buf_down[1].data(), 0, 0, ptr_layout.bytes_count(), true);
184+
scale_zp.down_zp_addrs->copy_from(stream, buf_down[2].data(), 0, 0, ptr_layout.bytes_count(), true);
185+
}
173186
}
174187

175188
static void CreateMOEExpert2Op(ProgramBuilder& p, const std::shared_ptr<ov::op::internal::MOEExpert2>& op) {

0 commit comments

Comments
 (0)