Skip to content

Commit 2008b58

Browse files
committed
修改解码器输出函数名
1 parent 2617deb commit 2008b58

File tree

4 files changed

+8
-17
lines changed

4 files changed

+8
-17
lines changed

export_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
add_arg = functools.partial(add_arguments, argparser=parser)
99
add_arg('configs', str, 'configs/conformer.yml', '配置文件')
1010
add_arg("use_gpu", bool, True, '是否使用GPU评估模型')
11-
add_arg("save_quant", bool, False, '是否保存量化模型')
1211
add_arg('save_model', str, 'models/', '模型保存的路径')
1312
add_arg('resume_model', str, 'models/ConformerModel_fbank/best_model/', '准备转换的模型路径')
1413
args = parser.parse_args()
@@ -20,5 +19,4 @@
2019

2120
# 导出预测模型
2221
trainer.export(save_model_path=args.save_model,
23-
resume_model=args.resume_model,
24-
save_quant=args.save_quant)
22+
resume_model=args.resume_model)

masr/decoders/attention_rescoring.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def attention_rescoring(
6161

6262
# ctc score in ln domain
6363
# (beam_size, max_hyps_len, vocab_size)
64-
decoder_out, r_decoder_out = model.forward_attention_decoder(hyps_pad, hyps_lens, encoder_out, reverse_weight)
64+
decoder_out, r_decoder_out = model.get_decoder_out(hyps_pad, hyps_lens, encoder_out, reverse_weight)
6565

6666
# Only use decoder score for rescoring
6767
best_score = -float('inf')

masr/model_utils/conformer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def get_encoder_out_chunk(self,
219219
return ctc_probs, att_cache, cnn_cache
220220

221221
@torch.jit.export
222-
def forward_attention_decoder(
222+
def get_decoder_out(
223223
self,
224224
hyps: torch.Tensor,
225225
hyps_lens: torch.Tensor,

masr/trainer.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -591,16 +591,13 @@ def evaluate(self, resume_model=None, display_result=False, max_text_duration=No
591591

592592
def export(self,
593593
save_model_path='models/',
594-
resume_model='models/ConformerModel_fbank/best_model/',
595-
save_quant=False):
594+
resume_model='models/ConformerModel_fbank/best_model/'):
596595
"""导出预测模型
597596
598597
:param save_model_path: 模型保存的路径
599598
:type save_model_path: str
600599
:param resume_model: 准备转换的模型路径
601600
:type resume_model: str
602-
:param save_quant: 是否保存量化模型
603-
:type save_quant: bool
604601
:return:
605602
"""
606603
# 获取训练数据
@@ -632,13 +629,6 @@ def export(self,
632629
os.makedirs(save_model_dir, exist_ok=True)
633630
torch.jit.save(infer_model, infer_model_path)
634631
logger.info("预测模型已保存:{}".format(infer_model_path))
635-
# 保存量化模型
636-
if save_quant:
637-
quant_model_path = os.path.join(os.path.dirname(infer_model_path), 'inference_quant.pth')
638-
quantized_model = torch.quantization.quantize_dynamic(self.model)
639-
script_quant_model = torch.jit.script(quantized_model)
640-
torch.jit.save(script_quant_model, quant_model_path)
641-
logger.info("量化模型已保存:{}".format(quant_model_path))
642632
# 复制词汇表模型
643633
shutil.copytree(tokenizer.vocab_model_dir, os.path.join(save_model_dir, 'vocab_model'))
644634
# 保存配置信息
@@ -648,6 +638,9 @@ def export(self,
648638
'model_name': self.configs.model_conf.model,
649639
'streaming': self.configs.model_conf.model_args.streaming,
650640
'sample_rate': self.configs.dataset_conf.dataset.sample_rate,
651-
'preprocess_conf': self.configs.preprocess_conf
641+
'preprocess_conf': self.configs.preprocess_conf,
652642
}
643+
if self.configs.model_conf.model != "DeepSpeech2Model":
644+
inference_config['symbol'] = {'sos': self.model.sos_symbol(), 'eos': self.model.eos_symbol(),
645+
'ignore_id': self.model.ignore_symbol()}
653646
json.dump(inference_config, f, indent=4, ensure_ascii=False)

0 commit comments

Comments
 (0)