Skip to content

Commit ad06f8e

Browse files
authored
refactor(examples) Update FedRAG example (#5157)
1 parent 7e18ede commit ad06f8e

File tree

4 files changed

+35
-21
lines changed

4 files changed

+35
-21
lines changed

examples/fedrag/fedrag/client_app.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
"""fedrag: A Flower Federated RAG app."""
22

3-
import os
4-
53
from flwr.client import ClientApp
64
from flwr.common import ConfigRecord, Context, Message, RecordDict
75

86
from fedrag.retriever import Retriever
97

8+
109
# Flower ClientApp
1110
app = ClientApp()
1211

examples/fedrag/fedrag/llm_querier.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
from transformers import AutoTokenizer, AutoModelForCausalLM
77

8+
import os
9+
10+
os.environ["TOKENIZERS_PARALLELISM"] = "false" # to avoid deadlocks during tokenization
11+
812

913
class LLMQuerier:
1014

@@ -16,34 +20,47 @@ def __init__(self, model_name, use_gpu=False):
1620
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
1721
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
1822

23+
# set pad token if empty
24+
if self.tokenizer.pad_token_id is None:
25+
self.tokenizer.pad_token = self.tokenizer.eos_token
26+
self.tokenizer.pad_token_id = self.tokenizer.convert_tokens_to_ids(
27+
self.tokenizer.pad_token
28+
)
29+
1930
def answer(self, question, documents, options, dataset_name, max_new_tokens=10):
2031
# Format options as A) ... B) ... etc.
2132
formatted_options = "\n".join([f"{k}) {v}" for k, v in options.items()])
2233

2334
prompt = self.__format_prompt(
2435
question, documents, formatted_options, dataset_name
2536
)
26-
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True).to(
27-
self.device
28-
)
2937

38+
inputs = self.tokenizer(
39+
prompt, padding=True, return_tensors="pt", truncation=True
40+
).to(self.device)
41+
42+
# Perform element-wise comparison and create attention mask tensor
43+
attention_mask = (inputs.input_ids != self.tokenizer.pad_token_id).long()
3044
outputs = self.model.generate(
3145
inputs.input_ids,
46+
attention_mask=attention_mask,
3247
max_new_tokens=max_new_tokens,
3348
early_stopping=False,
34-
eos_token_id=self.tokenizer.eos_token_id,
49+
pad_token_id=self.tokenizer.pad_token_id, # set explicitly to avoid open-end generation print statement
50+
eos_token_id=self.tokenizer.eos_token_id, # set explicitly to avoid open-end generation print statement
3551
)
3652

37-
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
38-
return full_response, self.__parse_response(full_response, prompt)
53+
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
54+
generated_answer = self.__extract_answer(generated_text, prompt)
55+
return prompt, generated_answer
3956

4057
@classmethod
4158
def __format_prompt(cls, question, documents, options, dataset_name):
42-
instruction = None
59+
instruction = "You are a helpful medical expert, and your task is to answer a medical question using the relevant documents."
4360
if dataset_name == "pubmedqa":
4461
instruction = "As an expert doctor in clinical science and medical knowledge, can you tell me if the following statement is correct? Answer yes, no, or maybe."
4562
elif dataset_name == "bioasq":
46-
"You are an advanced biomedical AI assistant trained to understand and process medical and scientific texts. Given a biomedical question, your goal is to provide a concise and accurate answer based on relevant scientific literature."
63+
instruction = "You are an advanced biomedical AI assistant trained to understand and process medical and scientific texts. Given a biomedical question, your goal is to provide a concise and accurate answer based on relevant scientific literature."
4764

4865
ctx_documents = "\n".join(
4966
[f"Document {i + 1}: {doc}" for i, doc in enumerate(documents)]
@@ -59,16 +76,16 @@ def __format_prompt(cls, question, documents, options, dataset_name):
5976
Options:
6077
{options}
6178
62-
Please answer with only the correct option: """
79+
Answer only with the correct option: """
6380
return prompt
6481

6582
@classmethod
66-
def __parse_response(cls, full_response, original_prompt):
83+
def __extract_answer(cls, generated_text, original_prompt):
6784
# Extract only the new generated text
68-
response = full_response[len(original_prompt) :].strip()
85+
response = generated_text[len(original_prompt) :].strip()
6986

7087
# Find first occurrence of A-D (case-insensitive)
71-
match = re.search(r"\b([A-Da-d])\b", response)
72-
if match:
73-
return match.group(1).upper()
88+
option = re.search(r"\b([A-Da-d])\b", response)
89+
if option:
90+
return option.group(1).upper()
7491
return None

examples/fedrag/fedrag/retriever.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
import warnings
44

5-
# Suppress FAISS-specific warnings
6-
warnings.filterwarnings("ignore", category=DeprecationWarning, module="faiss")
5+
# Suppress deprecation warnings
6+
warnings.filterwarnings("ignore", category=DeprecationWarning)
77

88
import os
99
import json

examples/fedrag/fedrag/server_app.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import hashlib
44
import os
5-
import random
65
import time
76
from collections import defaultdict
87
from itertools import cycle
@@ -168,14 +167,13 @@ def main(grid: Grid, context: Context) -> None:
168167
options = q["options"]
169168
answer = q["answer"]
170169

171-
response, predicted_answer = llm_querier.answer(
170+
prompt, predicted_answer = llm_querier.answer(
172171
question, merged_docs, options, dataset_name
173172
)
174173

175174
# If the model did not predict any value,
176175
# then discard the question
177176
if predicted_answer is not None:
178-
predicted_answer = random.choice(list(options.keys()))
179177
expected_answers[dataset_name].append(answer)
180178
predicted_answers[dataset_name].append(predicted_answer)
181179
q_et = time.time()

0 commit comments

Comments
 (0)