-
Notifications
You must be signed in to change notification settings - Fork 0
/
Pruning.py
64 lines (59 loc) · 2.6 KB
/
Pruning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from Utility import *
from timeit import default_timer as timer
# METODO CHE ESGUE LE POTATURE
def pruning(rules, validation_df, training_df):
improved = True
count = 0
start = timer()
old_original = make_predictions_rule(validation_df, rules)
new_accuracy = old_original
while improved:
count += 1
print(count)
rules, new_accuracy, improved = prune_rule(validation_df, rules, new_accuracy)
print(new_accuracy)
# se la accuracy sul validation è maggiore o ugaule di quella sul traim mi stoppo
accuracy_train_df = make_predictions_rule(training_df, rules)
if new_accuracy >= accuracy_train_df:
improved = False
end = timer()
print("# DI POTATURE:", count-1)
print("TEMPO DI ESECUZIONE:", (end - start))
return old_original, rules, new_accuracy
# METODO CHE TROVA LA MIGLIORE POTATURA DA FARE E LA ESEGUE
def prune_rule(df, rules, accuracy):
copy_rules = copy.deepcopy(rules)
prune_values = []
improved = False
best_accuracy = accuracy
for index_rule in range(0, len(rules)):
current_rule = rules[index_rule]
# ciclo su tutti gli elementi e provo a potarne uno alla volta
for index_element in range(0, len(current_rule) - 1):
copy_rule = copy.deepcopy(current_rule)
copy_rules = copy.deepcopy(rules)
if len(copy_rule) > 2:
copy_rule.pop(index_element)
# rimuovo elemento da regola e sostituisco la regola nella lista delle regole
copy_rules[index_rule] = copy_rule
copy_rules = np.unique(copy_rules)
accuracy = make_predictions_rule(df, copy_rules)
prune_values.append([index_rule, index_element, accuracy])
# Trovo potatura che porta a maggiore accuratezza, creando un dataframe per fare argmax della colonna
prune_values = pd.DataFrame(prune_values, columns=["index_rule", "index_element", "accuracy"])
max = prune_values.accuracy.argmax()
best_prune = prune_values.iloc[max]
if best_prune[2] > best_accuracy:
best_prune = prune_values.iloc[max]
best_accuracy = best_prune[2]
improved = True
copy_rules = copy.deepcopy(rules)
# Apporto la modifica alla regola
if improved:
copy_rules[int(best_prune[0])].pop(int(best_prune[1]))
# elimino duplicati
copy_rules = np.unique(copy_rules)
else: # non ho trovato nessun miglioramento
print("NESSUN MIGLIORAMENTO")
end = timer()
return copy_rules, best_accuracy, improved