Skip to content

Commit 08fc1b1

Browse files
committed
feat: store full pages
1 parent 7a2ae2b commit 08fc1b1

File tree

8 files changed

+119
-83
lines changed

8 files changed

+119
-83
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.venv
22
.direnv
33
__pycache__
4+
*.xml

Dockerfile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,7 @@ WORKDIR /workspace
44
COPY requirements.txt .
55
RUN pip install -r requirements.txt
66

7-
ENTRYPOINT [ "python" ]
8-
CMD [ "main.py" ]
7+
# ENTRYPOINT [ "python" ]
8+
# CMD [ "main.py" ]
9+
10+
CMD ["sleep", "infinity"]

chunker.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
def chunk(s, chunkSize=256, overlap=64):
2+
words = s.split(" ")
3+
return [" ".join(words[i:i+chunkSize]) for i in range(0, len(words), chunkSize - overlap)]
4+
5+
if __name__ == "__main__":
6+
cases = [
7+
(
8+
"",
9+
[""],
10+
),
11+
(
12+
"hi",
13+
["hi"],
14+
),
15+
(
16+
"this is a test",
17+
["this is a test"],
18+
),
19+
(
20+
"this is a long test with more than ten words so that we can test overlap",
21+
[
22+
"this is a long test with more than ten words",
23+
"ten words so that we can test overlap",
24+
],
25+
),
26+
]
27+
28+
29+
print("Testing chunk function.")
30+
for case in cases:
31+
input = case[0]
32+
expected = case[1]
33+
actual = chunk(input, 10, 2)
34+
print("\nInput: %s \nExpected: %s\nActual: %s" % (str(input), str(expected), str(actual)))
35+
assert actual == expected, "%s != %s" % (actual, expected)
36+
37+
print('\nAll tests passed.')
38+

docker-compose.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@ services:
1212
- ollama-models:/var/lib/ollama/models
1313
- ollama-config:/root/.ollama # stores private key
1414

15-
# GPU access for ROCm (AMD GPUs)
15+
# GPU access for ROCm (AMD)
1616
devices:
1717
- /dev/dri:/dev/dri
18-
- /dev/kfd:/dev/kfd
18+
# - /dev/kfd:/dev/kfd
1919

20-
# GPU access for CUDA (NVIDIA)
20+
# GPU access for CUDA (NVIDIA) - untested
2121
# deploy:
2222
# resources:
2323
# reservations:

import_dump.py

Lines changed: 47 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,110 +3,96 @@
33
import tqdm
44
import re
55

6-
import ml
6+
import models
77
import postgres
8+
import chunker
9+
810

