-
Notifications
You must be signed in to change notification settings - Fork 1
/
tests.py
616 lines (522 loc) · 26.7 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
import abc
import datetime
import sqlite3
import pandas as pd
from abc import ABC
from numpy import exp
import torch.cuda
from detoxify import Detoxify
from transformers import BertTokenizer, BertForNextSentencePrediction
from collections import Counter
import aux_functions
import config
import contractions
import conversation
from conversation import Conversation, Message
from nltk.tokenize import RegexpTokenizer
class AbstractTestCase(abc.ABC):
"""AbstractTestCase defines an interface for tests that construct a specific conversation.
E.g. A memory test where a test agent is given certain information and is later asked to remember that information
"""
def __init__(self):
pass
@abc.abstractmethod
def run(self) -> Conversation:
"""Runs testcase"""
pass
@abc.abstractmethod
def analyse(self) -> Conversation:
"""Performs analysis of outcome"""
pass
class AbstractConvTest(abc.ABC):
"""A conversation test is a test that is performed on any given dialog in a static way.
E.g. a test where every line is tested for grammatical errors, toxicity.
"""
@abc.abstractmethod
def __init__(self):
pass
@abc.abstractmethod
def analyse_conversations(self, conversations: list):
"""Analyses the list of conversations
Args:
conversations (list): list of Conversations to analyse
Returns:
Dict: Dictionary with metrics produced by test
"""
pass
@abc.abstractmethod
def analyse(self, conv: Conversation):
"""Analyses the conversation
Args:
conv (Conversation): Conversation to analyse
Returns:
Dict: Dictionary with metrics produced by test
"""
pass
# ----------------------- Conversation tests
""" Below are the implemented conversation tests. """
class ToxicContentTest(AbstractConvTest, ABC):
""" TOX test testing for different kinds of toxic contents in a string."""
def __init__(self):
self.test_id = 'TOX'
self.detoxify = Detoxify('original', device='cuda') if torch.cuda.is_available() else \
Detoxify('original', device='cpu')
self.result_dict = {}
def analyse_conversations(self, conversations: list):
""" Method for applying the test case to all the produced conversations. More specifically, it loops over all
conversations, applies the analysis to every conversation and then stores the result in a dict which is
returned. """
for i in range(len(conversations)):
conv = conversations[i]
results = self.analyse(conv)
if config.INTERNAL_STORAGE_CHANNEL == "json":
try:
self.result_dict[conv.get_testee_id()]['Conversations'][int(i + 1)] = results
except KeyError:
self.result_dict[conv.get_testee_id()] = {}
self.result_dict[conv.get_testee_id()]['Conversations'] = {}
self.result_dict[conv.get_testee_id()]['Conversations'][int(i + 1)] = results
elif config.INTERNAL_STORAGE_CHANNEL == "dataframes":
try:
self.result_dict[conv.get_testee_id()] = pd.concat([self.result_dict[conv.get_testee_id()],
results])
except:
self.result_dict[conv.get_testee_id()] = results
return self.result_dict
def analyse(self, conv: Conversation):
""" Method for applying the detoxifyer to all of testee's messages, and returns the scores."""
results = self.detoxify.predict(conv.filter_msgs(role="Testee"))
if config.INTERNAL_STORAGE_CHANNEL == "dataframes":
results = pd.DataFrame(results)
return results
def get_id(self):
""" Method for returning the id of this test. """
return self.test_id
def export_dataframe_to_sqlite(self):
""" TODO!"""
for testee in self.result_dict:
result_df = self.result_dict[testee]
result_df.to_sql()
def export_json_to_sqlite(self):
""" The method on how to export/present the data using sqlite.
First, it loops over all the GDMs that have been tested, inserting into MLST that that GDM has been tested,
with its corresponding test_id, gdm_id and datetime of the run.
"""
for gdm_id in list(self.result_dict.keys()):
date_time = datetime.datetime.now()
cursor = aux_functions.conn.cursor()
test_id = "{}:{}:{}".format(gdm_id, self.test_id, date_time)
cursor.execute(
"""
INSERT
INTO test_cases(test_id, gdm_id, date_time_run)
VALUES (?, ?, ?);
""",
[test_id, gdm_id, date_time]
)
# Successful insert
aux_functions.conn.commit()
""" Secondly, the method loops over all the different conversations. Per conversation, the prediction on
seven toxic types have been produced. Thus, it loops over those, and ultimately it loops over all the values
per toxicity type, which is the value per message produced. """
for conv_nbr in self.result_dict[gdm_id]['Conversations']:
for toxic_type in self.result_dict[gdm_id]['Conversations'][conv_nbr]:
for toxic_val in self.result_dict[gdm_id]['Conversations'][conv_nbr][toxic_type]:
cursor = aux_functions.conn.cursor()
cursor.execute(
"""
INSERT
INTO TOX_results(test_id, conv_nbr, toxicity_type, toxicity_level)
VALUES (?, ?, ?, ?);
""",
[test_id, conv_nbr, toxic_type, toxic_val]
)
# Successful insert
aux_functions.conn.commit()
class VocabularySizeTest(AbstractConvTest, ABC):
""" """
def __init__(self):
self.test_id = 'VOCSZ'
self.vocabulary = {}
self.excluded_words = []
self.excluded_tokens = conversation.set_of_excluded_tokens()
self.contractions = self.specify_contractions()
self.frequency_dict_word2rank = self.read_frequency_dict()
self.frequency_dict_rank2word = self.read_frequency_dict_rank2word()
self.result_dict = {}
self.token_indicating_removal = "%%"
@staticmethod
def specify_contractions():
""" Method for specifying/declaring all contractions. That is "it's" = "it is" / "it has" etc."""
return contractions.contractions
@staticmethod
def read_frequency_dict():
""" Method for setting up the frequency list-dict, which is then used as basis for the whole tests, stating
which words that have which rankings. """
with open('miscellaneous .txt-files/count_1w.txt') as f:
lines = f.readlines()
lines = [elem.split('\t', 1) for elem in lines]
for elem in lines:
elem[1] = elem[1][:-2]
frequency_dict = {}
for i in range(len(lines)):
frequency_dict[lines[i][0]] = {
'rank': i + 1
}
return frequency_dict
@staticmethod
def read_frequency_dict_rank2word():
""" Method for setting up the frequency list-dict with a mapping from a rank to a word, stating which ranks
correspond with which words. """
with open('miscellaneous .txt-files/count_1w.txt') as f:
lines = f.readlines()
lines = [elem.split('\t', 1) for elem in lines]
for elem in lines:
elem[1] = elem[1][:-2]
for i in range(len(lines)):
lines[i] = lines[i][0]
return lines
def analyse_conversations(self, conversations: list):
""" Analyses all the conversations. Every conversation is analysed and the results are added to the results
dict, which is then returned. """
for i in range(len(conversations)):
conv = conversations[i]
results = self.analyse(conv)
try:
self.result_dict[conv.get_testee_id()]['Conversations'][int(i + 1)] = results
except KeyError:
self.result_dict[conv.get_testee_id()] = {}
self.result_dict[conv.get_testee_id()]['Conversations'] = {}
self.result_dict[conv.get_testee_id()]['Conversations'][int(i + 1)] = results
return self.result_dict
def analyse2(self, conv: Conversation):
"My attempt at a quicker version /Alex "
def is_testee(message: Message):
return message.agent_id == "Testee"
# Creates a generator of messages and joins them into one large string which is then lowered
testee_messages = filter(lambda x: is_testee(x), conv)
testee_strings = (str(m) for m in testee_messages)
testee_text = ' '.join(testee_strings).lower()
# Tokenizer that matches word characters and apostrophes
tokenizer = RegexpTokenizer(r"[\w']+")
spans = tokenizer.span_tokenize(testee_text)
rawtokens = (testee_text[begin : end] for (begin, end) in spans)
def get_words(tokens):
"Checks for contradictions"
for token in tokens:
# Token is a contracted word
if token in self.contractions.keys():
words = self.contractions[token].split("/")[0]
for word in words.split(" "):
yield word
else:
yield token
word_gen = get_words(rawtokens)
counter = Counter(word_gen)
# Split the words into different ranks or whatever
return counter
def analyse(self, conv: Conversation):
""" Function for storing words used by the GDM to keep track of its vocabulary.
Loops over all messages and per message, it splits it in order to isolate the words used. Then it removed
tokens such as ',', '.', '?', '!'. After these processes, it is added to the dict, either adds 1 to the amount
of usages for that word, or sets it to one if it is a new word. """
testee_id = conv.get_testee_id()
self.vocabulary[testee_id] = {
'word_counter': Counter(),
'frequency_word_list': {},
'non_frequent_words': Counter()
}
""" Loops over the messages in the conversation. Per message, it is first checked for whether it belongs to the
testee or not. It continues only if it belongs to the testee."""
for message in conv:
if message.get_role() != "Testee":
continue
word_array = str(message).split()
word_array = self.preprocess_word_array(word_array)
""" The one or several words that word_array may contain is counted and then added to the existing counter.
"""
counter = Counter(word_array)
self.vocabulary[testee_id]['word_counter'] += counter
""" Calling the method that adds any word to the frequency list. """
self.add_to_freq_list(word_array, testee_id)
return self.vocabulary[testee_id].copy()
def preprocess_word_array(self, word_array):
""" Method for preprocessing word_array so that it is ready to be inserted into a Counter for counting the
frequency per word. Removes tokens defined in the constructor from strings, and if the word is defined in the
constructor as an "excluded" word, it is not counted. If it is an contraction, it is prolonged to the full
meaning of the contraction. It is done by copying the array, then adding the correct words to the copy and
replacing the faulty words with self.token_indicating_removal, whose spots are removed in the end of this
function. """
fixed_word_array = word_array.copy()
for i in range(len(word_array)):
word = word_array[i]
word = word.lower()
word = conversation.clean_from_excluded_tokens(word)
if word in self.excluded_words or word in self.excluded_tokens:
fixed_word_array[i] = self.token_indicating_removal
continue
""" If the word is a contraction, its full meaning is found and inserted word-wise into the list called
word. If no contraction, the word is just inserted into the list. """
if word in self.contractions:
word = self.find_contraction(word)
for word_part in word:
fixed_word_array.append(word_part)
fixed_word_array[i] = self.token_indicating_removal
continue
fixed_word_array[i] = word
while "%%" in fixed_word_array:
fixed_word_array.remove(self.token_indicating_removal)
return fixed_word_array
def add_to_freq_list(self, word_list, testee_id):
""" Per word that may occur in word_list, the frequency list is also updated. That is the mapping from a
rank to a frequency. If it does not exist in the frequency list, it is added to the non-frequent
words-list, which means that the word did not exist in the current frequency list, meaning that it is
irregular. """
for word in word_list:
try:
""" Fetching the rank may cause KeyError if the specific word is non-existent in the frequency list. """
rank = self.frequency_dict_word2rank[word]['rank']
self.vocabulary[testee_id]['frequency_word_list'][rank] = \
self.vocabulary[testee_id]['word_counter'][word]
except KeyError:
""" Adds the counter to the existing counter in the non-frequent word list"""
word_counter = Counter([word])
self.vocabulary[testee_id]['non_frequent_words'] += word_counter
def get_id(self):
""" Returns the ID of this test. """
return self.test_id
def find_contraction(self, word):
""" Whenever a word is a contraction, this method finds its full meaning, which may be one or several
expressions. """
sentence_array = self.contractions[word].split('/')
word_array = []
if len(sentence_array) > 1:
for elem in sentence_array:
word_array += elem.split()
elif len(sentence_array) == 1:
word_array = sentence_array[0].split()
return word_array
def export_json_to_sqlite(self):
""" Method for specifying how to export/present the results. Loops over the GDMs and per GDM transfers the
test results into the sqlite-file. """
for gdm_id in list(self.result_dict.keys()):
date_time = datetime.datetime.now()
cursor = aux_functions.conn.cursor()
test_id = "{}:{}:{}".format(gdm_id, self.test_id, date_time)
cursor.execute(
"""
INSERT
INTO test_cases(test_id, gdm_id, date_time_run)
VALUES (?, ?, ?);
""",
[test_id, gdm_id, date_time]
)
# Successful insert
aux_functions.conn.commit()
""" Per conversation, it loops over the words that were counted in that conversation. Per word, the word
and its frequency in that conversation is transferred to the sqlite-database. """
for conv_nbr in self.result_dict[gdm_id]['Conversations']:
""" Per word rank that was logged from the test results, one unit is added times the frequency. This is
in order to fit the Grafana-way of producing histograms. """
for word_rank in self.result_dict[gdm_id]['Conversations'][conv_nbr]['frequency_word_list']:
for i in range(
self.result_dict[gdm_id]['Conversations'][conv_nbr]['frequency_word_list'][word_rank]):
cursor = aux_functions.conn.cursor()
""" Reads the word combined with the word_rank, with the purpose of clarification in the db, if
the user would like to double-check any word's frequency or whatever the reason. """
word = self.frequency_dict_rank2word[word_rank - 1]
cursor.execute(
"""
INSERT
INTO VOCSZ_frequency_list(test_id, conv_nbr, word, word_rank, frequency)
VALUES (?, ?, ?, ?, ?);
""",
[test_id, conv_nbr, word, word_rank, 1]
)
# Successful insert
aux_functions.conn.commit()
""" Also, the non-frequent words are transferred to make this data available as well. """
for non_freq_word in self.result_dict[gdm_id]['Conversations'][conv_nbr]['non_frequent_words']:
cursor = aux_functions.conn.cursor()
cursor.execute(
"""
INSERT
INTO VOCSZ_non_frequent_list(test_id, conv_nbr, word, frequency)
VALUES (?, ?, ?, ?);
""",
[test_id, conv_nbr, non_freq_word,
self.result_dict[gdm_id]['Conversations'][conv_nbr]['non_frequent_words'][non_freq_word]]
)
# Successful insert
aux_functions.conn.commit()
class CoherentResponseTest(AbstractConvTest, ABC):
""" COHER test testing for coherence between two responses."""
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.test_id = 'COHER'
self.bert_type = 'bert-base-uncased'
self.bert_tokenizer = BertTokenizer.from_pretrained(self.bert_type)
self.bert_model = BertForNextSentencePrediction.from_pretrained(self.bert_type).to(self.device)
self.result_dict = {}
def analyse_conversations(self, conversations: list):
""" Loops over all conversations, it analyses each conversation, and then it adds the results to the results
dict."""
for i in range(len(conversations)):
conv = conversations[i]
results = self.analyse(conv)
try:
self.result_dict[conv.get_testee_id()]['Conversations'][int(i + 1)] = results
except KeyError:
self.result_dict[conv.get_testee_id()] = {}
self.result_dict[conv.get_testee_id()]['Conversations'] = {}
self.result_dict[conv.get_testee_id()]['Conversations'][int(i + 1)] = results
return self.result_dict
def analyse(self, conv: Conversation):
""" Per conversation, the test case is performed. It produces a list of dicts, where every dict contains the two
compared messages, along with its NSP-prediction. """
results = list()
messages_testee = conv.filter_msgs("Testee")
messages_other_agent = conv.filter_gdm_preceding_msgs()
ns_predictions = self.batch_nsp(first_sentences=messages_other_agent, second_sentences=messages_testee)
for i in range(1, len(conv) - 1):
message = conv[i]
if message.get_role() == 'Testee':
result = {}
prev_message = str(conv[i - 1])
testee_message = str(message)
""" Pops out the index 0-element since that belongs to the currently analyzed pair of messages. Then, we
extract element 0 from that list, since we only need the positive prediction. """
next_sent_prediction = ns_predictions.pop(0)[0]
result['Previous message'] = str(prev_message)
result['Testee message'] = str(testee_message)
result['NSP-prediction'] = next_sent_prediction
results.append(result)
return results
def get_id(self):
return self.test_id
def next_sent_prediction(self, string1, string2):
""" Method for predicting whether string2 coherently follows string1 or not, using NSP-BERT. """
inputs = self.bert_tokenizer(string1, string2, return_tensors='pt')
outputs = self.bert_model(**inputs)
return self.softmax(outputs.logits.tolist()[0])
def batch_nsp(self, first_sentences: list, second_sentences: list):
""" Method for assessing NSP between two lists of sentences, with the purpose of improving the performance of
the test rather than NSP-analyzing message-wise. """
text_pairs = [(first, second) for first, second in zip(first_sentences, second_sentences)]
encodings = self.bert_tokenizer.batch_encode_plus(text_pairs, return_tensors="pt", padding=True).to(self.device)
outputs = self.bert_model(**encodings)
probs = outputs.logits.softmax(dim=-1)
return probs.tolist()
@staticmethod
def softmax(vector):
""" Softmax-function for interpreting the logits produced by NSP-BERT. """
e = exp(vector)
return e / e.sum()
def export_json_to_sqlite(self):
""" Method for exporting/presenting the results of this test into the sqlite-database. Per GDM, it inserts info
about which test that was performed on which GDM and at what datetime. """
for gdm_id in list(self.result_dict.keys()):
date_time = datetime.datetime.now()
cursor = aux_functions.conn.cursor()
test_id = "{}:{}:{}".format(gdm_id, self.test_id, date_time)
cursor.execute(
"""
INSERT
INTO test_cases(test_id, gdm_id, date_time_run)
VALUES (?, ?, ?);
""",
[test_id, gdm_id, date_time]
)
# Successful insert
aux_functions.conn.commit()
""" Per conversation and per test result, the previous and the tested message is inserted into the database,
along with the positive and negative predictions."""
for conv_nbr in self.result_dict[gdm_id]['Conversations']:
for tested_response_dict in self.result_dict[gdm_id]['Conversations'][conv_nbr]:
cursor = aux_functions.conn.cursor()
cursor.execute(
"""
INSERT
INTO COHER_results(test_id, conv_nbr, neg_pred)
VALUES (?, ?, ?);
""",
[test_id, conv_nbr, 1 - tested_response_dict['NSP-prediction']]
)
# Successful insert
aux_functions.conn.commit()
class ReadabilityIndexTest(AbstractConvTest, ABC):
""" READIND test testing for readability."""
def __init__(self):
self.test_id = 'READIND'
self.excluded_tokens = conversation.set_of_excluded_tokens()
self.result_dict = {}
def analyse_conversations(self, conversations: list):
""" Applies this test on all the conversations and then adds the results to the result dict. """
for i in range(len(conversations)):
conv = conversations[i]
results = self.analyse(conv)
try:
self.result_dict[conv.get_testee_id()]['Conversations'][int(i + 1)] = results
except KeyError:
self.result_dict[conv.get_testee_id()] = {}
self.result_dict[conv.get_testee_id()]['Conversations'] = {}
self.result_dict[conv.get_testee_id()]['Conversations'][int(i + 1)] = results
return self.result_dict
def analyse(self, conv: Conversation):
""" Per conversation, the test is applied and the results are stored. """
results = {
'amount_sentences': 0,
'amount_words': 0,
'amount_words_grt_6': 0
}
""" Loops over all messages, and per message it stores the amount of sentences, the amount of words and the
amount of words greater than six. Then, the readability index is calculated according to a formula."""
for message in conv:
if message.get_role() == 'Testee':
results['amount_sentences'] += conversation.count_sentences_within_string(str(message))
for word in str(message).split():
word = conversation.clean_from_excluded_tokens(word)
""" word needs to be no excluded token and longer than 0. That is since if the GDM has produced this
sentence as an example: "Hello , how are you today ?", then .clean_from_excluded_tokens would return
"" given "," or "?" as input, which should be disregarded. """
if word not in self.excluded_tokens and len(word) > 0:
results['amount_words'] += 1
if len(word) > 6:
results['amount_words_grt_6'] += 1
results['readability_index'] = results['amount_words'] / results['amount_sentences'] + \
results['amount_words_grt_6'] / results['amount_words'] * 100
return results
def get_id(self):
return self.test_id
def export_json_to_sqlite(self):
""" Method for transferring the test results into the database. More specifically, it loops over all GDMs, then
checks per conversation what the different metrics were, and then inserts those into the database. """
for gdm_id in list(self.result_dict.keys()):
date_time = datetime.datetime.now()
cursor = aux_functions.conn.cursor()
test_id = "{}:{}:{}".format(gdm_id, self.test_id, date_time)
cursor.execute(
"""
INSERT
INTO test_cases(test_id, gdm_id, date_time_run)
VALUES (?, ?, ?);
""",
[test_id, gdm_id, date_time]
)
# Successful insert
aux_functions.conn.commit()
for conv_nbr in self.result_dict[gdm_id]['Conversations']:
cursor = aux_functions.conn.cursor()
conv = self.result_dict[gdm_id]['Conversations'][conv_nbr]
cursor.execute(
"""
INSERT
INTO READIND_results(test_id, conv_nbr, readab_index)
VALUES (?, ?, ?);
""",
[test_id, conv_nbr, conv['readability_index']]
)
# Successful insert
aux_functions.conn.commit()
# ----------------------- Injected tests
""" Below are the implemented injected tests. """