@@ -90,11 +90,18 @@ static void prepare_weights(ProgramBuilder& p, const std::shared_ptr<ov::op::int
90
90
91
91
cldnn::layout ptr_layout (ov::PartialShape{static_cast <int >(op->get_config ().expert_num )}, cldnn::data_types::u64 , cldnn::format::byfx);
92
92
cldnn::layout gate_up_ptr_layout (ov::PartialShape{static_cast <int >(op->get_config ().expert_num * 64 / sizeof (uint64_t ))}, cldnn::data_types::u64 , cldnn::format::byfx);
93
- // [64bytes]->gate_addrs,up_addrs, gate_scales_addrs, up_scales_addrs,gate_zp_addrs,up_zp_addrs, padding1, padding2
94
- scale_zp.gate_up_addrs = p.get_engine ().allocate_memory (gate_up_ptr_layout, cldnn::allocation_type::usm_device, false );
95
- scale_zp.down_addrs = p.get_engine ().allocate_memory (ptr_layout, cldnn::allocation_type::usm_device, false );
96
- scale_zp.down_scales_addrs = p.get_engine ().allocate_memory (ptr_layout, cldnn::allocation_type::usm_device, false );
97
- scale_zp.down_zp_addrs = p.get_engine ().allocate_memory (ptr_layout, cldnn::allocation_type::usm_device, false );
93
+ int cm_mask = 1 ;
94
+ auto env = std::getenv (" CM_MASK" );
95
+ if (env) {
96
+ cm_mask = std::atoi (env);
97
+ }
98
+ if (cm_mask) {
99
+ // [64bytes]->gate_addrs,up_addrs, gate_scales_addrs, up_scales_addrs,gate_zp_addrs,up_zp_addrs, padding1, padding2
100
+ scale_zp.gate_up_addrs = p.get_engine ().allocate_memory (gate_up_ptr_layout, cldnn::allocation_type::usm_device, false );
101
+ scale_zp.down_addrs = p.get_engine ().allocate_memory (ptr_layout, cldnn::allocation_type::usm_device, false );
102
+ scale_zp.down_scales_addrs = p.get_engine ().allocate_memory (ptr_layout, cldnn::allocation_type::usm_device, false );
103
+ scale_zp.down_zp_addrs = p.get_engine ().allocate_memory (ptr_layout, cldnn::allocation_type::usm_device, false );
104
+ }
98
105
std::array<std::vector<uint64_t >, 3 > buf_down;
99
106
struct addrs {
100
107
uint64_t gate_addrs;
@@ -131,45 +138,51 @@ static void prepare_weights(ProgramBuilder& p, const std::shared_ptr<ov::op::int
131
138
auto idx = rt[" __scale_const__" ].as <int >();
132
139
OPENVINO_ASSERT (idx >= 0 && idx < 3 );
133
140
params[i].param [idx].scale = alloc (node, true );
134
- params[i].param [idx].scale_ba = alloc (node, false );
135
- auto p = reinterpret_cast <uint64_t >(params[i].param [idx].scale_ba ->buffer_ptr ());
136
- switch (idx) {
137
- case 0 :
138
- buf_gate_up[i].gate_scales_addrs = p;
139
- break ;
140
- case 1 :
141
- buf_gate_up[i].up_scales_addrs = p;
142
- break ;
143
- default :
144
- buf_down[1 ].push_back (p);
145
- break ;
141
+ if (cm_mask) {
142
+ params[i].param [idx].scale_ba = alloc (node, false );
143
+ auto p = reinterpret_cast <uint64_t >(params[i].param [idx].scale_ba ->buffer_ptr ());
144
+ switch (idx) {
145
+ case 0 :
146
+ buf_gate_up[i].gate_scales_addrs = p;
147
+ break ;
148
+ case 1 :
149
+ buf_gate_up[i].up_scales_addrs = p;
150
+ break ;
151
+ default :
152
+ buf_down[1 ].push_back (p);
153
+ break ;
154
+ }
146
155
}
147
156
}
148
157
if (rt.count (" __zp_const__" )) {
149
158
auto idx = rt[" __zp_const__" ].as <int >();
150
159
OPENVINO_ASSERT (idx >= 0 && idx < 3 );
151
160
params[i].param [idx].zp = alloc (node, true );
152
- params[i].param [idx].zp_ba = alloc (node, false );
153
- auto p = reinterpret_cast <uint64_t >(params[i].param [idx].zp_ba ->buffer_ptr ());
154
- switch (idx) {
155
- case 0 :
156
- buf_gate_up[i].gate_zp_addrs = p;
157
- break ;
158
- case 1 :
159
- buf_gate_up[i].up_zp_addrs = p;
160
- break ;
161
- default :
162
- buf_down[2 ].push_back (p);
163
- break ;
161
+ if (cm_mask) {
162
+ params[i].param [idx].zp_ba = alloc (node, false );
163
+ auto p = reinterpret_cast <uint64_t >(params[i].param [idx].zp_ba ->buffer_ptr ());
164
+ switch (idx) {
165
+ case 0 :
166
+ buf_gate_up[i].gate_zp_addrs = p;
167
+ break ;
168
+ case 1 :
169
+ buf_gate_up[i].up_zp_addrs = p;
170
+ break ;
171
+ default :
172
+ buf_down[2 ].push_back (p);
173
+ break ;
174
+ }
164
175
}
165
176
}
166
177
}
167
178
}
168
- auto & stream = p.get_engine ().get_service_stream ();
169
- scale_zp.gate_up_addrs ->copy_from (stream, buf_gate_up.data (), 0 , 0 , gate_up_ptr_layout.bytes_count (), true );
170
- scale_zp.down_addrs ->copy_from (stream, buf_down[0 ].data (), 0 , 0 , ptr_layout.bytes_count (), true );
171
- scale_zp.down_scales_addrs ->copy_from (stream, buf_down[1 ].data (), 0 , 0 , ptr_layout.bytes_count (), true );
172
- scale_zp.down_zp_addrs ->copy_from (stream, buf_down[2 ].data (), 0 , 0 , ptr_layout.bytes_count (), true );
179
+ if (cm_mask) {
180
+ auto & stream = p.get_engine ().get_service_stream ();
181
+ scale_zp.gate_up_addrs ->copy_from (stream, buf_gate_up.data (), 0 , 0 , gate_up_ptr_layout.bytes_count (), true );
182
+ scale_zp.down_addrs ->copy_from (stream, buf_down[0 ].data (), 0 , 0 , ptr_layout.bytes_count (), true );
183
+ scale_zp.down_scales_addrs ->copy_from (stream, buf_down[1 ].data (), 0 , 0 , ptr_layout.bytes_count (), true );
184
+ scale_zp.down_zp_addrs ->copy_from (stream, buf_down[2 ].data (), 0 , 0 , ptr_layout.bytes_count (), true );
185
+ }
173
186
}
174
187
175
188
static void CreateMOEExpert2Op (ProgramBuilder& p, const std::shared_ptr<ov::op::internal::MOEExpert2>& op) {
0 commit comments