-
Notifications
You must be signed in to change notification settings - Fork 0
/
mammotab_cpa.py
67 lines (59 loc) · 2.83 KB
/
mammotab_cpa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import csv
import gzip
from gzip import BadGzipFile
import json
from tqdm import tqdm
from lamapi import LamAPI
from collections import Counter
lamapi = LamAPI()
cache = {}
def extract_relations(entities):
"""Find relations"""
cpa_annotations = {}
for entity_row in entities:
for index, cell_annotation in enumerate(entity_row):
for index_2, cell_annotation_2 in enumerate(entity_row):
if index < index_2 and index != index_2 and cell_annotation != '' and cell_annotation != '':
if f"{cell_annotation}_{cell_annotation_2}" not in cache:
result = lamapi.predicates(
[[cell_annotation, cell_annotation_2]])
if result:
if "message" not in result:
cache[f"{cell_annotation}_{cell_annotation_2}"] = result[f"{cell_annotation} {cell_annotation_2}"]
if f"{index}_{index_2}" not in cpa_annotations:
cpa_annotations[f"{index}_{index_2}"] = result[f"{cell_annotation} {cell_annotation_2}"]
else:
cpa_annotations[f"{index}_{index_2}"].extend(
result[f"{cell_annotation} {cell_annotation_2}"])
else:
if f"{index}_{index_2}" not in cpa_annotations:
cpa_annotations[f"{index}_{index_2}"] = cache[f"{cell_annotation}_{cell_annotation_2}"]
else:
cpa_annotations[f"{index}_{index_2}"].extend(
cache[f"{cell_annotation}_{cell_annotation_2}"])
final_cpa_dict = {}
for key, value in cpa_annotations.items():
current_cpa = Counter(value).most_common()[0][0]
final_cpa_dict[key] = current_cpa
return final_cpa_dict
cpa = []
for folder in tqdm(os.listdir("./wiki_tables_enriched/")):
for table in os.listdir(f"./wiki_tables_enriched/{folder}/"):
with gzip.open(f"./wiki_tables_enriched/{folder}/{table}", 'rb') as file:
try:
text = file.read()
except BadGzipFile as e:
continue
my_json = text.decode('utf8')
data = json.loads(my_json)
for table in data["tables"].keys():
entities = [
row_gt for row_gt in data["tables"][table]["entity"]]
cpa_result = extract_relations(entities)
for key, value in cpa_result.items():
col_1, col_2 = key.split("_")
cpa.append([table, col_1, col_2, value])
with open('cpa_gt.csv', 'w', newline='', encoding='utf8') as file:
writer = csv.writer(file)
writer.writerows(cpa)