Skip to content

Optimize Index Encoder for constant time search #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions chebai_graph/preprocessing/property_encoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import abc
import os
import torch
from typing import Optional

import torch
import sys
from itertools import islice
import inspect


class PropertyEncoder(abc.ABC):
def __init__(self, property, **kwargs):
Expand Down Expand Up @@ -36,11 +40,13 @@ class IndexEncoder(PropertyEncoder):
def __init__(self, property, indices_dir=None, **kwargs):
super().__init__(property, **kwargs)
if indices_dir is None:
indices_dir = os.path.dirname(__file__)
indices_dir = os.path.dirname(inspect.getfile(self.__class__))
self.dirname = indices_dir
# load already existing cache
with open(self.index_path, "r") as pk:
self.cache = [x.strip() for x in pk]
self.cache: dict[str, int] = {
token.strip(): idx for idx, token in enumerate(pk)
}
self.index_length_start = len(self.cache)
self.offset = 0

Expand All @@ -64,19 +70,33 @@ def index_path(self):

def on_finish(self):
"""Save cache"""
with open(self.index_path, "w") as pk:
new_length = len(self.cache) - self.index_length_start
pk.writelines([f"{c}\n" for c in self.cache])
print(
f"saved index of property {self.property.name} to {self.index_path}, "
f"index length: {len(self.cache)} (new: {new_length})"
)
total_tokens = len(self.cache)
if total_tokens > self.index_length_start:
print("New tokens added to the cache, Saving them to index token file.....")

assert sys.version_info >= (
3,
7,
), "This code requires Python 3.7 or higher."
# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
new_tokens = list(islice(self.cache, self.index_length_start, total_tokens))

with open(self.index_path, "a") as pk:
pk.writelines([f"{c}\n" for c in new_tokens])
print(
f"New {len(new_tokens)} tokens append to index of property {self.property.name} to {self.index_path}..."
)
print(
f"Now, the total length of the index of property {self.property.name} is {total_tokens}"
)

def encode(self, token):
"""Returns a unique number for each token, automatically adds new tokens to the cache."""
if not str(token) in self.cache:
self.cache.append(str(token))
return torch.tensor([self.cache.index(str(token)) + self.offset])
self.cache[(str(token))] = len(self.cache)
return torch.tensor([self.cache[str(token)] + self.offset])


class OneHotEncoder(IndexEncoder):
Expand Down