Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Implementation of Iterative Feedback Loop for Translation Improvement #38

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/example_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


if __name__ == "__main__":
source_lang, target_lang, country = "English", "Spanish", "Mexico"
source_lang, target_lang, country, max_iterations, min_score_threshold = "English", "Spanish", "Mexico", 2, 90

relative_path = "sample-texts/sample-short1.txt"
script_dir = os.path.dirname(os.path.abspath(__file__))
Expand All @@ -21,6 +21,10 @@
target_lang=target_lang,
source_text=source_text,
country=country,
max_iterations=max_iterations,
min_score_threshold=min_score_threshold


)

print(f"Translation:\n\n{translation}")
154 changes: 105 additions & 49 deletions src/translation_agent/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import re
from typing import List, Union

import openai
import tiktoken
from dotenv import load_dotenv
from icecream import ic
from langchain_text_splitters import RecursiveCharacterTextSplitter

from typing import Tuple

load_dotenv() # read local .env file
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
Expand All @@ -17,6 +18,28 @@
# discrete chunks to translate one chunk at a time


def extract_suggestions_and_score(text):
"""
Extract the suggestions and score from the given text which are enclosed within <SUGGESTIONS> and <SCORE> tags.

Args:
text (str): The text containing the suggestions and score.

Returns:
Tuple[str, int]: A tuple containing the extracted suggestion text and the score as a int.
"""

score_match = re.search(r'<SCORE>(.*?)</SCORE>', text, re.DOTALL)
suggestions_match = re.search(r'<SUGGESTIONS>(.*?)</SUGGESTIONS>', text, re.DOTALL)

if score_match and suggestions_match:
score_str = score_match.group(1).strip()
score = int(score_str) if score_str.isdigit() else 0
reflection = suggestions_match.group(1).strip()
return reflection, score
else:
return "", 0

def get_completion(
prompt: str,
system_message: str = "You are a helpful assistant.",
Expand Down Expand Up @@ -103,7 +126,7 @@ def one_chunk_reflect_on_translation(
source_text: str,
translation_1: str,
country: str = "",
) -> str:
) -> Tuple[str, int]:
"""
Use an LLM to reflect on the translation, treating the entire text as one chunk.

Expand All @@ -115,12 +138,13 @@ def one_chunk_reflect_on_translation(
country (str): Country specified for the target language.

Returns:
str: The LLM's reflection on the translation, providing constructive criticism and suggestions for improvement.
Tuple[str, int]: The LLM's reflection on the translation, providing constructive criticism and suggestions for improvement and the score of the translation.
"""

system_message = f"You are an expert linguist specializing in translation from {source_lang} to {target_lang}. \
You will be provided with a source text and its translation and your goal is to improve the translation."

You will be provided with a source text and its translation and your goal is to improve the translation. In addition to providing constructive criticism and suggestions for improvement, \
rate the translation quality on a scale from 0 to 100."

