Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compatible with torch.utils.data.DataLoader #154

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions pix2tex/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset, DataLoader
import numpy as np
import imagesize
import logging
Expand All @@ -15,10 +16,10 @@

from pix2tex.utils.utils import in_model_path
from pix2tex.dataset.transforms import train_transform, test_transform
import math



class Im2LatexDataset:
class Im2LatexDataset(IterableDataset):
keep_smaller_batches = False
shuffle = True
batchsize = 16
Expand All @@ -33,6 +34,7 @@ class Im2LatexDataset:
eos_token_id = 2
transform = train_transform
data = defaultdict(lambda: [])
permutation = None

def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, batchsize=16, max_seq_len=1024,
max_dimensions=(1024, 512), min_dimensions=(32, 32), pad=False, keep_smaller_batches=False, test=False):
Expand All @@ -42,7 +44,7 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba
equations (str, optional): Path to equations. Defaults to None.
images (str, optional): Directory where images are saved. Defaults to None.
tokenizer (str, optional): Path to saved tokenizer. Defaults to None.
shuffle (bool, opitonal): Defaults to True.
shuffle (bool, opitonal): Defaults to True.
batchsize (int, optional): Defaults to 16.
max_seq_len (int, optional): Defaults to 1024.
max_dimensions (tuple(int, int), optional): Maximal dimensions the model can handle
Expand Down Expand Up @@ -75,32 +77,39 @@ def __init__(self, equations=None, images=None, tokenizer=None, shuffle=True, ba
self.data[(width, height)].append((eqs[self.indices[i]], im))
except KeyboardInterrupt:
pass
# formula&image pairs grouped by image size
self.data = dict(self.data)
self._get_size()

self._shuffle()
iter(self)

def __len__(self):
return self.size
return self.size # total number of batches given the batchsize

def __iter__(self):
self.i = 0
self.transform = test_transform if self.test else train_transform
self.pairs = []
for k in self.data:
info = np.array(self.data[k], dtype=object)
p = torch.randperm(len(info)) if self.shuffle else torch.arange(len(info))
for i in range(0, len(info), self.batchsize):
batch = info[p[i:i+self.batchsize]]
batch = info[i:i+self.batchsize]
if len(batch.shape) == 1:
batch = batch[None, :]
if len(batch) < self.batchsize and not self.keep_smaller_batches:
continue
self.pairs.append(batch)
if self.shuffle:
self.pairs = np.random.permutation(np.array(self.pairs, dtype=object))
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
# configure the dataset to only process the split workload
per_worker = int(math.ceil(self.size/float(worker_info.num_workers)))
worker_id = worker_info.id
self.start = worker_id * per_worker
self.end = min(self.start + per_worker, self.size)
else:
self.pairs = np.array(self.pairs, dtype=object)
self.start, self.end = 0, self.size

self.pairs = np.array(self.pairs, dtype=object)[self.permutation[self.start:self.end]]
self.size = len(self.pairs)
return self

Expand All @@ -121,6 +130,8 @@ def prepare_data(self, batch):
"""

eqs, ims = batch.T
# for im in ims:
# print(im)
tok = self.tokenizer(list(eqs), return_token_type_ids=False)
# pad with bos and eos token
for k, p in zip(tok, [[self.bos_token_id, self.eos_token_id], [1, 1]]):
Expand Down Expand Up @@ -155,6 +166,15 @@ def _get_size(self):
for k in self.data:
div, mod = divmod(len(self.data[k]), self.batchsize)
self.size += div # + (1 if mod > 0 else 0)
if self.permutation is None or len(self.permutation) != self.size:
self._shuffle()

def _shuffle(self):
if self.shuffle:
self.permutation = np.random.permutation(self.size)
else:
self.permutation = np.arange(self.size)
return self

def load(self, filename, args=[]):
"""returns a pickled version of a dataset
Expand All @@ -169,6 +189,7 @@ def load(self, filename, args=[]):
filename = os.path.realpath(tempf)
with open(filename, 'rb') as file:
x = pickle.load(file)
x._get_size()
return x

def combine(self, x):
Expand Down Expand Up @@ -219,6 +240,17 @@ def update(self, **kwargs):
iter(self)


class Dataloader(DataLoader):
def __init__(self, dataset: Im2LatexDataset, batch_size=1, shuffle=False, *args, **kwargs):
self.dataset = dataset
self.dataset.update(batchsize=batch_size, shuffle=shuffle, *args, **kwargs)
lukas-blecher marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(self.dataset, *args, shuffle=False, batch_size=None, **kwargs)

def __iter__(self):
self.dataset._shuffle()
return super().__iter__()


def generate_tokenizer(equations, output, vocab_size):
from tokenizers import Tokenizer, pre_tokenizers
from tokenizers.models import BPE
Expand Down
6 changes: 3 additions & 3 deletions pix2tex/eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pix2tex.dataset.dataset import Im2LatexDataset
from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader
import argparse
import logging
import yaml
Expand Down Expand Up @@ -28,12 +28,12 @@ def detokenize(tokens, tokenizer):


@torch.no_grad()
def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: int = None, name: str = 'test'):
def evaluate(model: Model, dataset: Dataloader, args: Munch, num_batches: int = None, name: str = 'test'):
"""evaluates the model. Returns bleu score on the dataset

Args:
model (torch.nn.Module): the model
dataset (Im2LatexDataset): test dataset
dataset (Dataloader): test dataset
args (Munch): arguments
num_batches (int): How many batches to evaluate on. Defaults to None (all batches).
name (str, optional): name of the test e.g. val or test for wandb. Defaults to 'test'.
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/config-vit.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
num_workers: 0
betas:
- 0.9
- 0.999
Expand Down
1 change: 1 addition & 0 deletions pix2tex/model/settings/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
gpu_devices: null #[0,1,2,3,4,5,6,7]
num_workers: 0
backbone_layers:
- 2
- 3
Expand Down
4 changes: 4 additions & 0 deletions pix2tex/model/settings/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ pad: False
pad_token: 0
bos_token: 1
eos_token: 2

#devices(GPU&CPU)
num_workers: 0
gpu_devices: null #[0,1,2,3,4,5,6,7]
20 changes: 10 additions & 10 deletions pix2tex/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pix2tex.dataset.dataset import Im2LatexDataset
from pix2tex.dataset.dataset import Im2LatexDataset, Dataloader
import os
import argparse
import logging
Expand All @@ -16,12 +16,12 @@


def train(args):
dataloader = Im2LatexDataset().load(args.data)
dataloader.update(**args, test=False)
valdataloader = Im2LatexDataset().load(args.valdata)
train_dataset = Im2LatexDataset().load(args.data)
train_dataloader = Dataloader(train_dataset, **args, test=False)
lukas-blecher marked this conversation as resolved.
Show resolved Hide resolved
val_dataset = Im2LatexDataset().load(args.valdata)
valargs = args.copy()
valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True)
valdataloader.update(**valargs)
val_dataloader = Dataloader(val_dataset, **valargs)
lukas-blecher marked this conversation as resolved.
Show resolved Hide resolved
device = args.device
model = get_model(args)
if torch.cuda.is_available() and not args.no_cuda:
Expand All @@ -47,7 +47,7 @@ def save_models(e, step=0):
try:
for e in range(args.epoch, args.epochs):
args.epoch = e
dset = tqdm(iter(dataloader))
dset = tqdm(iter(train_dataloader))
lukas-blecher marked this conversation as resolved.
Show resolved Hide resolved
for i, (seq, im) in enumerate(dset):
if seq is not None and im is not None:
opt.zero_grad()
Expand All @@ -63,20 +63,20 @@ def save_models(e, step=0):
dset.set_description('Loss: %.4f' % total_loss)
if args.wandb:
wandb.log({'train/loss': total_loss})
if (i+1+len(dataloader)*e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if (i+1+len(train_dataloader)*e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, val_dataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if bleu_score > max_bleu and token_accuracy > max_token_acc:
max_bleu, max_token_acc = bleu_score, token_accuracy
save_models(e, step=i)
if (e+1) % args.save_freq == 0:
save_models(e, step=len(dataloader))
save_models(e, step=len(train_dataloader))
if args.wandb:
wandb.log({'train/epoch': e+1})
except KeyboardInterrupt:
if e >= 2:
save_models(e, step=i)
raise KeyboardInterrupt
save_models(e, step=len(dataloader))
save_models(e, step=len(train_dataloader))


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions pix2tex/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def parse_args(args, **kwargs) -> Munch:
args.update(kwargs)
args.wandb = not kwargs.debug and not args.debug
args.device = get_device(args, kwargs.no_cuda)
args.num_workers = args.get('num_workers', 0)
args.max_dimensions = [args.max_width, args.max_height]
args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)]
if 'decoder_args' not in args or args.decoder_args is None:
Expand Down