Skip to content

Commit

Permalink
setting default baseline to 0.5
Browse files Browse the repository at this point in the history
  • Loading branch information
urialon committed Jun 6, 2023
1 parent 10abf4c commit 0285e0e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
12 changes: 7 additions & 5 deletions code_bert_score/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ def score(
baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float()
else:
baselines = torch.from_numpy(pd.read_csv(baseline_path).to_numpy())[:, 1:].unsqueeze(1).float()

all_preds = (all_preds - baselines) / (1 - baselines)
else:
print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr)
# print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr)
baselines = 0.5
all_preds = (all_preds - baselines) / (1 - baselines)

out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2], all_preds[..., 3] # P, R, F, F3

Expand Down Expand Up @@ -252,9 +252,11 @@ def plot_example(
baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv")
if os.path.isfile(baseline_path):
baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float()
sim = (sim - baselines[2].item()) / (1 - baselines[2].item())
baseline = baselines[2].item()
else:
print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr)
# print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr)
baseline = 0.5
sim = (sim - baseline) / (1 - baseline)

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
Expand Down
15 changes: 9 additions & 6 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,19 @@ def print_results(predictions, refs, pred_results):
with open('idf_dicts/java_idf.pkl', 'rb') as f:
java_idf = pickle.load(f)

# pred_results = code_bert_score.score([''],['a'], sources=["a"], lang="python")
# pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf)
# print_results(predictions, refs, pred_results)
pred_results = code_bert_score.score([''],['a'], sources=["a"], lang="python")
pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf)
print_results(predictions, refs, pred_results)

# print('When providing the context: "find the index of target in this.elements"')
# pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf, sources=['find the index of target in this.elements'] * 2)
# print_results(predictions, refs, pred_results)
print('When providing the context: "find the index of target in this.elements"')
pred_results = code_bert_score.score(cands=predictions, refs=refs, no_punc=True, lang='java', idf=java_idf, sources=['find the index of target in this.elements'] * 2)
print_results(predictions, refs, pred_results)


with open('idf_dicts/python_idf.pkl', 'rb') as f:
python_idf = pickle.load(f)
pred_results = code_bert_score.score(cands=['math.sqrt(x)'], refs=[['x ** 0.5']], no_punc=True, lang='python', idf=python_idf)
print(pred_results)

pred_results = code_bert_score.score(cands=['math.sqrt(x)'], refs=[['x ** 0.5']], rescale_with_baseline=True, lang='en')
print(pred_results)

0 comments on commit 0285e0e

Please sign in to comment.