Skip to content

快速训练

hnluo edited this page Apr 21, 2023 · 5 revisions

代码(finetune.py):

import os
import json
import shutil

from modelscope.pipelines import pipeline
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from modelscope.utils.constant import Tasks

from funasr.datasets.ms_dataset import MsDataset
from funasr.utils.compute_wer import compute_wer


def modelscope_finetune(params):
    if not os.path.exists(params["model_dir"]):
        os.makedirs(params["model_dir"], exist_ok=True)
    # dataset split ["train", "validation"]
    ds_dict = MsDataset.load(params["dataset_name"], namespace='speech_asr')
    kwargs = dict(
        model=params["modelscope_model_name"],
        data_dir=ds_dict,
        work_dir=params["model_dir"],
        max_epoch=1)
    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
    trainer.train()
    pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
    required_files = ["am.mvn", "decoding.yaml", "configuration.json"]
    for file_name in required_files:
        shutil.copy(os.path.join(pretrained_model_path, file_name),
                    os.path.join(params["model_dir"], file_name))
    

def modelscope_infer(params):
    # prepare for decoding
    with open(os.path.join(params["model_dir"], "configuration.json")) as f:
        config_dict = json.load(f)
        config_dict["model"]["am_model_name"] = params["decoding_model_name"]
    with open(os.path.join(params["model_dir"], "configuration.json"), "w") as f:
        json.dump(config_dict, f, indent=4, separators=(',', ': '))
    decoding_path = os.path.join(params["model_dir"], "decode_results")
    if os.path.exists(decoding_path):
        shutil.rmtree(decoding_path)
    os.mkdir(decoding_path)

    # decoding
    inference_pipeline = pipeline(
        task=Tasks.auto_speech_recognition,
        model=params["model_dir"],
        output_dir=decoding_path,
        batch_size=64
    )
    audio_in = os.path.join(params["test_data_dir"], "wav.scp")
    inference_pipeline(audio_in=audio_in)

    # computer CER if GT text is set
    text_in = os.path.join(params["test_data_dir"], "text")
    if os.path.exists(text_in):
        text_proc_file = os.path.join(decoding_path, "1best_recog/token")
        compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
        os.system("tail -n 3 {}".format(os.path.join(decoding_path, "text.cer")))

if __name__ == '__main__':
    finetune_params = {}
    finetune_params["modelscope_model_name"] = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
    finetune_params["dataset_name"] = "speech_asr_aishell1_subset"
    finetune_params["model_dir"] = "./checkpoint"

    modelscope_finetune(finetune_params)

    infer_params = {}
    infer_params["model_dir"] = "./checkpoint"
    infer_params["decoding_model_name"] = "1epoch.pb"
    infer_params["test_data_dir"] = "./checkpoint/data/validation"
    modelscope_infer(infer_params)

流程:

modelscope模型资源下载->modelscope dataset下载->模型训练->模型测试并计算CER

运行命令

python finetune.py

训练输入参数介绍

  • modelscope_model_name:需要finetune的modelscope模型名字
  • dataset_name:modelscope dataset名字
  • model_dir:训练模型保存目录

训练输出目录结构

tree ./checkpoint/
./checkpoint/
├── 1epoch.pb
├── tensorboard
    ├── train
    └── valid
  • 1epoch.pb:训练1epoch模型文件
  • tensorboard:训练tensorboard保存目录,tensorboard打开方式:tensorboard --logdir checkpoint/tensorboard/train/,tensorboard查看方式:打开网页输入训练服务器 ip:6006

解码输入参数介绍

  • model_dir:解码模型目录
  • test_data_dir:测试数据目录
  • decoding_model_name:解码模型名字

解码输出目录结构

tree ./checkpoint/decode_results/
./checkpoint/decode_results/
├── text.cer
├── 1best_recog
    ├── rtf
    └── text
    └── score

  • text.cer:CER统计文件
BAC009S0724W0495.wav(nwords=13,cor=13,ins=0,del=0,sub=0) corr:100.00%,cer:0.00%
ref:    筹备了一系列新展并同时亮相
hyp:    筹备了一系列新展并同时亮相

%WER 3.51 [ 177 / 5037, 4 ins, 0 del, 173 sub ]
%SER 28.81 [ 102 / 354 ]
  • rtf:解码每句话的耗时
BAC009S0724W0495.wav decoding, feature length: 3075, forward_time: 0.4731, rtf: 0.0026
rtf_avf decoding, feature length total: 29815.0, forward_time total: 4.7479, rtf avg: 0.0027
  • score:解码每句话的得分
BAC009S0724W0495.wav tensor(-1.2249, device='cuda:0')
  • text:解码结果
BAC009S0724W0495.wav 筹备了一系列新展并同时亮相