@@ -8,32 +8,38 @@ use once_cell::sync::{Lazy, OnceCell};
8
8
use serde:: { Deserialize , Serialize } ;
9
9
use std:: collections:: HashMap ;
10
10
use std:: error:: Error ;
11
- use std:: io:: Write ;
12
11
use std:: sync:: mpsc:: { sync_channel, SyncSender } ;
13
12
use std:: sync:: Arc ;
14
13
use std:: thread;
14
+ use std:: { convert:: Infallible , io:: Write , path:: PathBuf } ;
15
15
use tokio:: sync:: mpsc:: { channel, Receiver } ;
16
16
use tokio:: sync:: Mutex ;
17
17
18
18
#[ derive( Parser , Debug , Clone ) ]
19
- #[ command( author, version, about, long_about = None ) ]
20
- pub enum Args {
21
- /// Use a LLaMA model
22
- #[ command( ) ]
23
- Llama ( Box < BaseArgs > ) ,
24
- }
25
-
26
- #[ derive( Parser , Debug , Clone ) ]
27
- pub struct BaseArgs {
28
- #[ arg( short, long) ]
29
- model : String ,
30
-
31
- #[ arg( short, long) ]
19
+ struct Args {
20
+ model_architecture : llm:: ModelArchitecture ,
21
+ model_path : PathBuf ,
22
+ #[ arg( long, short = 'v' ) ]
23
+ vocabulary_path : Option < PathBuf > ,
24
+ #[ arg( long, short = 'r' ) ]
25
+ vocabulary_repository : Option < String > ,
26
+ #[ arg( long, short = 'h' ) ]
32
27
host : String ,
33
-
34
- #[ arg( short, long) ]
28
+ #[ arg( long, short = 'p' ) ]
35
29
port : u16 ,
36
30
}
31
+ impl Args {
32
+ pub fn to_vocabulary_source ( & self ) -> llm:: VocabularySource {
33
+ match ( & self . vocabulary_path , & self . vocabulary_repository ) {
34
+ ( Some ( _) , Some ( _) ) => {
35
+ panic ! ( "Cannot specify both --vocabulary-path and --vocabulary-repository" ) ;
36
+ }
37
+ ( Some ( path) , None ) => llm:: VocabularySource :: HuggingFaceTokenizerFile ( path. to_owned ( ) ) ,
38
+ ( None , Some ( repo) ) => llm:: VocabularySource :: HuggingFaceRemote ( repo. to_owned ( ) ) ,
39
+ ( None , None ) => llm:: VocabularySource :: Model ,
40
+ }
41
+ }
42
+ }
37
43
38
44
const END : & str = "<<END>>" ;
39
45
static TX_INFER : OnceCell < Arc < Mutex < SyncSender < String > > > > = OnceCell :: new ( ) ;
@@ -72,33 +78,54 @@ async fn chat(chat_request: web::Query<ChatRequest>) -> Result<impl Responder, B
72
78
. streaming ( Box :: pin ( stream_tasks) ) )
73
79
}
74
80
75
- fn infer < M : llm :: KnownModel + ' static > (
76
- args : & BaseArgs ,
81
+ fn infer (
82
+ args : & Args ,
77
83
rx_infer : std:: sync:: mpsc:: Receiver < String > ,
78
84
tx_callback : tokio:: sync:: mpsc:: Sender < String > ,
79
85
) -> Result < ( ) > {
80
- let llm_model = llm:: load :: < llm:: models:: Llama > (
81
- std:: path:: Path :: new ( & args. model ) ,
86
+
87
+ let vocabulary_source = args. to_vocabulary_source ( ) ;
88
+ let model_architecture = args. model_architecture ;
89
+ let model_path = & args. model_path ;
90
+ let now = std:: time:: Instant :: now ( ) ;
91
+
92
+ let llm_model = llm:: load_dynamic (
93
+ model_architecture,
94
+ & model_path,
95
+ vocabulary_source,
82
96
Default :: default ( ) ,
83
97
llm:: load_progress_callback_stdout,
84
98
)
85
- . unwrap_or_else ( |err| panic ! ( "Failed to load model: {err}" ) ) ;
99
+ . unwrap_or_else ( |err| {
100
+ panic ! ( "Failed to load {model_architecture} model from {model_path:?}: {err}" )
101
+ } ) ;
102
+
103
+ println ! (
104
+ "Model fully loaded! Elapsed: {}ms" ,
105
+ now. elapsed( ) . as_millis( )
106
+ ) ;
86
107
87
108
while let Ok ( msg) = rx_infer. recv ( ) {
88
109
let prompt = msg. to_string ( ) ;
89
110
let mut session = llm_model. start_session ( Default :: default ( ) ) ;
111
+
90
112
let res = session. infer :: < std:: convert:: Infallible > (
91
- & llm_model,
113
+ llm_model. as_ref ( ) ,
92
114
& mut rand:: thread_rng ( ) ,
93
115
& llm:: InferenceRequest {
94
- prompt : & prompt,
116
+ prompt : Some ( prompt) . as_deref ( ) . unwrap ( ) . into ( ) ,
117
+ parameters : & llm:: InferenceParameters :: default ( ) ,
95
118
play_back_previous_tokens : false ,
96
- .. Default :: default ( )
119
+ maximum_token_count : None ,
97
120
} ,
98
121
& mut Default :: default ( ) ,
99
- |t| {
100
- tx_callback. blocking_send ( t. to_string ( ) ) ;
101
- Ok ( ( ) )
122
+ |r| match r {
123
+ llm:: InferenceResponse :: PromptToken ( t)
124
+ | llm:: InferenceResponse :: InferredToken ( t) => {
125
+ tx_callback. blocking_send ( t. to_string ( ) ) ;
126
+ Ok ( llm:: InferenceFeedback :: Continue )
127
+ }
128
+ _ => Ok ( llm:: InferenceFeedback :: Continue ) ,
102
129
} ,
103
130
) ;
104
131
@@ -111,8 +138,8 @@ fn infer<M: llm::KnownModel + 'static>(
111
138
112
139
#[ actix_web:: main]
113
140
async fn main ( ) -> std:: io:: Result < ( ) > {
114
- let cli_args = Args :: parse ( ) ;
115
- println ! ( "{cli_args :#?}" ) ;
141
+ let args = Args :: parse ( ) ;
142
+ println ! ( "{args :#?}" ) ;
116
143
117
144
let ( tx_infer, rx_infer) = sync_channel :: < String > ( 3 ) ;
118
145
let ( tx_callback, rx_callback) = channel :: < String > ( 3 ) ;
@@ -121,23 +148,21 @@ async fn main() -> std::io::Result<()> {
121
148
RX_CALLBACK . set ( Arc :: new ( Mutex :: new ( rx_callback) ) ) . unwrap ( ) ;
122
149
123
150
//"/home/jovyan/rust-src/llm-ui/models/ggml-model-q4_0.binA
124
- let c_args = cli_args . clone ( ) ;
151
+ let c_args = args . clone ( ) ;
125
152
thread:: spawn ( move || {
126
- match & cli_args {
127
- Args :: Llama ( args) => {
128
- infer :: < llm:: models:: Llama > ( & args, rx_infer, tx_callback) ;
129
- }
130
- } ;
153
+ infer ( & args, rx_infer, tx_callback) ;
131
154
} ) ;
132
155
133
- let ( host, port) = match & c_args {
134
- Args :: Llama ( args) => ( args. host . to_string ( ) , args. port ) ,
135
- } ;
156
+ let host = c_args. host . to_string ( ) ;
157
+ let port: u16 = c_args. port ;
136
158
137
159
HttpServer :: new ( || {
138
160
App :: new ( )
139
161
. route ( "/api/chat" , web:: get ( ) . to ( chat) )
140
- . service ( fs:: Files :: new ( "/" , std:: path:: Path :: new ( env ! ( "CARGO_MANIFEST_DIR" ) ) . join ( "static" ) ) )
162
+ . service ( fs:: Files :: new (
163
+ "/" ,
164
+ std:: path:: Path :: new ( env ! ( "CARGO_MANIFEST_DIR" ) ) . join ( "static" ) ,
165
+ ) )
141
166
} )
142
167
. bind ( ( host, port) ) ?
143
168
. run ( )
0 commit comments