Skip to content

Commit c2c70a7

Browse files
committed
working with latest
1 parent 93bcbaf commit c2c70a7

File tree

2 files changed

+66
-41
lines changed

2 files changed

+66
-41
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ edition = "2021"
66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
77

88
[dependencies]
9-
#llm = { git = "https://github.com/rustformers/llm.git" }
10-
llm = "0.1.1"
9+
llm = { git = "https://github.com/rustformers/llm.git" }
10+
#llm = "0.1.1"
1111
rand = "0.8.5"
1212
actix-files = "0.6.2"
1313
actix-web = "4"

src/main.rs

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,38 @@ use once_cell::sync::{Lazy, OnceCell};
88
use serde::{Deserialize, Serialize};
99
use std::collections::HashMap;
1010
use std::error::Error;
11-
use std::io::Write;
1211
use std::sync::mpsc::{sync_channel, SyncSender};
1312
use std::sync::Arc;
1413
use std::thread;
14+
use std::{convert::Infallible, io::Write, path::PathBuf};
1515
use tokio::sync::mpsc::{channel, Receiver};
1616
use tokio::sync::Mutex;
1717

1818
#[derive(Parser, Debug, Clone)]
19-
#[command(author, version, about, long_about = None)]
20-
pub enum Args {
21-
/// Use a LLaMA model
22-
#[command()]
23-
Llama(Box<BaseArgs>),
24-
}
25-
26-
#[derive(Parser, Debug, Clone)]
27-
pub struct BaseArgs {
28-
#[arg(short, long)]
29-
model: String,
30-
31-
#[arg(short, long)]
19+
struct Args {
20+
model_architecture: llm::ModelArchitecture,
21+
model_path: PathBuf,
22+
#[arg(long, short = 'v')]
23+
vocabulary_path: Option<PathBuf>,
24+
#[arg(long, short = 'r')]
25+
vocabulary_repository: Option<String>,
26+
#[arg(long, short = 'h')]
3227
host: String,
33-
34-
#[arg(short, long)]
28+
#[arg(long, short = 'p')]
3529
port: u16,
3630
}
31+
impl Args {
32+
pub fn to_vocabulary_source(&self) -> llm::VocabularySource {
33+
match (&self.vocabulary_path, &self.vocabulary_repository) {
34+
(Some(_), Some(_)) => {
35+
panic!("Cannot specify both --vocabulary-path and --vocabulary-repository");
36+
}
37+
(Some(path), None) => llm::VocabularySource::HuggingFaceTokenizerFile(path.to_owned()),
38+
(None, Some(repo)) => llm::VocabularySource::HuggingFaceRemote(repo.to_owned()),
39+
(None, None) => llm::VocabularySource::Model,
40+
}
41+
}
42+
}
3743

3844
const END: &str = "<<END>>";
3945
static TX_INFER: OnceCell<Arc<Mutex<SyncSender<String>>>> = OnceCell::new();
@@ -72,33 +78,54 @@ async fn chat(chat_request: web::Query<ChatRequest>) -> Result<impl Responder, B
7278
.streaming(Box::pin(stream_tasks)))
7379
}
7480

75-
fn infer<M: llm::KnownModel + 'static>(
76-
args: &BaseArgs,
81+
fn infer(
82+
args: &Args,
7783
rx_infer: std::sync::mpsc::Receiver<String>,
7884
tx_callback: tokio::sync::mpsc::Sender<String>,
7985
) -> Result<()> {
80-
let llm_model = llm::load::<llm::models::Llama>(
81-
std::path::Path::new(&args.model),
86+
87+
let vocabulary_source = args.to_vocabulary_source();
88+
let model_architecture = args.model_architecture;
89+
let model_path = &args.model_path;
90+
let now = std::time::Instant::now();
91+
92+
let llm_model = llm::load_dynamic(
93+
model_architecture,
94+
&model_path,
95+
vocabulary_source,
8296
Default::default(),
8397
llm::load_progress_callback_stdout,
8498
)
85-
.unwrap_or_else(|err| panic!("Failed to load model: {err}"));
99+
.unwrap_or_else(|err| {
100+
panic!("Failed to load {model_architecture} model from {model_path:?}: {err}")
101+
});
102+
103+
println!(
104+
"Model fully loaded! Elapsed: {}ms",
105+
now.elapsed().as_millis()
106+
);
86107

87108
while let Ok(msg) = rx_infer.recv() {
88109
let prompt = msg.to_string();
89110
let mut session = llm_model.start_session(Default::default());
111+
90112
let res = session.infer::<std::convert::Infallible>(
91-
&llm_model,
113+
llm_model.as_ref(),
92114
&mut rand::thread_rng(),
93115
&llm::InferenceRequest {
94-
prompt: &prompt,
116+
prompt: Some(prompt).as_deref().unwrap().into(),
117+
parameters: &llm::InferenceParameters::default(),
95118
play_back_previous_tokens: false,
96-
..Default::default()
119+
maximum_token_count: None,
97120
},
98121
&mut Default::default(),
99-
|t| {
100-
tx_callback.blocking_send(t.to_string());
101-
Ok(())
122+
|r| match r {
123+
llm::InferenceResponse::PromptToken(t)
124+
| llm::InferenceResponse::InferredToken(t) => {
125+
tx_callback.blocking_send(t.to_string());
126+
Ok(llm::InferenceFeedback::Continue)
127+
}
128+
_ => Ok(llm::InferenceFeedback::Continue),
102129
},
103130
);
104131

@@ -111,8 +138,8 @@ fn infer<M: llm::KnownModel + 'static>(
111138

112139
#[actix_web::main]
113140
async fn main() -> std::io::Result<()> {
114-
let cli_args = Args::parse();
115-
println!("{cli_args:#?}");
141+
let args = Args::parse();
142+
println!("{args:#?}");
116143

117144
let (tx_infer, rx_infer) = sync_channel::<String>(3);
118145
let (tx_callback, rx_callback) = channel::<String>(3);
@@ -121,23 +148,21 @@ async fn main() -> std::io::Result<()> {
121148
RX_CALLBACK.set(Arc::new(Mutex::new(rx_callback))).unwrap();
122149

123150
//"/home/jovyan/rust-src/llm-ui/models/ggml-model-q4_0.binA
124-
let c_args = cli_args.clone();
151+
let c_args = args.clone();
125152
thread::spawn(move || {
126-
match &cli_args {
127-
Args::Llama(args) => {
128-
infer::<llm::models::Llama>(&args, rx_infer, tx_callback);
129-
}
130-
};
153+
infer(&args, rx_infer, tx_callback);
131154
});
132155

133-
let (host, port) = match &c_args {
134-
Args::Llama(args) => (args.host.to_string(), args.port),
135-
};
156+
let host = c_args.host.to_string();
157+
let port: u16 = c_args.port;
136158

137159
HttpServer::new(|| {
138160
App::new()
139161
.route("/api/chat", web::get().to(chat))
140-
.service(fs::Files::new("/", std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("static")))
162+
.service(fs::Files::new(
163+
"/",
164+
std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("static"),
165+
))
141166
})
142167
.bind((host, port))?
143168
.run()

0 commit comments

Comments
 (0)