Skip to content

Commit

Permalink
small adjustments
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-blecher committed May 21, 2022
1 parent dbf75d9 commit 9f974c9
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 29 deletions.
2 changes: 1 addition & 1 deletion pix2tex/model/settings/config-vit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ max_seq_len: 512
max_width: 672
min_height: 32
min_width: 32
micro_batchsize: 64
micro_batchsize: -1
model_path: checkpoints_add
name: pix2tex-vit
num_layers: 4
Expand Down
2 changes: 1 addition & 1 deletion pix2tex/model/settings/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ backbone_layers:
betas:
- 0.9
- 0.999
batchsize: 10
batchsize: 64
bos_token: 1
channels: 1
data: dataset/data/train.pkl
Expand Down
30 changes: 8 additions & 22 deletions pix2tex/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,7 @@
from pix2tex.eval import evaluate
from pix2tex.models import get_model
# from pix2tex.utils import *
from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler


def gpu_memory_check(model, args):
# check if largest batch can be handled by system
try:
batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize
for _ in range(5):
im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long()
loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq)
loss.sum().backward()
except RuntimeError:
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width))
model.zero_grad()
with torch.cuda.device(args.device):torch.cuda.empty_cache()
del im, seq
from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler, gpu_memory_check


def train(args):
Expand All @@ -40,13 +24,15 @@ def train(args):
valdataloader.update(**valargs)
device = args.device
model = get_model(args)
gpu_memory_check(model, args)
if args.load_chkpt is not None:
model.load_state_dict(torch.load(args.load_chkpt, map_location=device))
if torch.cuda.is_available() and not args.no_cuda:
gpu_memory_check(model, args)
max_bleu, max_token_acc = 0, 0
out_path = os.path.join(args.model_path, args.name)
os.makedirs(out_path, exist_ok=True)

if args.load_chkpt is not None:
model.load_state_dict(torch.load(args.load_chkpt, map_location=device))

def save_models(e, step=0):
torch.save(model.state_dict(), os.path.join(out_path, '%s_e%02d_step%02d.pth' % (args.name, e+1, step)))
yaml.dump(dict(args), open(os.path.join(out_path, 'config.yaml'), 'w+'))
Expand Down Expand Up @@ -88,9 +74,9 @@ def save_models(e, step=0):
wandb.log({'train/epoch': e+1})
except KeyboardInterrupt:
if e >= 2:
save_models(e)
save_models(e, step=i)
raise KeyboardInterrupt
save_models(e)
save_models(e, step=len(dataloader))


if __name__ == '__main__':
Expand Down
28 changes: 23 additions & 5 deletions pix2tex/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,44 @@ def seed_everything(seed: int):
def parse_args(args, **kwargs) -> Munch:
args = Munch({'epoch': 0}, **args)
kwargs = Munch({'no_cuda': False, 'debug': False}, **kwargs)
args.update(kwargs)
args.wandb = not kwargs.debug and not args.debug
args.device = get_device(args, kwargs)
args.device = get_device(args, kwargs.no_cuda)
args.max_dimensions = [args.max_width, args.max_height]
args.min_dimensions = [args.get('min_width', 32), args.get('min_height', 32)]
if 'decoder_args' not in args or args.decoder_args is None:
args.decoder_args = {}
return args


def get_device(args, kwargs):
def get_device(args, no_cuda=False):
device = 'cpu'
available_gpus = torch.cuda.device_count()
args.gpu_devices = args.gpu_devices if args.get('gpu_devices', False) else range(available_gpus)
if available_gpus > 0 and not kwargs.no_cuda:
args.gpu_devices = args.gpu_devices if args.get('gpu_devices', False) else list(range(available_gpus))
if available_gpus > 0 and not no_cuda:
device = 'cuda:%d' % args.gpu_devices[0] if args.gpu_devices else 0
assert available_gpus >= len(args.gpu_devices), "Available %d gpu, but specified gpu %s." % (available_gpus, ','.join(map(str, args.gpu_devices)))
assert max(args.gpu_devices) < available_gpus, "legal gpu_devices should in [%s], received [%s]" % (','.join(map(str, range(available_gpus))),','.join(map(str, args.gpu_devices)))
assert max(args.gpu_devices) < available_gpus, "legal gpu_devices should in [%s], received [%s]" % (','.join(map(str, range(available_gpus))), ','.join(map(str, args.gpu_devices)))
return device


def gpu_memory_check(model, args):
# check if largest batch can be handled by system
try:
batchsize = args.batchsize if args.get('micro_batchsize', -1) == -1 else args.micro_batchsize
for _ in range(5):
im = torch.empty(batchsize, args.channels, args.max_height, args.min_height, device=args.device).float()
seq = torch.randint(0, args.num_tokens, (batchsize, args.max_seq_len), device=args.device).long()
loss = model.data_parallel(im, device_ids=args.gpu_devices, tgt_seq=seq)
loss.sum().backward()
except RuntimeError:
raise RuntimeError("The system cannot handle a batch size of %i for the maximum image size (%i, %i). Try to use a smaller micro batchsize." % (batchsize, args.max_height, args.max_width))
model.zero_grad()
with torch.cuda.device(args.device):
torch.cuda.empty_cache()
del im, seq


def token2str(tokens, tokenizer) -> list:
if len(tokens.shape) == 1:
tokens = tokens[None, :]
Expand Down

0 comments on commit 9f974c9

Please sign in to comment.