|
| 1 | +import json |
| 2 | +import os.path |
| 3 | +import logging |
| 4 | +import time |
| 5 | +from langchain.vectorstores import FAISS |
| 6 | +from langchain import PromptTemplate |
| 7 | +from utils.references import References |
| 8 | +from utils.knowledge import Knowledge |
| 9 | +from utils.file_operations import make_archive, copy_templates |
| 10 | +from utils.tex_processing import create_copies |
| 11 | +from utils.gpt_interaction import GPTModel |
| 12 | +from utils.prompts import SYSTEM |
| 13 | +from utils.embeddings import EMBEDDINGS |
| 14 | +from utils.gpt_interaction import get_gpt_responses |
| 15 | +TOTAL_TOKENS = 0 |
| 16 | +TOTAL_PROMPTS_TOKENS = 0 |
| 17 | +TOTAL_COMPLETION_TOKENS = 0 |
| 18 | +def log_usage(usage, generating_target, print_out=True): |
| 19 | + global TOTAL_TOKENS |
| 20 | + global TOTAL_PROMPTS_TOKENS |
| 21 | + global TOTAL_COMPLETION_TOKENS |
| 22 | + |
| 23 | + prompts_tokens = usage['prompt_tokens'] |
| 24 | + completion_tokens = usage['completion_tokens'] |
| 25 | + total_tokens = usage['total_tokens'] |
| 26 | + |
| 27 | + TOTAL_TOKENS += total_tokens |
| 28 | + TOTAL_PROMPTS_TOKENS += prompts_tokens |
| 29 | + TOTAL_COMPLETION_TOKENS += completion_tokens |
| 30 | + |
| 31 | + message = f">>USAGE>> For generating {generating_target}, {total_tokens} tokens have been used " \ |
| 32 | + f"({prompts_tokens} for prompts; {completion_tokens} for completion). " \ |
| 33 | + f"{TOTAL_TOKENS} tokens have been used in total." |
| 34 | + if print_out: |
| 35 | + print(message) |
| 36 | + logging.info(message) |
| 37 | + |
| 38 | + |
| 39 | +def _generation_setup(title, template="Default", |
| 40 | + tldr=False, max_kw_refs=20, bib_refs=None, max_tokens_ref=2048, # generating references |
| 41 | + knowledge_database=None, max_tokens_kd=2048, query_counts=10): |
| 42 | + |
| 43 | + llm = GPTModel(model="gpt-3.5-turbo-16k") |
| 44 | + bibtex_path, destination_folder = copy_templates(template, title) |
| 45 | + logging.basicConfig(level=logging.INFO, filename=os.path.join(destination_folder, "generation.log")) |
| 46 | + |
| 47 | + #generate key words |
| 48 | + keywords, usage = llm(systems=SYSTEM["keywords"], prompts=title, return_json=True) |
| 49 | + log_usage(usage, "keywords") |
| 50 | + keywords = {keyword: max_kw_refs for keyword in keywords} |
| 51 | + print("Keywords: \n", keywords) |
| 52 | + |
| 53 | + #generate references |
| 54 | + ref = References(title, bib_refs) |
| 55 | + ref.collect_papers(keywords, tldr=tldr) |
| 56 | + references = ref.to_prompts(max_tokens=max_tokens_ref) |
| 57 | + all_paper_ids = ref.to_bibtex(bibtex_path) |
| 58 | + |
| 59 | + #product domain knowledge |
| 60 | + prompts = f"Title: {title}" |
| 61 | + preliminaries_kw, _ = llm(systems=SYSTEM["preliminaries"], prompts=prompts) |
| 62 | + # check if the database exists or not |
| 63 | + db_path = f"utils/knowledge_databases/{knowledge_database}" |
| 64 | + db_config_path = os.path.join(db_path, "db_meta.json") |
| 65 | + db_index_path = os.path.join(db_path, "faiss_index") |
| 66 | + if os.path.isdir(db_path): |
| 67 | + try: |
| 68 | + with open(db_config_path, "r", encoding="utf-8") as f: |
| 69 | + db_config = json.load(f) |
| 70 | + model_name = db_config["embedding_model"] |
| 71 | + embeddings = EMBEDDINGS[model_name] |
| 72 | + db = FAISS.load_local(db_index_path, embeddings) |
| 73 | + knowledge = Knowledge(db=db) |
| 74 | + knowledge.collect_knowledge(preliminaries_kw, max_query=query_counts) |
| 75 | + domain_knowledge = knowledge.to_prompts(max_tokens_kd) |
| 76 | + except Exception as e: |
| 77 | + domain_knowledge='' |
| 78 | + prompts = f"Title: {title}" |
| 79 | + syetem_promot = "You are an assistant designed to propose necessary components of an survey papers. Your response should follow the JSON format." |
| 80 | + components, usage = llm(systems=syetem_promot, prompts=prompts, return_json=True) |
| 81 | + log_usage(usage, "media") |
| 82 | + print(f"The paper information has been initialized. References are saved to {bibtex_path}.") |
| 83 | + |
| 84 | + paper = {} |
| 85 | + paper["title"] = title |
| 86 | + paper["references"] = references |
| 87 | + paper["bibtex"] = bibtex_path |
| 88 | + paper["components"] = components |
| 89 | + paper["domain_knowledge"] = domain_knowledge |
| 90 | + return paper, destination_folder, all_paper_ids |
| 91 | + |
| 92 | + |
| 93 | +def section_generation(paper, section, save_to_path, model, research_field="machine learning"): |
| 94 | + """ |
| 95 | + The main pipeline of generating a section. |
| 96 | + 1. Generate prompts. |
| 97 | + 2. Get responses from AI assistant. |
| 98 | + 3. Extract the section text. |
| 99 | + 4. Save the text to .tex file. |
| 100 | + :return usage |
| 101 | + """ |
| 102 | + |
| 103 | + title = paper["title"] |
| 104 | + references = paper["references"] |
| 105 | + components = paper['components'] |
| 106 | + instruction = '- Discuss three to five main related fields to this paper. For each field, select five to ten key publications from references. For each reference, analyze its strengths and weaknesses in one or two sentences. Present the related works in a logical manner, often chronologically. Consider using a taxonomy or categorization to structure the discussion. Do not use \section{...} or \subsection{...}; use \paragraph{...} to list related fields.' |
| 107 | + |
| 108 | + |
| 109 | + fundamental_subprompt = "Your task is to write the {section} section of the paper with the title '{title}'. This paper has the following content: {components}\n" |
| 110 | + instruction_subprompt = "\n" \ |
| 111 | + "Your response should follow the following instructions:\n" \ |
| 112 | + "{instruction}\n" |
| 113 | + |
| 114 | + |
| 115 | + ref_instruction_subprompt = "- Read references. " \ |
| 116 | + "Every time you use information from the references, you need to appropriately cite it (using \citep or \citet)." \ |
| 117 | + "For example of \citep, the sentence where you use information from lei2022adaptive \citep{{lei2022adaptive}}. " \ |
| 118 | + "For example of \citet, \citet{{lei2022adaptive}} claims some information.\n" \ |
| 119 | + "- Avoid citing the same reference in a same paragraph.\n" \ |
| 120 | + "\n" \ |
| 121 | + "References:\n" \ |
| 122 | + "{references}" |
| 123 | + output_subprompt = "Ensure that it can be directly compiled by LeTaX." |
| 124 | + |
| 125 | + reivew_prompts = PromptTemplate( |
| 126 | + input_variables=["title", "components", "instruction", "section", "references"], |
| 127 | + template=fundamental_subprompt + instruction_subprompt + ref_instruction_subprompt + output_subprompt) |
| 128 | + prompts = reivew_prompts.format(title=title, |
| 129 | + components=components, |
| 130 | + instruction=instruction, |
| 131 | + section=section, |
| 132 | + references=references) |
| 133 | + SECTION_GENERATION_SYSTEM = PromptTemplate(input_variables=["research_field"], |
| 134 | + template="You are an assistant designed to write academic papers in the field of {research_field} using LaTeX." ) |
| 135 | + output, usage = get_gpt_responses(SECTION_GENERATION_SYSTEM.format(research_field=research_field), prompts, |
| 136 | + model=model, temperature=0.4) |
| 137 | + |
| 138 | + output=output[25:] |
| 139 | + tex_file = os.path.join(save_to_path, f"{section}.tex") |
| 140 | + with open(tex_file, "w", encoding="utf-8") as f: |
| 141 | + f.write(output) |
| 142 | + |
| 143 | + use_md =True |
| 144 | + use_chinese = True |
| 145 | + if use_md: |
| 146 | + system_md = 'You are an translator between the LaTeX and .MD. here is a latex file where the content is: \n \n ' + output |
| 147 | + prompts_md = 'you should transfer the latex content to the .MD format seriously, and pay attention to the correctness of the citation format (use the number). you should directly output the new content without anyoter replay. you should add reference papers at the end of the paper, and add line breaks between two reference papers. The Title should be ' + paper['title'] |
| 148 | + output_md, usage_md = get_gpt_responses(system_md, prompts_md, |
| 149 | + model=model, temperature=0.4) |
| 150 | + md_file = os.path.join(save_to_path, f"{'survey'}.md") |
| 151 | + with open(md_file, "w", encoding="utf-8") as m: |
| 152 | + m.write(output_md) |
| 153 | + |
| 154 | + if use_chinese == True: |
| 155 | + system_md_chi = 'You are an translator between the english and chinese. here is a english file where the content is: \n \n ' + output |
| 156 | + prompts_md_chi = 'you should transfer the english to chinese and dont change anything others. you should directly output the new content without anyoter replay. you should keep the reference papers unchanged.' |
| 157 | + |
| 158 | + output_md_chi, usage_md_chi = get_gpt_responses(system_md_chi, prompts_md_chi, |
| 159 | + model=model, temperature=0.4) |
| 160 | + md_file_chi = os.path.join(save_to_path, f"{'survey_chinese'}.md") |
| 161 | + with open(md_file_chi, "w", encoding="utf-8") as c: |
| 162 | + c.write(output_md_chi) |
| 163 | + return usage |
| 164 | + |
| 165 | + |
| 166 | +def generate_draft(title, tldr=True, max_kw_refs=20, bib_refs=None, max_tokens_ref=2048, |
| 167 | + knowledge_database=None, max_tokens_kd=2048, query_counts=10, |
| 168 | + section='related works', model="gpt-3.5-turbo-16k", template="Default" |
| 169 | + , save_zip=None): |
| 170 | + |
| 171 | + print("================START================") |
| 172 | + paper, destination_folder, _ = _generation_setup(title, template, tldr, max_kw_refs, bib_refs, |
| 173 | + max_tokens_ref=max_tokens_ref, max_tokens_kd=max_tokens_kd, |
| 174 | + query_counts=query_counts, |
| 175 | + knowledge_database=knowledge_database) |
| 176 | + |
| 177 | + # main components |
| 178 | + print(f"================PROCESSING================") |
| 179 | + usage = section_generation(paper, section, destination_folder, model=model) |
| 180 | + log_usage(usage, section) |
| 181 | + create_copies(destination_folder) |
| 182 | + print("\nPROCESSING COMPLETE\n") |
| 183 | + return make_archive(destination_folder, title+".zip") |
| 184 | + print("draft has been generated in " + destination_folder) |
| 185 | + |
| 186 | + |
| 187 | + |
| 188 | + |
| 189 | +if __name__ == "__main__": |
| 190 | + import openai |
| 191 | + |
| 192 | + |
| 193 | + openai.api_key = "your key" |
| 194 | + openai.api_base = 'https://api.openai.com/v1' |
| 195 | + |
| 196 | + #openai.proxy = "socks5h://localhost:7890 # if use the vpn |
| 197 | + target_title = "Reinforcement Learning for Robot Control" |
| 198 | + |
| 199 | + generate_draft(target_title, knowledge_database="ml_textbook_test",max_kw_refs=20) |
| 200 | + |
0 commit comments