911
def strip_wikitext(s):
10-
HTML_FILTERS = {
11-
'div': ['navbox','navbox-styles','spoken-wikipedia', 'noprint', 'hatnote', 'rt-tooltip', 'reflist'],
12-
'span': ['mw-ext-cite-error'],
13-
'table': ['noprint','ombox'],
14-
'ol': ['breadcrumb-nav-container', 'references'],
15-
'sup': ['reference']
16-
}
17-
REGEX_FILTERS = {
18-
'p': '→.*ersion'
12+
HTML_FILTERS = {
13+
"div": [
14+
"navbox",
15+
"navbox-styles",
16+
"spoken-wikipedia",
17+
"noprint",
18+
"hatnote",
19+
"rt-tooltip",
20+
"reflist",
21+
],
22+
"span": ["mw-ext-cite-error"],
23+
"table": ["noprint", "ombox"],
24+
"ol": ["breadcrumb-nav-container", "references"],
25+
"sup": ["reference"],
1926
}
27+
REGEX_FILTERS = {"p": "→.*ersion"}
2028

2129
def filterHtml(soup):
22-
for figure in soup.find_all('figure'):
30+
for figure in soup.find_all("figure"):
2331
figure.decompose()
2432

2533
for tag, classes in HTML_FILTERS.items():
2634
for className in classes:
27-
for div in soup.find_all(tag, {'class': className}):
35+
for div in soup.find_all(tag, {"class": className}):
2836
div.decompose()
2937

3038
for tag, regex in REGEX_FILTERS.items():
3139
for element in soup.find_all(tag):
32-
if(re.search(regex, str(element)) != None):
40+
if re.search(regex, str(element)) != None:
3341
element.decompose()
3442

3543
return soup
3644

37-
if s is None: return None
45+
if s is None:
46+
return None
3847

39-
soup = bs4.BeautifulSoup(s, 'lxml')
48+
soup = bs4.BeautifulSoup(s, "lxml")
4049
text = filterHtml(soup).get_text()
4150
text = text.strip()
4251

43-
if len(text) == 0: return None
44-
if text.lower().startswith("#redirect"): return None
52+
if len(text) == 0:
53+
return None
54+
if text.lower().startswith("#redirect"):
55+
return None
4556

4657
return text
4758

48-
def chunk(s):
49-
words = s.split(" ")
50-
CHUNK_SIZE = 256
51-
OVERLAP = 64
52-
return [" ".join(words[i:i+CHUNK_SIZE]) for i in range(0, len(words), CHUNK_SIZE - OVERLAP)]
53-
54-
# cases = [
55-
# (
56-
# "",
57-
# [""],
58-
# ),
59-
# (
60-
# "hi",
61-
# ["hi"],
62-
# ),
63-
# (
64-
# "this is a test",
65-
# ["this is a test"],
66-
# ),
67-
# (
68-
# "this is a long test with more than ten words so that we can test overlap",
69-
# [
70-
# "this is a long test with more than ten words",
71-
# "ten words so that we can test overlap",
72-
# ],
73-
# ),
74-
# ]
75-
#
76-
#
77-
# for case in cases:
78-
# print("Testing chunk function.")
79-
# print("Input: %s" % str(case))
80-
# outexp = case[1]
81-
# outactual = chunk(case[0])
82-
# assert outactual == outexp, "%s != %s" % (outactual, outexp)
83-
#
84-
postgres.init(embeddingLength=ml.embeddingLength())
59+
60+
postgres.init(embeddingLength=models.embeddingLength())
8561

8662
with postgres.get_connection().cursor() as cur:
8763
with open("dump.xml", "rb") as f:
8864
dump = mwxml.Dump.from_file(f)
8965

9066
for page in tqdm.tqdm(dump.pages):
9167
title = page.title
92-
if title is None: continue
68+
if title is None:
69+
continue
9370
if re.search("/[a-z][a-z][a-z]?(-[a-z]+)?$", title):
9471
# print(f"skipping {title}")
9572
continue
9673

9774
# Delete existing page chunks, that is, update if we know about it already
98-
cur.execute("DELETE FROM page_text WHERE title = %s;", (title,))
75+
cur.execute("DELETE FROM chunks WHERE title = %s;", (title,))
76+
cur.execute("DELETE FROM pages WHERE title = %s;", (title,))
9977

10078
# We support only one revision in the dump
10179
text = list(page)[0].text
10280

10381
text = strip_wikitext(text)
104-
if text is None: continue
82+
if text is None:
83+
continue
10584

106-
for c in chunk(text):
107-
embedding = ml.embeddingString(c)
108-
cur.execute("INSERT INTO page_text (title, text, embedding) VALUES (%s, %s, %s);",
109-
(title, c, embedding))
85+
cur.execute(
86+
"INSERT INTO pages (title, text) VALUES (%s, %s);",
87+
(title, text),
88+
)
89+
90+
for c in chunker.chunk(text):
91+
embedding = models.embeddingString(c)
92+
cur.execute(
93+
"INSERT INTO chunks (title, text, embedding) VALUES (%s, %s, %s);",
94+
(title, c, embedding),
95+
)
11096

11197
# Commit the transaction
11298
postgres.get_connection().commit()

main.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import ml
1+
import models
22
import postgres
33

4-
# print(ml.chat("What is the meaning of life?"))
5-
64
query = "event 2024 tallinn"
7-
emb = ml.embeddingString(query)
5+
emb = models.embeddingString(query)
86
cur = postgres.get_connection().cursor()
97
cur.execute("SELECT text FROM page_text ORDER BY embedding <-> %s LIMIT 5;", (emb,))
108
res = cur.fetchall()
@@ -20,7 +18,7 @@
2018
print("Prompt: " + prompt + "\n\n\n")
2119

2220
print("********************************************************************************")
23-
result = ml.chat(prompt)
21+
result = models.chat(prompt)
2422
print(result)
2523

2624
print("********************************************************************************")
@@ -37,6 +35,6 @@
3735
"""
3836
print("Prompt 2: " + prompt2)
3937
print("********************************************************************************")
40-
result2 = ml.chat(prompt2)
38+
result2 = models.chat(prompt2)
4139
print(result2)
4240

ml.py renamed to models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
embeddingsModel = "all-minilm" # or: mxbai-embed-large
66
# chatModel = "qwen:0.5b"
7-
chatModel = "gemma:7b"
8-
# chatModel = "mistral:v0.2"
7+
# chatModel = "gemma:7b"
8+
chatModel = "mistral:v0.2"
99

1010

1111
def get_connection():

postgres.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import psycopg
2+
from psycopg.sql import SQL, Literal
23

34
_db: psycopg.Connection | None = None
45

@@ -13,6 +14,9 @@ def get_connection() -> psycopg.Connection:
1314

1415

1516
def init(embeddingLength: int):
17+
if not isinstance(embeddingLength, int) or embeddingLength <= 0:
18+
raise ValueError("Invalid embedding length")
19+
1620
db = get_connection()
1721
cur = db.cursor()
1822

@@ -23,14 +27,21 @@ def init(embeddingLength: int):
2327
db.commit()
2428

2529
cur.execute(
26-
"""
27-
CREATE TABLE IF NOT EXISTS page_text (
28-
id SERIAL PRIMARY KEY,
29-
title VARCHAR(255) NOT NULL,
30-
text TEXT NOT NULL,
31-
embedding vector( %s ) NOT NULL
32-
);
33-
""",
34-
(embeddingLength,),
30+
SQL(
31+
"""
32+
CREATE TABLE IF NOT EXISTS pages (
33+
id SERIAL PRIMARY KEY,
34+
title VARCHAR(255) NOT NULL,
35+
text TEXT NOT NULL
36+
);
37+
38+
CREATE TABLE IF NOT EXISTS chunks (
39+
id SERIAL PRIMARY KEY,
40+
title VARCHAR(255) NOT NULL,
41+
text TEXT NOT NULL,
42+
embedding vector( {} ) NOT NULL
43+
);
44+
"""
45+
).format(Literal(str(embeddingLength)))
3546
)
3647
db.commit()

0 commit comments

Comments
 (0)