Skip to content

Commit

Permalink
Prereleased(LLMLinguia): fix the chunck issue and prepare for v0.2.2 (#…
Browse files Browse the repository at this point in the history
…130)

Co-authored-by: Qianhui Wu <[email protected]>
Co-authored-by: panzs <[email protected]>
Co-authored-by: Xufang Luo <[email protected]>
Co-authored-by: Yuqing Yang <[email protected]>
  • Loading branch information
5 people authored Apr 9, 2024
1 parent 309392a commit a411a3f
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 84 deletions.
169 changes: 92 additions & 77 deletions experiments/llmlingua2/evaluation/eval_meetingbank_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,89 +32,104 @@

args = parser.parse_args()
os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
data = json.load(open(args.load_prompt_from))
data = data.values() if isinstance(data, dict) else data

print(f"num data: {len(data)}")

model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)

results = defaultdict(dict)
results_list = defaultdict(list)
if os.path.exists(args.save_path):
prev_results = json.load(open(args.save_path))
results.update(prev_results)
if os.path.exists(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
)
):
results_list = json.load(
open(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),


def predict():
data = json.load(open(args.load_prompt_from))
data = data.values() if isinstance(data, dict) else data

print(f"num data: {len(data)}")

model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)

results = defaultdict(dict)
results_list = defaultdict(list)
if os.path.exists(args.save_path):
prev_results = json.load(open(args.save_path))
results.update(prev_results)
if os.path.exists(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
)
):
results_list = json.load(
open(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
)
)
)
)

prompt = "Write a high-quality answer for the given question using the provided meeting transcript (which may be compressed).\n{transcript}\nQuestion:{question}\nAnswer:"
for sample in tqdm(data):
sample_idx = int(sample["idx"])
if sample_idx in results or str(sample_idx) in results:
print(f"{sample_idx}-th already processed.")
continue
if args.num_sample > 0 and int(sample_idx) > args.num_sample:
break
transcript = sample[args.load_key]
token_ids = tokenizer.encode(transcript)
if len(token_ids) > args.n_max_token - args.n_max_token_ans:
transcript = tokenizer.decode(
token_ids[: args.n_max_token - args.n_max_token_ans]
prompt = "Write a high-quality answer for the given question using the provided meeting transcript (which may be compressed).\n{transcript}\nQuestion:{question}\nAnswer:"
for sample in tqdm(data):
sample_idx = int(sample["idx"])
if sample_idx in results or str(sample_idx) in results:
print(f"{sample_idx}-th already processed.")
continue
if args.num_sample > 0 and int(sample_idx) > args.num_sample:
break
transcript = sample[args.load_key]
token_ids = tokenizer.encode(transcript)
if len(token_ids) > args.n_max_token - args.n_max_token_ans:
transcript = tokenizer.decode(
token_ids[: args.n_max_token - args.n_max_token_ans]
)
qa_list = sample["QA_pairs"]
q_list = []
a_list = []
a_list_model = []
for qa in qa_list:
q = qa["question"]
a = qa["answer"]
query = prompt.format(transcript=transcript, question=q)
answer = query_llm(
query,
model,
args.model_name_or_path,
args.n_max_token_ans,
tokenizer=tokenizer,
)
q_list.append(q)
a_list.append(a)
a_list_model.append(answer)

results[sample_idx]["transcript"] = transcript
results[sample_idx]["questions"] = q_list[:]
results[sample_idx]["answers"] = a_list[:]
results[sample_idx]["model_answers"] = a_list_model[:]

results_list["questions"].extend(q_list[:])
results_list["answers"].extend(a_list[:])
results_list["model_answers"].extend(a_list_model[:])

json.dump(results, open(args.save_path, "w"), indent=4)
json.dump(
results_list,
open(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
),
"w",
),
indent=4,
)
qa_list = sample["QA_pairs"]
q_list = []
a_list = []
a_list_model = []
for qa in qa_list:
q = qa["question"]
a = qa["answer"]
query = prompt.format(transcript=transcript, question=q)
answer = query_llm(
query,
model,
args.model_name_or_path,
args.n_max_token_ans,
tokenizer=tokenizer,


predict()
results_list = json.load(
open(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
)
q_list.append(q)
a_list.append(a)
a_list_model.append(answer)

results[sample_idx]["transcript"] = transcript
results[sample_idx]["questions"] = q_list[:]
results[sample_idx]["answers"] = a_list[:]
results[sample_idx]["model_answers"] = a_list_model[:]

results_list["questions"].extend(q_list[:])
results_list["answers"].extend(a_list[:])
results_list["model_answers"].extend(a_list_model[:])

json.dump(results, open(args.save_path, "w"), indent=4)
json.dump(
results_list,
open(
os.path.join(
os.path.dirname(args.save_path),
os.path.basename(args.save_path).replace("answer", "answer_list"),
),
"w",
),
indent=4,
)

score_dict = evaluate_with_gt(results_list["answers"], results_list["model_answers"])
)
for i, ans in enumerate(results_list["answers"]):
results_list["answers"][i] = [results_list["answers"][i]]
score_dict = evaluate_with_gt(results_list["model_answers"], results_list["answers"])
json.dump(
score_dict,
open(
Expand Down
16 changes: 12 additions & 4 deletions experiments/llmlingua2/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ def qa_f1_zh_score(prediction, ground_truth, **kwargs):
return f1_score(prediction_tokens, ground_truth_tokens)


def qa_score(prediction, ground_truths):
normalized_prediction = normalize_answer2(prediction)

for ground_truth in ground_truths:
normalized_ground_truth = normalize_answer2(ground_truth)
if normalized_ground_truth.lower() in normalized_prediction.lower():
return 1.0
return 0.0


import regex


Expand Down Expand Up @@ -207,12 +217,10 @@ def eval_qa_f1_score(pred, ground_truths):
pred_list = pred_list_truncated

metrics = {
"qa_f1_score": 0.0,
"best_subspan_em": 0.0,
"qa_score": 0.0,
}
for pred, gts in zip(pred_list, gt_list):
metrics["qa_f1_score"] += eval_qa_f1_score(pred, gts)
metrics["best_subspan_em"] += best_subspan_em(pred, gts)
metrics["qa_score"] += qa_score(pred, gts)
# average
for metric_name, score in metrics.items():
metrics[metric_name] = score * 100 / len(pred_list)
Expand Down
6 changes: 4 additions & 2 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2407,8 +2407,10 @@ def split_string_to_words(input_string):
keep_words = []
word_labels = []
assert len(words) == len(word_probs)
for word, word_porb in zip(words, word_probs):
if word_porb > threshold:
for word, word_prob in zip(words, word_probs):
if word_prob > threshold or (
threshold == 1.0 and word_prob == threshold
):
if (
drop_consecutive
and word in force_tokens
Expand Down
2 changes: 1 addition & 1 deletion llmlingua/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
_MINOR = "2"
# On master and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "1"
_PATCH = "2"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""
Expand Down

0 comments on commit a411a3f

Please sign in to comment.