Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:说话人日志pipline增加情绪识别 #1993

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
return key_list, data_list


def distribute_emotion(sentence_list, ser_time_list):
ser_time_list = [(st * 1000, ed * 1000, emotion) for st, ed, emotion in ser_time_list]
for d in sentence_list:
sentence_start = d['start']
sentence_end = d['end']
sentence_emotion = "EMO_UNKNOWN"
max_overlap = 0
for st, ed, emotion in ser_time_list:
overlap = max(min(sentence_end, ed) - max(sentence_start, st), 0)
if overlap > max_overlap:
max_overlap = overlap
sentence_emotion = emotion
if overlap > 0 and sentence_emotion == emotion:
max_overlap += overlap
d['emotion'] = sentence_emotion
return sentence_list


class AutoModel:

def __init__(self, **kwargs):
Expand Down Expand Up @@ -157,7 +175,11 @@ def __init__(self, **kwargs):
if spk_mode not in ["default", "vad_segment", "punc_segment"]:
logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
self.spk_mode = spk_mode

ser_model = kwargs.get("ser_model", None)
ser_kwargs = {} if kwargs.get("ser_kwargs", {}) is None else kwargs.get("ser_kwargs", {})
if ser_model is not None:
logging.info("Building SER model.")
ser_model, ser_kwargs = self.build_model(**ser_kwargs)
self.kwargs = kwargs
self.model = model
self.vad_model = vad_model
Expand All @@ -166,6 +188,8 @@ def __init__(self, **kwargs):
self.punc_kwargs = punc_kwargs
self.spk_model = spk_model
self.spk_kwargs = spk_kwargs
self.ser_model = ser_model
self.ser_kwargs = ser_kwargs
self.model_path = kwargs.get("model_path")

@staticmethod
Expand Down Expand Up @@ -434,6 +458,16 @@ def inference_with_vad(self, input, input_len=None, **cfg):
speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg
)
results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
if self.ser_model is not None:
ser_res = self.inference(speech_b, input_len=None, model=self.ser_model,
kwargs=self.ser_kwargs, **cfg)
if "SenseVoiceSmall" in kwargs.get("ser_model", None):
results[_b]["ser_type"] = [i['text'].split("|><|")[1] for i in ser_res]
elif "emotion2vec" in kwargs.get("ser_model", None):
results[_b]["ser_type"] = [i['labels'][i["scores"].index(max(i["scores"]))] for i in ser_res]



beg_idx = end_idx
end_idx += 1
max_len_in_batch = sample_length
Expand Down Expand Up @@ -526,6 +560,7 @@ def inference_with_vad(self, input, input_len=None, **cfg):
"end": vadsegment[1],
"sentence": rest["text"],
"timestamp": rest["timestamp"],
"emotion": rest["ser_type"],
}
)
elif self.spk_mode == "punc_segment":
Expand All @@ -549,6 +584,13 @@ def inference_with_vad(self, input, input_len=None, **cfg):
raw_text,
return_raw_text=return_raw_text,
)
if "ser_type" in result:
if len(sentence_list) == len(result["ser_type"]):
for i in range(len(sentence_list)):
sentence_list[i]["emotion"] = result["ser_type"][i]
else:
merged_list = [[x[0], x[1], y] for x, y in zip(all_segments, result["ser_type"])]
distribute_emotion(sentence_list, merged_list)
distribute_spk(sentence_list, sv_output)
result["sentence_info"] = sentence_list
elif kwargs.get("sentence_timestamp", False):
Expand All @@ -572,6 +614,8 @@ def inference_with_vad(self, input, input_len=None, **cfg):
result["sentence_info"] = sentence_list
if "spk_embedding" in result:
del result["spk_embedding"]
if "ser_type" in result:
del result["ser_type"]

result["key"] = key
results_ret_list.append(result)
Expand Down