-
Notifications
You must be signed in to change notification settings - Fork 0
/
ui.py
118 lines (95 loc) · 5.22 KB
/
ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from smain import main
from summary.util.check_db import excute_sqlite_sql
from summary.util.mp3_from_mp4 import get_media_files
from summary.util.text_from_mp3 import get_whisper_model, get_whisper_text
from summary.config import (model_size_or_path,
file_default_path, company_name, create_table_sql)
from fastapi import FastAPI
app = FastAPI()
SummaryType = {
"总体摘要": "SumMp4All",
"章节摘要": "SumMp4Step",
"总体纪要": "SumTextAll",
}
reRunType = {
"是": True,
"否": False,
}
def doIt(summary_type, file_Path, file_get_type='upload', file_Info=None, re_run=False):
if file_get_type == 'choose':
file_Path = file_default_path + "/" + file_Path
summaryType = SummaryType[summary_type]
fileInfo = file_Info
reRun = reRunType[re_run]
print(summaryType, fileInfo, reRun)
whisperModel = get_whisper_model(model_size_or_path)
return main(summaryType, file_Path, fileInfo, whisperModel, reRun)
from typing import Literal
@app.post("/media")
def getTextApi(audio_path: str,
initial_prompt: str,
mode: Literal['timeline', 'normal', 'subtitle']
):
whisperModel = get_whisper_model(model_size_or_path)
transcription = get_whisper_text(whisperModel, audio_path, initial_prompt=initial_prompt, mode=mode)
return transcription
@app.get("/")
def create_chain_app():
with gr.Blocks(title=company_name) as demo:
with gr.Tab(label='上传转录'):
with gr.Row():
media_upload_block = gr.File(file_count='single', file_types=['audio', 'video'],
label='上传媒体文件', scale=6)
media_upload_block.GRADIO_CACHE = file_default_path
video_component = gr.Video(label='视频预览', scale=3)
audio_component = gr.Audio(label='音频预览', scale=3)
with gr.Row():
input_textbox = gr.Textbox(label='媒体关键字', scale=10)
input_type = gr.Dropdown(label='生成摘要类型', choices=['总体摘要', '章节摘要', '总体纪要'],
value='总体纪要',
scale=2)
rerun_type = gr.Dropdown(label='是否重跑', choices=['否', '是'], value='否', scale=1)
submit_button = gr.Button(value='Generate', variant='primary', scale=4)
with gr.Row():
output_textbox1 = gr.Textbox(lines=10, label='转录文本', scale=4)
output_textbox2 = gr.Textbox(lines=10, label='摘要文本', scale=6)
submit_button.click(fn=doIt,
inputs=[input_type, media_upload_block, gr.Textbox('upload', visible=False),
input_textbox,
rerun_type],
outputs=[output_textbox1, output_textbox2])
media_upload_block.upload(fn=lambda x: x, inputs=[media_upload_block], outputs=[video_component])
media_upload_block.upload(fn=lambda x: x, inputs=[media_upload_block], outputs=[audio_component])
with gr.Tab(label='选择转录'):
with gr.Row():
media_files = get_media_files(file_default_path)
media_select_block = gr.Dropdown(label='选择媒体文件',
choices=media_files, scale=6)
video_preview = gr.Video(label='视频预览', scale=3)
audio_preview = gr.Audio(label='音频预览', scale=3)
# Add a preview button to play the selected media file
preview_button = gr.Button(value='预览', variant='secondary', scale=1)
preview_button.click(fn=lambda x: (f'{file_default_path}/{x}', f'{file_default_path}/{x}'),
inputs=[media_select_block],
outputs=[video_preview, audio_preview])
with gr.Row():
input_textbox = gr.Textbox(label='媒体关键字', scale=10)
input_type = gr.Dropdown(label='生成摘要类型', choices=['总体摘要', '章节摘要', '总体纪要'],
value='总体纪要',
scale=2)
rerun_type = gr.Dropdown(label='是否重跑', choices=['否', '是'], value='否', scale=1)
submit_button = gr.Button(value='Generate', variant='primary', scale=4)
with gr.Row():
output_textbox1 = gr.Textbox(lines=10, label='转录文本', scale=4)
output_textbox2 = gr.Textbox(lines=10, label='摘要文本', scale=6)
submit_button.click(fn=doIt,
inputs=[input_type, media_select_block, gr.Textbox('choose', visible=False),
input_textbox,
rerun_type],
outputs=[output_textbox1, output_textbox2])
return demo
io = create_chain_app()
CUSTOM_PATH = "/" + company_name
excute_sqlite_sql(create_table_sql)
app = gr.mount_gradio_app(app, io, path=CUSTOM_PATH)