diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 3829bcabb..8e00703ca 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -385,11 +385,11 @@ def inference_with_vad(self, input, input_len=None, **cfg): result["text"] = punc_res[0]["text"] # speaker embedding cluster after resorted - if self.spk_model is not None: + if self.spk_model is not None and kwargs.get('return_spk_res', True): all_segments = sorted(all_segments, key=lambda x: x[0]) spk_embedding = result['spk_embedding'] - labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs['preset_spk_num']) - del result['spk_embedding'] + labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs.get('preset_spk_num', None)) + # del result['spk_embedding'] sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) if self.spk_mode == 'vad_segment': # recover sentence_list sentence_list = [] @@ -409,6 +409,7 @@ def inference_with_vad(self, input, input_len=None, **cfg): result['timestamp'], \ result['raw_text']) result['sentence_info'] = sentence_list + del result['spk_embedding'] result["key"] = key results_ret_list.append(result)