From 51bbb5bdc5771f03313d01b93d66019e74068aaa Mon Sep 17 00:00:00 2001 From: yuhaowu Date: Mon, 20 Nov 2023 17:39:05 +0800 Subject: [PATCH] rm numpy.float type for numpy >= 1.24 --- moverscore_v2.py | 12 ++++++------ requirements.txt | 1 + 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/moverscore_v2.py b/moverscore_v2.py index 52fd426..56af7ba 100644 --- a/moverscore_v2.py +++ b/moverscore_v2.py @@ -159,8 +159,8 @@ def word_mover_score(refs, hyps, idf_dict_ref, idf_dict_hyp, stop_words=[], n_gr distance_matrix = batched_cdist_l2(raw, raw).double().cpu().numpy() for i in range(batch_size): - c1 = np.zeros(raw.shape[1], dtype=np.float) - c2 = np.zeros(raw.shape[1], dtype=np.float) + c1 = np.zeros(raw.shape[1], dtype=float) + c2 = np.zeros(raw.shape[1], dtype=float) c1[:len(ref_idf[i])] = ref_idf[i] c2[len(ref_idf[i]):] = hyp_idf[i] @@ -169,7 +169,7 @@ def word_mover_score(refs, hyps, idf_dict_ref, idf_dict_hyp, stop_words=[], n_gr dst = distance_matrix[i] _, flow = emd_with_flow(c1, c2, dst) - flow = np.array(flow, dtype=np.float32) + flow = np.array(flow, dtype=float) score = 1./(1. + np.sum(flow * dst))#1 - np.sum(flow * dst) preds.append(score) @@ -198,8 +198,8 @@ def plot_example(is_flow, reference, translation, device='cuda:0'): i = 0 - c1 = np.zeros(raw.shape[1], dtype=np.float) - c2 = np.zeros(raw.shape[1], dtype=np.float) + c1 = np.zeros(raw.shape[1], dtype=float) + c2 = np.zeros(raw.shape[1], dtype=float) c1[:len(ref_idf[i])] = ref_idf[i] c2[len(ref_idf[i]):] = hyp_idf[i] @@ -210,7 +210,7 @@ def plot_example(is_flow, reference, translation, device='cuda:0'): if is_flow: _, flow = emd_with_flow(c1, c2, dst) - new_flow = np.array(flow, dtype=np.float32) + new_flow = np.array(flow, dtype=float) res = new_flow[:len(ref_tokens[i]), len(ref_idf[i]): (len(ref_idf[i])+len(hyp_tokens[i]))] else: res = 1./(1. + dst[:len(ref_tokens[i]), len(ref_idf[i]): (len(ref_idf[i])+len(hyp_tokens[i]))]) diff --git a/requirements.txt b/requirements.txt index cd17c28..4ff2e43 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch>=1.0.0 pyemd==0.5.1 pytorch-transformers>=1.1.0 +numpy>=1.24