diff --git a/.gitignore b/.gitignore
index b5844af..0d68f6f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -6,3 +6,4 @@ poetry.lock
floresp-v2.0-rc.3
*cache
wmt
+outputs
diff --git a/app/README.md b/app/README.md
new file mode 100644
index 0000000..1d84349
--- /dev/null
+++ b/app/README.md
@@ -0,0 +1,89 @@
+
+## Translation Agent WebUI
+
+This repository contains a Gradio web UI for a translation agent that utilizes various language models for translation.
+
+### Preview
+
+![webui](image.png)
+
+**Features:**
+
+- **Tokenized Text:** Displays translated text with tokenization, highlighting differences between original and translated words.
+- **Document Upload:** Supports uploading various document formats (PDF, TXT, DOC, etc.) for translation.
+- **Multiple API Support:** Integrates with popular language models like:
+ - Groq
+ - OpenAI
+ - Ollama
+ - Together AI
+ ...
+- **Different LLM for reflection**: Now you can enable second Endpoint to use another LLM for reflection.
+
+
+**Getting Started**
+
+1. **Install Dependencies:**
+
+ **Linux**
+ ```bash
+ git clone https://github.com/andrewyng/translation-agent.git
+ cd translation-agent
+ poetry install --with app
+ poetry shell
+ ```
+ **Windows**
+ ```bash
+ git clone https://github.com/andrewyng/translation-agent.git
+ cd translation-agent
+ poetry install --with app
+ poetry shell
+ ```
+
+2. **Set API Keys:**
+ - Rename `.env.sample` to `.env`, you can add your API keys for each service:
+
+ ```
+ OPENAI_API_KEY="sk-xxxxx" # Keep this field
+ GROQ_API_KEY="xxxxx"
+ TOGETHER_API_KEY="xxxxx"
+ ```
+ - Then you can also set the API_KEY in webui.
+
+3. **Run the Web UI:**
+
+ **Linux**
+ ```bash
+ python app/app.py
+ ```
+ **Windows**
+ ```bash
+ python .\app\app.py
+ ```
+
+4. **Access the Web UI:**
+ Open your web browser and navigate to `http://127.0.0.1:7860/`.
+
+**Usage:**
+
+1. Select your desired translation API from the Endpoint dropdown menu.
+2. Input the source language, target language, and country(optional).
+3. Input the source text or upload your document file.
+4. Submit and get translation, the UI will display the translated text with tokenization and highlight differences.
+5. Enable Second Endpoint, you can add another endpoint by different LLMs for reflection.
+6. Using a custom endpoint, you can enter an OpenAI compatible API base url.
+
+**Customization:**
+
+- **Add New LLMs:** Modify the `patch.py` file to integrate additional LLMs.
+
+**Contributing:**
+
+Contributions are welcome! Feel free to open issues or submit pull requests.
+
+**License:**
+
+This project is licensed under the MIT License.
+
+**DEMO:**
+
+[Huggingface Demo](https://huggingface.co/spaces/vilarin/Translation-Agent-WebUI)
diff --git a/app/app.py b/app/app.py
new file mode 100644
index 0000000..0107536
--- /dev/null
+++ b/app/app.py
@@ -0,0 +1,392 @@
+import os
+import re
+from glob import glob
+
+import gradio as gr
+from process import (
+ diff_texts,
+ extract_docx,
+ extract_pdf,
+ extract_text,
+ model_load,
+ translator,
+ translator_sec,
+)
+
+
+def huanik(
+ endpoint: str,
+ base: str,
+ model: str,
+ api_key: str,
+ choice: str,
+ endpoint2: str,
+ base2: str,
+ model2: str,
+ api_key2: str,
+ source_lang: str,
+ target_lang: str,
+ source_text: str,
+ country: str,
+ max_tokens: int,
+ temperature: int,
+ rpm: int,
+):
+ if not source_text or source_lang == target_lang:
+ raise gr.Error(
+ "Please check that the content or options are entered correctly."
+ )
+
+ try:
+ model_load(endpoint, base, model, api_key, temperature, rpm)
+ except Exception as e:
+ raise gr.Error(f"An unexpected error occurred: {e}") from e
+
+ source_text = re.sub(r"(?m)^\s*$\n?", "", source_text)
+
+ if choice:
+ init_translation, reflect_translation, final_translation = (
+ translator_sec(
+ endpoint2=endpoint2,
+ base2=base2,
+ model2=model2,
+ api_key2=api_key2,
+ source_lang=source_lang,
+ target_lang=target_lang,
+ source_text=source_text,
+ country=country,
+ max_tokens=max_tokens,
+ )
+ )
+
+ else:
+ init_translation, reflect_translation, final_translation = translator(
+ source_lang=source_lang,
+ target_lang=target_lang,
+ source_text=source_text,
+ country=country,
+ max_tokens=max_tokens,
+ )
+
+ final_diff = gr.HighlightedText(
+ diff_texts(init_translation, final_translation),
+ label="Diff translation",
+ combine_adjacent=True,
+ show_legend=True,
+ visible=True,
+ color_map={"removed": "red", "added": "green"},
+ )
+
+ return init_translation, reflect_translation, final_translation, final_diff
+
+
+def update_model(endpoint):
+ endpoint_model_map = {
+ "Groq": "llama3-70b-8192",
+ "OpenAI": "gpt-4o",
+ "TogetherAI": "Qwen/Qwen2-72B-Instruct",
+ "Ollama": "llama3",
+ "CUSTOM": "",
+ }
+ if endpoint == "CUSTOM":
+ base = gr.update(visible=True)
+ else:
+ base = gr.update(visible=False)
+ return gr.update(value=endpoint_model_map[endpoint]), base
+
+
+def read_doc(path):
+ file_type = path.split(".")[-1]
+ print(file_type)
+ if file_type in ["pdf", "txt", "py", "docx", "json", "cpp", "md"]:
+ if file_type.endswith("pdf"):
+ content = extract_pdf(path)
+ elif file_type.endswith("docx"):
+ content = extract_docx(path)
+ else:
+ content = extract_text(path)
+ return re.sub(r"(?m)^\s*$\n?", "", content)
+ else:
+ raise gr.Error("Oops, unsupported files.")
+
+
+def enable_sec(choice):
+ if choice:
+ return gr.update(visible=True)
+ else:
+ return gr.update(visible=False)
+
+
+def update_menu(visible):
+ return not visible, gr.update(visible=not visible)
+
+
+def export_txt(strings):
+ if strings:
+ os.makedirs("outputs", exist_ok=True)
+ base_count = len(glob(os.path.join("outputs", "*.txt")))
+ file_path = os.path.join("outputs", f"{base_count:06d}.txt")
+ with open(file_path, "w", encoding="utf-8") as f:
+ f.write(strings)
+ return gr.update(value=file_path, visible=True)
+ else:
+ return gr.update(visible=False)
+
+
+def switch(source_lang, source_text, target_lang, output_final):
+ if output_final:
+ return (
+ gr.update(value=target_lang),
+ gr.update(value=output_final),
+ gr.update(value=source_lang),
+ gr.update(value=source_text),
+ )
+ else:
+ return (
+ gr.update(value=target_lang),
+ gr.update(value=source_text),
+ gr.update(value=source_lang),
+ gr.update(value=""),
+ )
+
+
+def close_btn_show():
+ return gr.update(visible=False), gr.update(visible=True)
+
+
+def close_btn_hide(output_diff):
+ if output_diff:
+ return gr.update(visible=True), gr.update(visible=False)
+ else:
+ return gr.update(visible=False), gr.update(visible=True)
+
+
+TITLE = """
+
+
Translation Agent WebUI
+
+"""
+
+CSS = """
+ h1 {
+ text-align: center;
+ display: block;
+ height: 10vh;
+ align-content: center;
+ }
+ footer {
+ visibility: hidden;
+ }
+ .menu_btn {
+ width: 48px;
+ height: 48px;
+ max-width: 48px;
+ min-width: 48px;
+ padding: 0px;
+ background-color: transparent;
+ border: none;
+ cursor: pointer;
+ position: relative;
+ box-shadow: none;
+ }
+ .menu_btn::before,
+ .menu_btn::after {
+ content: '';
+ position: absolute;
+ width: 30px;
+ height: 3px;
+ background-color: #4f46e5;
+ transition: transform 0.3s ease;
+ }
+ .menu_btn::before {
+ top: 12px;
+ box-shadow: 0 8px 0 #6366f1;
+ }
+ .menu_btn::after {
+ bottom: 16px;
+ }
+ .menu_btn.active::before {
+ transform: translateY(8px) rotate(45deg);
+ box-shadow: none;
+ }
+ .menu_btn.active::after {
+ transform: translateY(-8px) rotate(-45deg);
+ }
+ .lang {
+ max-width: 100px;
+ min-width: 100px;
+ }
+"""
+
+JS = """
+ function () {
+ const menu_btn = document.getElementById('menu');
+ menu_btn.classList.toggle('active');
+ }
+
+"""
+
+with gr.Blocks(theme="soft", css=CSS, fill_height=True) as demo:
+ with gr.Row():
+ visible = gr.State(value=True)
+ menu_btn = gr.Button(
+ value="", elem_classes="menu_btn", elem_id="menu", size="sm"
+ )
+ gr.HTML(TITLE)
+ with gr.Row():
+ with gr.Column(scale=1) as menubar:
+ endpoint = gr.Dropdown(
+ label="Endpoint",
+ choices=["OpenAI", "Groq", "TogetherAI", "Ollama", "CUSTOM"],
+ value="OpenAI",
+ )
+ choice = gr.Checkbox(
+ label="Additional Endpoint",
+ info="Additional endpoint for reflection",
+ )
+ model = gr.Textbox(
+ label="Model",
+ value="gpt-4o",
+ )
+ api_key = gr.Textbox(
+ label="API_KEY",
+ type="password",
+ )
+ base = gr.Textbox(label="BASE URL", visible=False)
+ with gr.Column(visible=False) as AddEndpoint:
+ endpoint2 = gr.Dropdown(
+ label="Additional Endpoint",
+ choices=[
+ "OpenAI",
+ "Groq",
+ "TogetherAI",
+ "Ollama",
+ "CUSTOM",
+ ],
+ value="OpenAI",
+ )
+ model2 = gr.Textbox(
+ label="Model",
+ value="gpt-4o",
+ )
+ api_key2 = gr.Textbox(
+ label="API_KEY",
+ type="password",
+ )
+ base2 = gr.Textbox(label="BASE URL", visible=False)
+ with gr.Row():
+ source_lang = gr.Textbox(
+ label="Source Lang",
+ value="English",
+ elem_classes="lang",
+ )
+ target_lang = gr.Textbox(
+ label="Target Lang",
+ value="Spanish",
+ elem_classes="lang",
+ )
+ switch_btn = gr.Button(value="🔄️")
+ country = gr.Textbox(
+ label="Country", value="Argentina", max_lines=1
+ )
+ with gr.Accordion("Advanced Options", open=False):
+ max_tokens = gr.Slider(
+ label="Max tokens Per Chunk",
+ minimum=512,
+ maximum=2046,
+ value=1000,
+ step=8,
+ )
+ temperature = gr.Slider(
+ label="Temperature",
+ minimum=0,
+ maximum=1.0,
+ value=0.3,
+ step=0.1,
+ )
+ rpm = gr.Slider(
+ label="Request Per Minute",
+ minimum=1,
+ maximum=1000,
+ value=60,
+ step=1,
+ )
+
+ with gr.Column(scale=4):
+ source_text = gr.Textbox(
+ label="Source Text",
+ value="If one advances confidently in the direction of his dreams, and endeavors to live the life which he has imagined, he will meet with a success unexpected in common hours.",
+ lines=12,
+ )
+ with gr.Tab("Final"):
+ output_final = gr.Textbox(
+ label="Final Translation", lines=12, show_copy_button=True
+ )
+ with gr.Tab("Initial"):
+ output_init = gr.Textbox(
+ label="Init Translation", lines=12, show_copy_button=True
+ )
+ with gr.Tab("Reflection"):
+ output_reflect = gr.Textbox(
+ label="Reflection", lines=12, show_copy_button=True
+ )
+ with gr.Tab("Diff"):
+ output_diff = gr.HighlightedText(visible=False)
+ with gr.Row():
+ submit = gr.Button(value="Translate")
+ upload = gr.UploadButton(label="Upload", file_types=["text"])
+ export = gr.DownloadButton(visible=False)
+ clear = gr.ClearButton(
+ [source_text, output_init, output_reflect, output_final]
+ )
+ close = gr.Button(value="Stop", visible=False)
+
+ switch_btn.click(
+ fn=switch,
+ inputs=[source_lang, source_text, target_lang, output_final],
+ outputs=[source_lang, source_text, target_lang, output_final],
+ )
+
+ menu_btn.click(
+ fn=update_menu, inputs=visible, outputs=[visible, menubar], js=JS
+ )
+ endpoint.change(fn=update_model, inputs=[endpoint], outputs=[model, base])
+
+ choice.select(fn=enable_sec, inputs=[choice], outputs=[AddEndpoint])
+ endpoint2.change(
+ fn=update_model, inputs=[endpoint2], outputs=[model2, base2]
+ )
+
+ start_ta = submit.click(
+ fn=huanik,
+ inputs=[
+ endpoint,
+ base,
+ model,
+ api_key,
+ choice,
+ endpoint2,
+ base2,
+ model2,
+ api_key2,
+ source_lang,
+ target_lang,
+ source_text,
+ country,
+ max_tokens,
+ temperature,
+ rpm,
+ ],
+ outputs=[output_init, output_reflect, output_final, output_diff],
+ )
+ upload.upload(fn=read_doc, inputs=upload, outputs=source_text)
+ output_diff.change(fn=export_txt, inputs=output_final, outputs=[export])
+
+ submit.click(fn=close_btn_show, outputs=[clear, close])
+ output_diff.change(
+ fn=close_btn_hide, inputs=output_diff, outputs=[clear, close]
+ )
+ close.click(fn=None, cancels=start_ta)
+
+if __name__ == "__main__":
+ demo.queue(api_open=False).launch(show_api=False, share=False)
diff --git a/app/image.png b/app/image.png
new file mode 100644
index 0000000..ea7a587
Binary files /dev/null and b/app/image.png differ
diff --git a/app/patch.py b/app/patch.py
new file mode 100644
index 0000000..89fdd43
--- /dev/null
+++ b/app/patch.py
@@ -0,0 +1,161 @@
+import os
+import time
+from functools import wraps
+from threading import Lock
+from typing import Optional, Union
+
+import gradio as gr
+import openai
+import translation_agent.utils as utils
+
+
+RPM = 60
+MODEL = ""
+TEMPERATURE = 0.3
+# Hide js_mode in UI now, update in plan.
+JS_MODE = False
+ENDPOINT = ""
+
+
+# Add your LLMs here
+def model_load(
+ endpoint: str,
+ base_url: str,
+ model: str,
+ api_key: Optional[str] = None,
+ temperature: float = TEMPERATURE,
+ rpm: int = RPM,
+ js_mode: bool = JS_MODE,
+):
+ global client, RPM, MODEL, TEMPERATURE, JS_MODE, ENDPOINT
+ ENDPOINT = endpoint
+ RPM = rpm
+ MODEL = model
+ TEMPERATURE = temperature
+ JS_MODE = js_mode
+
+ match endpoint:
+ case "OpenAI":
+ client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
+ case "Groq":
+ client = openai.OpenAI(
+ api_key=api_key if api_key else os.getenv("GROQ_API_KEY"),
+ base_url="https://api.groq.com/openai/v1",
+ )
+ case "TogetherAI":
+ client = openai.OpenAI(
+ api_key=api_key if api_key else os.getenv("TOGETHER_API_KEY"),
+ base_url="https://api.together.xyz/v1",
+ )
+ case "CUSTOM":
+ client = openai.OpenAI(api_key=api_key, base_url=base_url)
+ case "Ollama":
+ client = openai.OpenAI(
+ api_key="ollama", base_url="http://localhost:11434/v1"
+ )
+ case _:
+ client = openai.OpenAI(
+ api_key=api_key if api_key else os.getenv("OPENAI_API_KEY")
+ )
+
+
+def rate_limit(get_max_per_minute):
+ def decorator(func):
+ lock = Lock()
+ last_called = [0.0]
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ with lock:
+ max_per_minute = get_max_per_minute()
+ min_interval = 60.0 / max_per_minute
+ elapsed = time.time() - last_called[0]
+ left_to_wait = min_interval - elapsed
+
+ if left_to_wait > 0:
+ time.sleep(left_to_wait)
+
+ ret = func(*args, **kwargs)
+ last_called[0] = time.time()
+ return ret
+
+ return wrapper
+
+ return decorator
+
+
+@rate_limit(lambda: RPM)
+def get_completion(
+ prompt: str,
+ system_message: str = "You are a helpful assistant.",
+ model: str = "gpt-4-turbo",
+ temperature: float = 0.3,
+ json_mode: bool = False,
+) -> Union[str, dict]:
+ """
+ Generate a completion using the OpenAI API.
+
+ Args:
+ prompt (str): The user's prompt or query.
+ system_message (str, optional): The system message to set the context for the assistant.
+ Defaults to "You are a helpful assistant.".
+ model (str, optional): The name of the OpenAI model to use for generating the completion.
+ Defaults to "gpt-4-turbo".
+ temperature (float, optional): The sampling temperature for controlling the randomness of the generated text.
+ Defaults to 0.3.
+ json_mode (bool, optional): Whether to return the response in JSON format.
+ Defaults to False.
+
+ Returns:
+ Union[str, dict]: The generated completion.
+ If json_mode is True, returns the complete API response as a dictionary.
+ If json_mode is False, returns the generated text as a string.
+ """
+
+ model = MODEL
+ temperature = TEMPERATURE
+ json_mode = JS_MODE
+
+ if json_mode:
+ try:
+ response = client.chat.completions.create(
+ model=model,
+ temperature=temperature,
+ top_p=1,
+ response_format={"type": "json_object"},
+ messages=[
+ {"role": "system", "content": system_message},
+ {"role": "user", "content": prompt},
+ ],
+ )
+ return response.choices[0].message.content
+ except Exception as e:
+ raise gr.Error(f"An unexpected error occurred: {e}") from e
+ else:
+ try:
+ response = client.chat.completions.create(
+ model=model,
+ temperature=temperature,
+ top_p=1,
+ messages=[
+ {"role": "system", "content": system_message},
+ {"role": "user", "content": prompt},
+ ],
+ )
+ return response.choices[0].message.content
+ except Exception as e:
+ raise gr.Error(f"An unexpected error occurred: {e}") from e
+
+
+utils.get_completion = get_completion
+
+one_chunk_initial_translation = utils.one_chunk_initial_translation
+one_chunk_reflect_on_translation = utils.one_chunk_reflect_on_translation
+one_chunk_improve_translation = utils.one_chunk_improve_translation
+one_chunk_translate_text = utils.one_chunk_translate_text
+num_tokens_in_string = utils.num_tokens_in_string
+multichunk_initial_translation = utils.multichunk_initial_translation
+multichunk_reflect_on_translation = utils.multichunk_reflect_on_translation
+multichunk_improve_translation = utils.multichunk_improve_translation
+multichunk_translation = utils.multichunk_translation
+calculate_chunk_size = utils.calculate_chunk_size
diff --git a/app/process.py b/app/process.py
new file mode 100644
index 0000000..531d4f0
--- /dev/null
+++ b/app/process.py
@@ -0,0 +1,263 @@
+from difflib import Differ
+
+import docx
+import gradio as gr
+import pymupdf
+from icecream import ic
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from patch import (
+ calculate_chunk_size,
+ model_load,
+ multichunk_improve_translation,
+ multichunk_initial_translation,
+ multichunk_reflect_on_translation,
+ num_tokens_in_string,
+ one_chunk_improve_translation,
+ one_chunk_initial_translation,
+ one_chunk_reflect_on_translation,
+)
+from simplemma import simple_tokenizer
+
+
+progress = gr.Progress()
+
+
+def extract_text(path):
+ with open(path) as f:
+ file_text = f.read()
+ return file_text
+
+
+def extract_pdf(path):
+ doc = pymupdf.open(path)
+ text = ""
+ for page in doc:
+ text += page.get_text()
+ return text
+
+
+def extract_docx(path):
+ doc = docx.Document(path)
+ data = []
+ for paragraph in doc.paragraphs:
+ data.append(paragraph.text)
+ content = "\n\n".join(data)
+ return content
+
+
+def tokenize(text):
+ # Use nltk to tokenize the text
+ words = simple_tokenizer(text)
+ # Check if the text contains spaces
+ if " " in text:
+ # Create a list of words and spaces
+ tokens = []
+ for word in words:
+ tokens.append(word)
+ if not word.startswith("'") and not word.endswith(
+ "'"
+ ): # Avoid adding space after punctuation
+ tokens.append(" ") # Add space after each word
+ return tokens[:-1] # Remove the last space
+ else:
+ return words
+
+
+def diff_texts(text1, text2):
+ tokens1 = tokenize(text1)
+ tokens2 = tokenize(text2)
+
+ d = Differ()
+ diff_result = list(d.compare(tokens1, tokens2))
+
+ highlighted_text = []
+ for token in diff_result:
+ word = token[2:]
+ category = None
+ if token[0] == "+":
+ category = "added"
+ elif token[0] == "-":
+ category = "removed"
+ elif token[0] == "?":
+ continue # Ignore the hints line
+
+ highlighted_text.append((word, category))
+
+ return highlighted_text
+
+
+# modified from src.translaation-agent.utils.tranlsate
+def translator(
+ source_lang: str,
+ target_lang: str,
+ source_text: str,
+ country: str,
+ max_tokens: int = 1000,
+):
+ """Translate the source_text from source_lang to target_lang."""
+ num_tokens_in_text = num_tokens_in_string(source_text)
+
+ ic(num_tokens_in_text)
+
+ if num_tokens_in_text < max_tokens:
+ ic("Translating text as single chunk")
+
+ progress((1, 3), desc="First translation...")
+ init_translation = one_chunk_initial_translation(
+ source_lang, target_lang, source_text
+ )
+
+ progress((2, 3), desc="Reflection...")
+ reflection = one_chunk_reflect_on_translation(
+ source_lang, target_lang, source_text, init_translation, country
+ )
+
+ progress((3, 3), desc="Second translation...")
+ final_translation = one_chunk_improve_translation(
+ source_lang, target_lang, source_text, init_translation, reflection
+ )
+
+ return init_translation, reflection, final_translation
+
+ else:
+ ic("Translating text as multiple chunks")
+
+ token_size = calculate_chunk_size(
+ token_count=num_tokens_in_text, token_limit=max_tokens
+ )
+
+ ic(token_size)
+
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
+ model_name="gpt-4",
+ chunk_size=token_size,
+ chunk_overlap=0,
+ )
+
+ source_text_chunks = text_splitter.split_text(source_text)
+
+ progress((1, 3), desc="First translation...")
+ translation_1_chunks = multichunk_initial_translation(
+ source_lang, target_lang, source_text_chunks
+ )
+
+ init_translation = "".join(translation_1_chunks)
+
+ progress((2, 3), desc="Reflection...")
+ reflection_chunks = multichunk_reflect_on_translation(
+ source_lang,
+ target_lang,
+ source_text_chunks,
+ translation_1_chunks,
+ country,
+ )
+
+ reflection = "".join(reflection_chunks)
+
+ progress((3, 3), desc="Second translation...")
+ translation_2_chunks = multichunk_improve_translation(
+ source_lang,
+ target_lang,
+ source_text_chunks,
+ translation_1_chunks,
+ reflection_chunks,
+ )
+
+ final_translation = "".join(translation_2_chunks)
+
+ return init_translation, reflection, final_translation
+
+
+def translator_sec(
+ endpoint2: str,
+ base2: str,
+ model2: str,
+ api_key2: str,
+ source_lang: str,
+ target_lang: str,
+ source_text: str,
+ country: str,
+ max_tokens: int = 1000,
+):
+ """Translate the source_text from source_lang to target_lang."""
+ num_tokens_in_text = num_tokens_in_string(source_text)
+
+ ic(num_tokens_in_text)
+
+ if num_tokens_in_text < max_tokens:
+ ic("Translating text as single chunk")
+
+ progress((1, 3), desc="First translation...")
+ init_translation = one_chunk_initial_translation(
+ source_lang, target_lang, source_text
+ )
+
+ try:
+ model_load(endpoint2, base2, model2, api_key2)
+ except Exception as e:
+ raise gr.Error(f"An unexpected error occurred: {e}") from e
+
+ progress((2, 3), desc="Reflection...")
+ reflection = one_chunk_reflect_on_translation(
+ source_lang, target_lang, source_text, init_translation, country
+ )
+
+ progress((3, 3), desc="Second translation...")
+ final_translation = one_chunk_improve_translation(
+ source_lang, target_lang, source_text, init_translation, reflection
+ )
+
+ return init_translation, reflection, final_translation
+
+ else:
+ ic("Translating text as multiple chunks")
+
+ token_size = calculate_chunk_size(
+ token_count=num_tokens_in_text, token_limit=max_tokens
+ )
+
+ ic(token_size)
+
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
+ model_name="gpt-4",
+ chunk_size=token_size,
+ chunk_overlap=0,
+ )
+
+ source_text_chunks = text_splitter.split_text(source_text)
+
+ progress((1, 3), desc="First translation...")
+ translation_1_chunks = multichunk_initial_translation(
+ source_lang, target_lang, source_text_chunks
+ )
+
+ init_translation = "".join(translation_1_chunks)
+
+ try:
+ model_load(endpoint2, base2, model2, api_key2)
+ except Exception as e:
+ raise gr.Error(f"An unexpected error occurred: {e}") from e
+
+ progress((2, 3), desc="Reflection...")
+ reflection_chunks = multichunk_reflect_on_translation(
+ source_lang,
+ target_lang,
+ source_text_chunks,
+ translation_1_chunks,
+ country,
+ )
+
+ reflection = "".join(reflection_chunks)
+
+ progress((3, 3), desc="Second translation...")
+ translation_2_chunks = multichunk_improve_translation(
+ source_lang,
+ target_lang,
+ source_text_chunks,
+ translation_1_chunks,
+ reflection_chunks,
+ )
+
+ final_translation = "".join(translation_2_chunks)
+
+ return init_translation, reflection, final_translation
diff --git a/pyproject.toml b/pyproject.toml
index db621f6..26fb0db 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,6 +21,15 @@ icecream = "^2.1.3"
langchain-text-splitters = "^0.0.1"
python-dotenv = "^1.0.1"
+[tool.poetry.group.app]
+optional = true
+
+[tool.poetry.group.app.dependencies]
+simplemma = "^1.0.0"
+gradio = "4.37.2"
+python-docx = "^1.1.2"
+PyMuPDF = "^1.24.7"
+
[tool.poetry.group.dev]
optional = true
@@ -64,7 +73,7 @@ priority = "supplemental"
# Set the maximum line length to 79.
line-length = 79
indent-width = 4
-exclude = [".venv", ".env", ".git", "tests", "eval"]
+exclude = [".venv", ".env", ".git", "tests", "eval", ".jj"]
[tool.ruff.lint]
# Add the `line-too-long` rule to the enforced rule set. By default, Ruff omits rules that
@@ -89,7 +98,7 @@ fixable = ["ALL"]
ignore = ["SIM117"]
[tool.ruff.lint.isort]
-force-single-line = true
+force-single-line = false
lines-after-imports = 2
known-first-party = ["translation-agent"]
diff --git a/src/translation_agent/utils.py b/src/translation_agent/utils.py
index 67dfd5a..549dac3 100755
--- a/src/translation_agent/utils.py
+++ b/src/translation_agent/utils.py
@@ -1,6 +1,5 @@
import os
-from typing import List
-from typing import Union
+from typing import List, Union
import openai
import tiktoken