if country != "":
reflection_prompt = f"""Your task is to carefully read a source text and a translation from {source_lang} to {target_lang}, and then give constructive criticism and helpful suggestions to improve the translation. \
The final style and tone of the translation should match the style of {target_lang} colloquially spoken in {country}.
Expand All @@ -143,7 +167,7 @@ def one_chunk_reflect_on_translation(

Write a list of specific, helpful and constructive suggestions for improving the translation.
Each suggestion should address one specific part of the translation.
Output only the suggestions and nothing else."""
Output the suggestions and score, delimited by XML tags <SUGGESTIONS></SUGGESTIONS> and <SCORE></SCORE>."""

else:
reflection_prompt = f"""Your task is to carefully read a source text and a translation from {source_lang} to {target_lang}, and then give constructive criticisms and helpful suggestions to improve the translation. \
Expand All @@ -166,10 +190,10 @@ def one_chunk_reflect_on_translation(

Write a list of specific, helpful and constructive suggestions for improving the translation.
Each suggestion should address one specific part of the translation.
Output only the suggestions and nothing else."""
Output the suggestions and score, delimited by XML tags <SUGGESTIONS></SUGGESTIONS> and <SCORE></SCORE>."""

reflection = get_completion(reflection_prompt, system_message=system_message)
return reflection
return extract_suggestions_and_score(reflection)


def one_chunk_improve_translation(
Expand Down Expand Up @@ -229,7 +253,7 @@ def one_chunk_improve_translation(


def one_chunk_translate_text(
source_lang: str, target_lang: str, source_text: str, country: str = ""
source_lang: str, target_lang: str, source_text: str, country: str = "", max_iterations: int = 1, min_score_threshold = 75
) -> str:
"""
Translate a single chunk of text from the source language to the target language.
Expand All @@ -243,21 +267,30 @@ def one_chunk_translate_text(
target_lang (str): The target language for the translation.
source_text (str): The text to be translated.
country (str): Country specified for the target language.
max_iterations (int): The maximum number of iterations for the translation process.
min_score_threshold (int): The minimum score threshold for the translation.
Returns:
str: The improved translation of the source text.
"""
translation_1 = one_chunk_initial_translation(
source_lang, target_lang, source_text
)

reflection = one_chunk_reflect_on_translation(
source_lang, target_lang, source_text, translation_1, country
)
translation_2 = one_chunk_improve_translation(
source_lang, target_lang, source_text, translation_1, reflection
)
iteration = 0
score = 0

translation = one_chunk_initial_translation(source_lang, target_lang, source_text)

while iteration < max_iterations and score < min_score_threshold:
reflection, score = one_chunk_reflect_on_translation(
source_lang, target_lang, source_text, translation, country
)
if score < min_score_threshold:
translation = one_chunk_improve_translation(
source_lang, target_lang, source_text, translation, reflection
)
ic(f"Iteration {iteration + 1}, Score: {score}")
iteration += 1

return translation

return translation_2


def num_tokens_in_string(
Expand Down Expand Up @@ -350,7 +383,7 @@ def multichunk_reflect_on_translation(
source_text_chunks: List[str],
translation_1_chunks: List[str],
country: str = "",
) -> List[str]:
) -> List[Tuple[str, int]]:
"""
Provides constructive criticism and suggestions for improving a partial translation.

Expand All @@ -362,11 +395,14 @@ def multichunk_reflect_on_translation(
country (str): Country specified for the target language.

Returns:
List[str]: A list of reflections containing suggestions for improving each translated chunk.
List[Tuple[str, int]]: A list of tuples, where each tuple contains:
- The reflection and suggestions for improving the translation.
- The score indicating the quality of the translation.
"""

system_message = f"You are an expert linguist specializing in translation from {source_lang} to {target_lang}. \
You will be provided with a source text and its translation and your goal is to improve the translation."
You will be provided with a source text and its translation and your goal is to improve the translation.In addition to providing constructive criticism and suggestions for improvement, \
rate the translation quality on a scale from 0 to 100."

if country != "":
reflection_prompt = """Your task is to carefully read a source text and part of a translation of that text from {source_lang} to {target_lang}, and then give constructive criticism and helpful suggestions for improving the translation.
Expand Down Expand Up @@ -398,7 +434,7 @@ def multichunk_reflect_on_translation(

Write a list of specific, helpful and constructive suggestions for improving the translation.
Each suggestion should address one specific part of the translation.
Output only the suggestions and nothing else."""
Output the suggestions and score, delimited by XML tags <SUGGESTIONS></SUGGESTIONS> and <SCORE></SCORE>."""

else:
reflection_prompt = """Your task is to carefully read a source text and part of a translation of that text from {source_lang} to {target_lang}, and then give constructive criticism and helpful suggestions for improving the translation.
Expand Down Expand Up @@ -429,7 +465,7 @@ def multichunk_reflect_on_translation(

Write a list of specific, helpful and constructive suggestions for improving the translation.
Each suggestion should address one specific part of the translation.
Output only the suggestions and nothing else."""
Output the suggestions and score, delimited by XML tags <SUGGESTIONS></SUGGESTIONS> and <SCORE></SCORE>."""

reflection_chunks = []
for i in range(len(source_text_chunks)):
Expand Down Expand Up @@ -460,7 +496,7 @@ def multichunk_reflect_on_translation(
)

reflection = get_completion(prompt, system_message=system_message)
reflection_chunks.append(reflection)
reflection_chunks.append(extract_suggestions_and_score(reflection))

return reflection_chunks

Expand Down Expand Up @@ -552,7 +588,7 @@ def multichunk_improve_translation(


def multichunk_translation(
source_lang, target_lang, source_text_chunks, country: str = ""
source_lang, target_lang, source_text_chunks, country: str = "", max_iterations: int = 1, min_score_threshold = 75
):
"""
Improves the translation of multiple text chunks based on the initial translation and reflection.
Expand All @@ -564,31 +600,49 @@ def multichunk_translation(
translation_1_chunks (List[str]): The list of initial translations for each source text chunk.
reflection_chunks (List[str]): The list of reflections on the initial translations.
country (str): Country specified for the target language
max_iterations (int): The maximum number of iterations for the translation process.
min_score_threshold (int): The minimum score threshold for the translation.
Returns:
List[str]: The list of improved translations for each source text chunk.
"""

translation_1_chunks = multichunk_initial_translation(
source_lang, target_lang, source_text_chunks
)

reflection_chunks = multichunk_reflect_on_translation(
source_lang,
target_lang,
source_text_chunks,
translation_1_chunks,
country,
)

translation_2_chunks = multichunk_improve_translation(
source_lang,
target_lang,
source_text_chunks,
translation_1_chunks,
reflection_chunks,
)
iteration = 0
score = 0

return translation_2_chunks
translation_chunks = multichunk_initial_translation(source_lang, target_lang, source_text_chunks)
final_translation_chunks = translation_chunks
improving_chunk_indices = list(range(len(source_text_chunks)))

while iteration < max_iterations:
reflection_chunks = multichunk_reflect_on_translation(
source_lang, target_lang, source_text_chunks, translation_chunks, country
)

improving_chunk_indices_temp = []

for indice in improving_chunk_indices:
_ , score = reflection_chunks[indice]
if score < min_score_threshold:
improving_chunk_indices_temp.append(indice)
else:
final_translation_chunks[indice] = translation_chunks[indice]

if not improving_chunk_indices_temp:
break

improving_chunk_indices = improving_chunk_indices_temp
translation_chunks = [translation_chunks[indice] for indice in improving_chunk_indices]
reflection_chunks = [reflection_chunks[indice] for indice in improving_chunk_indices]
source_text_chunks = [source_text_chunks[indice] for indice in improving_chunk_indices]

translation_chunks = multichunk_improve_translation(
source_lang, target_lang, source_text_chunks, translation_chunks, reflection_chunks
)

ic(f"Iteration {iteration + 1}, Score: {score}")
iteration += 1

return final_translation_chunks


def calculate_chunk_size(token_count: int, token_limit: int) -> int:
Expand Down Expand Up @@ -638,8 +692,10 @@ def translate(
source_text,
country,
max_tokens=MAX_TOKENS_PER_CHUNK,
max_iterations=1,
min_score_threshold=75
):
"""Translate the source_text from source_lang to target_lang."""
"""Translate the source_text from source_lang to target_lang with iterative feedback to improve translation quality based on score thresholds."""

num_tokens_in_text = num_tokens_in_string(source_text)

Expand All @@ -649,7 +705,7 @@ def translate(
ic("Translating text as a single chunk")

final_translation = one_chunk_translate_text(
source_lang, target_lang, source_text, country
source_lang, target_lang, source_text, country, max_iterations, min_score_threshold
)

return final_translation
Expand All @@ -672,7 +728,7 @@ def translate(
source_text_chunks = text_splitter.split_text(source_text)

translation_2_chunks = multichunk_translation(
source_lang, target_lang, source_text_chunks, country
source_lang, target_lang, source_text_chunks, country, max_iterations, min_score_threshold
)

return "".join(translation_2_chunks)