-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_diarization.py
153 lines (123 loc) · 4.45 KB
/
eval_diarization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# SPDX-FileCopyrightText: Copyright © <2022> Idiap Research Institute <[email protected]>
#
# SPDX-FileContributor: Juan Zuluaga-Gomez <[email protected]>
#
# SPDX-License-Identifier: MIT-License
"""
Script to evaluate a text-based diarization model. The model is built by
fine-tuning a pretrained BERT-base-uncased model* fetched from HuggingFace.
* Other models can be used as well, e.g., bert-base-cased
BERT paper (ours): https://arxiv.org/abs/1810.04805
HuggingFace repository: https://huggingface.co/bert-base-uncased
"""
import argparse
import os
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
)
from diarization_utils import (
ATCDataset_diarization,
compute_metrics,
encode_tags,
read_atc_diarization_data,
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-pm",
"--print-metrics",
action="store_true",
help="Flag whether to print an output file, if set you need to pass an utt2spkid file",
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
default=1,
help="Batch size you want to use for decoding",
)
parser.add_argument(
"-m",
"--input-model",
required=True,
help="Folder where the final model is stored",
)
parser.add_argument(
"-i",
"--input-files",
required=True,
help="String with paths to text or utt2spk_id files to be evaluated, it needs to match the 'test_names' variales",
)
parser.add_argument(
"-n",
"--test-names",
required=True,
help="Name of the test sets to be evaluated",
)
parser.add_argument(
"-o",
"--output-folder",
required=True,
help="Folder where the final model is stored",
)
return parser.parse_args()
def main(args):
"""Main code execution"""
# input model and some detailed outputs
token_classification_model = args.input_model
path_to_files = args.input_files.rstrip().split(" ")
test_set_names = args.test_names.rstrip().split(" ")
output_folder = args.output_folder
# evaluating that the number of test set names matches the amount of paths passed
assert len(path_to_files) == len(
test_set_names
), "number of test files and their names differ"
# create the output directory, in 'evaluations folder'
os.makedirs(os.path.dirname(output_folder) + "/evaluations", exist_ok=True)
print("\nLoading the TOKEN classification recognition model (TEXT-DIARIZATION)\n")
# Fetch the Model and tokenizer
eval_model = AutoModelForTokenClassification.from_pretrained(
token_classification_model
)
tokenizer = AutoTokenizer.from_pretrained(
token_classification_model, use_fast=True, do_lower_case=True
)
# get the labels of the model
tag2id = eval_model.config.label2id
id2tag = eval_model.config.id2label
# cast the standard DataCollator
data_collator = DataCollatorWithPadding(tokenizer)
# Trainer, only instantiated for testing
trainer = Trainer(model=eval_model, data_collator=data_collator)
# main loop,
for path_to_file, dataset_name in zip(path_to_files, test_set_names):
print(f"****** TEXT-BASED DIARIZATION ******")
print(f"---- Evaluating dataset: --> {dataset_name} -----")
# converting the data to model's format,
eval_texts, eval_tags = read_atc_diarization_data(path_to_file)
# Tokenize, pad and package data for forward pass,
eval_encodings = tokenizer(
eval_texts,
is_split_into_words=True,
return_offsets_mapping=True,
padding=True,
truncation=True,
)
eval_labels = encode_tags(tag2id, eval_tags, eval_encodings)
eval_encodings.pop("offset_mapping") # we don't want to pass this to the model
eval_dataset = ATCDataset_diarization(eval_encodings, eval_labels)
# run forward pass, evaluate and, print the metrics
raw_pred, raw_labels, _ = trainer.predict(eval_dataset)
path_to_output_file = f"{output_folder}/{dataset_name}_metrics"
metrics = compute_metrics(
raw_pred, raw_labels, label_list=id2tag, log_folder=path_to_output_file
)
if __name__ == "__main__":
args = parse_args()
main(args)