Skip to content

Commit

Permalink
feat: playwright, spatial parsing, markdown for web search
Browse files Browse the repository at this point in the history
Co-authored-by: Aaditya Sahay <[email protected]>
  • Loading branch information
Saghen and Aaditya-Sahay committed May 3, 2024
1 parent 50febad commit 8c3db9a
Show file tree
Hide file tree
Showing 33 changed files with 1,719 additions and 449 deletions.
292 changes: 290 additions & 2 deletions package-lock.json

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"@types/jsdom": "^21.1.1",
"@types/minimist": "^1.2.5",
"@types/parquetjs": "^0.10.3",
"@types/sbd": "^1.0.5",
"@types/uuid": "^9.0.8",
"@typescript-eslint/eslint-plugin": "^6.x",
"@typescript-eslint/parser": "^6.x",
Expand All @@ -50,9 +51,11 @@
},
"type": "module",
"dependencies": {
"@cliqz/adblocker-playwright": "^1.27.2",
"@huggingface/hub": "^0.5.1",
"@huggingface/inference": "^2.6.3",
"@iconify-json/bi": "^1.1.21",
"@playwright/browser-chromium": "^1.43.1",
"@resvg/resvg-js": "^2.6.0",
"@xenova/transformers": "^2.16.1",
"autoprefixer": "^10.4.14",
Expand All @@ -74,10 +77,12 @@
"parquetjs": "^0.11.2",
"pino": "^9.0.0",
"pino-pretty": "^11.0.0",
"playwright": "^1.40.0",
"postcss": "^8.4.31",
"saslprep": "^1.0.3",
"satori": "^0.10.11",
"satori-html": "^0.3.2",
"sbd": "^1.0.19",
"serpapi": "^1.1.1",
"sharp": "^0.33.2",
"tailwind-scrollbar": "^3.0.0",
Expand Down
6 changes: 3 additions & 3 deletions src/lib/components/chat/ChatMessage.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,18 @@
{#if webSearchSources?.length}
<div class="mt-4 flex flex-wrap items-center gap-x-2 gap-y-1.5 text-sm">
<div class="text-gray-400">Sources:</div>
{#each webSearchSources as { link, title, hostname }}
{#each webSearchSources as { link, title }}
<a
class="flex items-center gap-2 whitespace-nowrap rounded-lg border bg-white px-2 py-1.5 leading-none hover:border-gray-300 dark:border-gray-800 dark:bg-gray-900 dark:hover:border-gray-700"
href={link}
target="_blank"
>
<img
class="h-3.5 w-3.5 rounded"
src="https://www.google.com/s2/favicons?sz=64&domain_url={hostname}"
src="https://www.google.com/s2/favicons?sz=64&domain_url={new URL(link).hostname}"
alt="{title} favicon"
/>
<div>{hostname.replace(/^www\./, "")}</div>
<div>{new URL(link).hostname.replace(/^www\./, "")}</div>
</a>
{/each}
</div>
Expand Down
7 changes: 6 additions & 1 deletion src/lib/server/embeddingEndpoints/hfApi/embeddingHfApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ export async function embeddingEndpointHfApi(
"Content-Type": "application/json",
...(authorization ? { Authorization: authorization } : {}),
},
body: JSON.stringify({ inputs: batchInputs }),
body: JSON.stringify({
inputs: {
source_sentence: batchInputs[0],
sentences: batchInputs.slice(1),
},
}),
});

if (!response.ok) {
Expand Down
60 changes: 41 additions & 19 deletions src/lib/server/isURLLocal.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,48 @@
import { Address6, Address4 } from "ip-address";

import dns from "node:dns";

export async function isURLLocal(URL: URL): Promise<boolean> {
const isLocal = new Promise<boolean>((resolve, reject) => {
dns.lookup(URL.hostname, (err, address, family) => {
if (err) {
reject(err);
}
if (family === 4) {
const addr = new Address4(address);
resolve(addr.isInSubnet(new Address4("127.0.0.0/8")));
} else if (family === 6) {
const addr = new Address6(address);
resolve(
addr.isLoopback() || addr.isInSubnet(new Address6("::1/128")) || addr.isLinkLocal()
);
} else {
reject(new Error("Unknown IP family"));
}
const dnsLookup = (hostname: string): Promise<{ address: string; family: number }> => {
return new Promise((resolve, reject) => {
dns.lookup(hostname, (err, address, family) => {
if (err) return reject(err);
resolve({ address, family });
});
});
};

export async function isURLLocal(URL: URL): Promise<boolean> {
const { address, family } = await dnsLookup(URL.hostname);

if (family === 4) {
const addr = new Address4(address);
const localSubnet = new Address4("127.0.0.0/8");
return addr.isInSubnet(localSubnet);
}

if (family === 6) {
const addr = new Address6(address);
return addr.isLoopback() || addr.isInSubnet(new Address6("::1/128")) || addr.isLinkLocal();
}

throw Error("Unknown IP family");
}

export function isURLStringLocal(url: string) {
try {
const urlObj = new URL(url);
return isURLLocal(urlObj);
} catch (e) {
// assume local if URL parsing fails
return true;
}
}

return isLocal;
// TODO: move this to a generic url helper
export function isURL(url: string) {
try {
new URL(url);
return true;
} catch (e) {
return false;
}
}
10 changes: 4 additions & 6 deletions src/lib/server/preprocessMessages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@ export async function preprocessMessages(
return await Promise.all(
structuredClone(messages).map(async (message, idx) => {
const webSearchContext = webSearch?.contextSources
.map(({ context }) => context)
.flat()
.sort((a, b) => a.idx - b.idx)
.map(({ text }) => text)
.join(" ");
.map(({ context }) => context.trim())
.join("\n\n----------\n\n");

// start by adding websearch to the last message
if (idx === messages.length - 1 && webSearch && webSearchContext?.trim()) {
const lastQuestion = messages.findLast((el) => el.from === "user")?.content ?? "";
Expand All @@ -27,7 +25,7 @@ export async function preprocessMessages(
.map((el) => el.content);
const currentDate = format(new Date(), "MMMM d, yyyy");

message.content = `I searched the web using the query: ${webSearch.searchQuery}.
message.content = `I searched the web using the query: ${webSearch.searchQuery}.
Today is ${currentDate} and here are the results:
=====================
${webSearchContext}
Expand Down
29 changes: 9 additions & 20 deletions src/lib/server/sentenceSimilarity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@ import type { EmbeddingBackendModel } from "$lib/server/embeddingModels";
import type { Embedding } from "$lib/server/embeddingEndpoints/embeddingEndpoints";

// see here: https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/README.md?plain=1#L34
function innerProduct(embeddingA: Embedding, embeddingB: Embedding) {
export function innerProduct(embeddingA: Embedding, embeddingB: Embedding) {
return 1.0 - dot(embeddingA, embeddingB);
}

export async function findSimilarSentences(
export async function getSentenceSimilarity(
embeddingModel: EmbeddingBackendModel,
query: string,
sentences: string[],
{ topK = 5 }: { topK: number }
): Promise<Embedding> {
sentences: string[]
): Promise<{ distance: number; embedding: Embedding; idx: number }[]> {
const inputs = [
`${embeddingModel.preQuery}${query}`,
...sentences.map((sentence) => `${embeddingModel.prePassage}${sentence}`),
Expand All @@ -24,19 +23,9 @@ export async function findSimilarSentences(
const queryEmbedding: Embedding = output[0];
const sentencesEmbeddings: Embedding[] = output.slice(1);

const distancesFromQuery: { distance: number; index: number }[] = [...sentencesEmbeddings].map(
(sentenceEmbedding: Embedding, index: number) => {
return {
distance: innerProduct(queryEmbedding, sentenceEmbedding),
index,
};
}
);

distancesFromQuery.sort((a, b) => {
return a.distance - b.distance;
});

// Return the indexes of the closest topK sentences
return distancesFromQuery.slice(0, topK).map((item) => item.index);
return sentencesEmbeddings.map((sentenceEmbedding, idx) => ({
distance: innerProduct(queryEmbedding, sentenceEmbedding),
embedding: sentenceEmbedding,
idx,
}));
}
72 changes: 72 additions & 0 deletions src/lib/server/websearch/embed/embed.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import type { WebSearchScrapedSource, WebSearchUsedSource } from "$lib/types/WebSearch";
import type { EmbeddingBackendModel } from "../../embeddingModels";
import { getSentenceSimilarity, innerProduct } from "../../sentenceSimilarity";
import { MarkdownElementType, type MarkdownElement } from "../markdown/types";
import { stringifyMarkdownElement } from "../markdown/utils/stringify";
import { flattenTree } from "./tree";

const MIN_CHARS = 3000;
const SOFT_MAX_CHARS = 8000;

export async function findContextSources(
sources: WebSearchScrapedSource[],
prompt: string,
embeddingModel: EmbeddingBackendModel
) {
const sourcesMarkdownElems = sources.map((source) => flattenTree(source.page.markdownTree));
const markdownElems = sourcesMarkdownElems.flat();

const embeddings = await getSentenceSimilarity(
embeddingModel,
prompt,
markdownElems
.map(stringifyMarkdownElement)
// Safety in case the stringified markdown elements are too long
// but chunking should have happened earlier
.map((elem) => elem.slice(0, embeddingModel.chunkCharLength))
);

const topEmbeddings = embeddings
.sort((a, b) => a.distance - b.distance)
.filter((embedding) => markdownElems[embedding.idx].type !== MarkdownElementType.Header);

let totalChars = 0;
const selectedMarkdownElems = new Set<MarkdownElement>();
const selectedEmbeddings: number[][] = [];
for (const embedding of topEmbeddings) {
const elem = markdownElems[embedding.idx];

// Ignore elements that are too similar to already selected elements
const tooSimilar = selectedEmbeddings.some(
(selectedEmbedding) => innerProduct(selectedEmbedding, embedding.embedding) < 0.01
);
if (tooSimilar) continue;

// Add element
if (!selectedMarkdownElems.has(elem)) {
selectedMarkdownElems.add(elem);
selectedEmbeddings.push(embedding.embedding);
totalChars += elem.content.length;
}

// Add element's parent (header)
if (elem.parent && !selectedMarkdownElems.has(elem.parent)) {
selectedMarkdownElems.add(elem.parent);
totalChars += elem.parent.content.length;
}

if (totalChars > SOFT_MAX_CHARS) break;
if (totalChars > MIN_CHARS && embedding.distance > 0.25) break;
}

const contextSources = sourcesMarkdownElems
.map<WebSearchUsedSource>((elems, idx) => {
const sourceSelectedElems = elems.filter((elem) => selectedMarkdownElems.has(elem));
const context = sourceSelectedElems.map(stringifyMarkdownElement).join("\n");
const source = sources[idx];
return { ...source, context };
})
.filter((contextSource) => contextSource.context.length > 0);

return contextSources;
}
6 changes: 6 additions & 0 deletions src/lib/server/websearch/embed/tree.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import type { MarkdownElement } from "../markdown/types";

export function flattenTree(elem: MarkdownElement): MarkdownElement[] {
if ("children" in elem) return [elem, ...elem.children.flatMap(flattenTree)];
return [elem];
}
98 changes: 98 additions & 0 deletions src/lib/server/websearch/markdown/fromHtml.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import { collapseString, sanitizeString } from "./utils/nlp";
import { stringifyHTMLElements, stringifyHTMLElementsUnformatted } from "./utils/stringify";
import { MarkdownElementType, tagNameMap, type HeaderElement, type MarkdownElement } from "./types";
import type { SerializedHTMLElement } from "../scrape/types";

type ConversionState = {
defaultType:
| MarkdownElementType.Paragraph
| MarkdownElementType.BlockQuote
| MarkdownElementType.UnorderedListItem
| MarkdownElementType.OrderedListItem;
listDepth: number;
blockQuoteDepth: number;
};
export function htmlElementToMarkdownElements(
parent: HeaderElement,
elem: SerializedHTMLElement | string,
prevState: ConversionState = {
defaultType: MarkdownElementType.Paragraph,
listDepth: 0,
blockQuoteDepth: 0,
}
): MarkdownElement | MarkdownElement[] {
// Found text so create an element based on the previous state
if (typeof elem === "string") {
if (elem.trim().length === 0) return [];
if (
prevState.defaultType === MarkdownElementType.UnorderedListItem ||
prevState.defaultType === MarkdownElementType.OrderedListItem
) {
return {
parent,
type: prevState.defaultType,
content: elem,
depth: prevState.listDepth,
};
}
if (prevState.defaultType === MarkdownElementType.BlockQuote) {
return {
parent,
type: prevState.defaultType,
content: elem,
depth: prevState.blockQuoteDepth,
};
}
return { parent, type: prevState.defaultType, content: elem };
}

const type = tagNameMap[elem.tagName] ?? MarkdownElementType.Paragraph;

// Update the state based on the current element
const state: ConversionState = { ...prevState };
if (type === MarkdownElementType.UnorderedList || type === MarkdownElementType.OrderedList) {
state.listDepth += 1;
state.defaultType =
type === MarkdownElementType.UnorderedList
? MarkdownElementType.UnorderedListItem
: MarkdownElementType.OrderedListItem;
}
if (type === MarkdownElementType.BlockQuote) {
state.defaultType = MarkdownElementType.BlockQuote;
state.blockQuoteDepth += 1;
}

// Headers
if (type === MarkdownElementType.Header) {
return {
parent,
type,
level: Number(elem.tagName[1]),
content: collapseString(stringifyHTMLElements(elem.content)),
children: [],
};
}

// Code blocks
if (type === MarkdownElementType.CodeBlock) {
return {
parent,
type,
content: sanitizeString(stringifyHTMLElementsUnformatted(elem.content)),
};
}

// Typical case, we want to flatten the DOM and only create elements when we see text
return elem.content.flatMap((el) => htmlElementToMarkdownElements(parent, el, state));
}

export function mergeAdjacentElements(elements: MarkdownElement[]): MarkdownElement[] {
return elements.reduce<MarkdownElement[]>((acc, elem) => {
const last = acc[acc.length - 1];
if (last && last.type === MarkdownElementType.Paragraph && last.type === elem.type) {
last.content += elem.content;
return acc;
}
return [...acc, elem];
}, []);
}

0 comments on commit 8c3db9a

Please sign in to comment.