-
Notifications
You must be signed in to change notification settings - Fork 3
/
Classify.py
55 lines (45 loc) · 1.75 KB
/
Classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sys
from buildARfromlogs import AR,parseLine
from collections import defaultdict
import time
def ARfromString(line):
"""
Parse an Association Rule from a string
"""
items = line.split(' ==> ')
lhs = items[0]
lhs = lhs.replace('[','').replace(']','').split('), (')
newlhs = []
for edge in lhs:
newlhs.append(parseLine(edge))
rhs = items[1].split(' : ')[0]
confidence = float(items[1].split(' : ')[1].replace('\n',''))
return AR(newlhs,rhs,confidence)
def predict(ARs, graph,top=10):
"""
Returns a predicted label for a graph, given a list of Association Rules
top defines how many highest confidence Association Rules should be used
increase top if your list of Association Rules is large (> 100)
"""
relevantARs = []
for i, ar in enumerate(ARs):
#This is just some pretty printing of progress
sys.stdout.write('\r \r')
sys.stdout.write("Predicting label " + str(round(float(i)/float(len(ARs)) * 10000) / 100.0) + "%")
sys.stdout.flush()
if graph.subgraphIsomorphism(ar.lhs):
relevantARs.append(ar)
print "\r \rPredicting label 100%"
labels = set()
for ar in relevantARs:
labels.add(ar.rhs)
if len(labels) == 1:
return labels.pop()
if len(labels) == 0:
return 'Support of every rule is 0'
relevantARs = sorted(relevantARs,key=lambda x: x.confidence,reverse=True)
groupscores = defaultdict(lambda: 0)
for i in range(min(len(relevantARs),top)):
rule = relevantARs[i]
groupscores[rule.rhs] += 1
return max(groupscores.iterkeys(),key=lambda x: groupscores[x])