Skip to content

Commit

Permalink
Merge pull request #26 from jabirshabbir/master
Browse files Browse the repository at this point in the history
Pull request for moverscore_v2 GPU modification
  • Loading branch information
andyweizhao authored Jan 18, 2023
2 parents 9c362cc + ad8dd8c commit 0459a3b
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions moverscore_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@

from transformers import AutoTokenizer, AutoModel

device = 'cuda'

if os.environ.get('MOVERSCORE_MODEL'):
model_name = os.environ.get('MOVERSCORE_MODEL')
else:
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=True)
model = AutoModel.from_pretrained(model_name, output_hidden_states=True, output_attentions=True)
model.eval()

model.to(device)

def truncate(tokens):
if len(tokens) > tokenizer.model_max_length - 2:
Expand Down Expand Up @@ -71,7 +73,7 @@ def bert_encode(model, x, attention_mask):
# stop_words = set(f.read().strip().split(' '))

def collate_idf(arr, tokenize, numericalize, idf_dict,
pad="[PAD]"):
pad="[PAD]",device='cuda:0'):

tokens = [["[CLS]"]+truncate(tokenize(a))+["[SEP]"] for a in arr]
arr = [numericalize(a) for a in tokens]
Expand All @@ -82,15 +84,18 @@ def collate_idf(arr, tokenize, numericalize, idf_dict,

padded, lens, mask = padding(arr, pad_token, dtype=torch.long)
padded_idf, _, _ = padding(idf_weights, pad_token, dtype=torch.float)
padded = padded.to(device=device)
mask = mask.to(device=device)
lens = lens.to(device=device)

return padded, padded_idf, lens, mask, tokens

def get_bert_embedding(all_sens, model, tokenizer, idf_dict,
batch_size=-1):
batch_size=-1,device='cuda:0'):

padded_sens, padded_idf, lens, mask, tokens = collate_idf(all_sens,
tokenizer.tokenize, tokenizer.convert_tokens_to_ids,
idf_dict)
idf_dict,device=device)

if batch_size == -1: batch_size = len(all_sens)

Expand Down Expand Up @@ -120,14 +125,14 @@ def batched_cdist_l2(x1, x2):
).add_(x1_norm).clamp_min_(1e-30).sqrt_()
return res

def word_mover_score(refs, hyps, idf_dict_ref, idf_dict_hyp, stop_words=[], n_gram=1, remove_subwords = True, batch_size=256):
def word_mover_score(refs, hyps, idf_dict_ref, idf_dict_hyp, stop_words=[], n_gram=1, remove_subwords = True, batch_size=256,device='cuda:0'):
preds = []
for batch_start in range(0, len(refs), batch_size):
batch_refs = refs[batch_start:batch_start+batch_size]
batch_hyps = hyps[batch_start:batch_start+batch_size]

ref_embedding, ref_lens, ref_masks, ref_idf, ref_tokens = get_bert_embedding(batch_refs, model, tokenizer, idf_dict_ref)
hyp_embedding, hyp_lens, hyp_masks, hyp_idf, hyp_tokens = get_bert_embedding(batch_hyps, model, tokenizer, idf_dict_hyp)
ref_embedding, ref_lens, ref_masks, ref_idf, ref_tokens = get_bert_embedding(batch_refs, model, tokenizer, idf_dict_ref,device=device)
hyp_embedding, hyp_lens, hyp_masks, hyp_idf, hyp_tokens = get_bert_embedding(batch_hyps, model, tokenizer, idf_dict_hyp,device=device)

ref_embedding = ref_embedding[-1]
hyp_embedding = hyp_embedding[-1]
Expand Down Expand Up @@ -177,8 +182,8 @@ def plot_example(is_flow, reference, translation, device='cuda:0'):
idf_dict_ref = defaultdict(lambda: 1.)
idf_dict_hyp = defaultdict(lambda: 1.)

ref_embedding, ref_lens, ref_masks, ref_idf, ref_tokens = get_bert_embedding([reference], model, tokenizer, idf_dict_ref)
hyp_embedding, hyp_lens, hyp_masks, hyp_idf, hyp_tokens = get_bert_embedding([translation], model, tokenizer, idf_dict_hyp)
ref_embedding, ref_lens, ref_masks, ref_idf, ref_tokens = get_bert_embedding([reference], model, tokenizer, idf_dict_ref,device=device)
hyp_embedding, hyp_lens, hyp_masks, hyp_idf, hyp_tokens = get_bert_embedding([translation], model, tokenizer, idf_dict_hyp,device=device)

ref_embedding = ref_embedding[-1]
hyp_embedding = hyp_embedding[-1]
Expand Down

0 comments on commit 0459a3b

Please sign in to comment.