Skip to content

Commit 1cdec83

Browse files
authored
[Developer QoL]: Use nicer Candle Error APIs (#767)
1 parent f85ddcd commit 1cdec83

File tree

15 files changed

+38
-43
lines changed

15 files changed

+38
-43
lines changed

Cargo.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ license = "MIT"
2525

2626
[workspace.dependencies]
2727
anyhow = "1.0.80"
28-
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "91e0c6e" }
29-
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "91e0c6e" }
28+
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "ad84486" }
29+
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "ad84486" }
3030
serde = "1.0.197"
3131
serde_json = "1.0.114"
3232
indexmap = { version = "2.2.5", features = ["serde"] }

mistralrs-core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ candle-core.workspace = true
1717
candle-nn.workspace = true
1818
serde.workspace = true
1919
serde_json.workspace = true
20-
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "91e0c6e", optional = true }
20+
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.6.0", rev = "ad84486", optional = true }
2121
dirs = "5.0.1"
2222
hf-hub = "0.3.2"
2323
thiserror = "1.0.57"

mistralrs-core/src/device_map.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ impl DeviceMapper for LayerDeviceMapper {
201201
fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
202202
dtype
203203
.try_into_dtype(&self.mappings.iter().collect::<Vec<_>>())
204-
.map_err(|e| candle_core::Error::Msg(format!("{e:?}")))
204+
.map_err(candle_core::Error::msg)
205205
}
206206
}
207207

@@ -249,6 +249,6 @@ impl DeviceMapper for DummyDeviceMapper {
249249
fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
250250
dtype
251251
.try_into_dtype(&[&self.nm_device])
252-
.map_err(|e| candle_core::Error::Msg(format!("{e:?}")))
252+
.map_err(candle_core::Error::msg)
253253
}
254254
}

mistralrs-core/src/engine/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ impl Engine {
538538
let prompt = get_mut_arcmutex!(self.pipeline)
539539
.tokenizer()
540540
.encode(text, true)
541-
.map_err(|e| anyhow::Error::msg(e.to_string()));
541+
.map_err(anyhow::Error::msg);
542542
handle_seq_error!(prompt, request.response)
543543
.get_ids()
544544
.to_vec()

mistralrs-core/src/pipeline/amoe.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
367367
0.0,
368368
vec![],
369369
)
370-
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
370+
.map_err(candle_core::Error::msg)?;
371371

372372
let dummy_group = Arc::new(tokio::sync::Mutex::new(SequenceGroup::new(
373373
1, false, false, 0,
@@ -402,7 +402,7 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
402402
true,
403403
Vec::new(),
404404
)
405-
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
405+
.map_err(candle_core::Error::msg)?;
406406
let images = image_urls.as_ref().map(|urls| {
407407
urls.iter()
408408
.map(|url| -> anyhow::Result<DynamicImage> {
@@ -511,26 +511,23 @@ impl AnyMoePipelineMixin for AnyMoePipeline {
511511
candle_core::bail!("`loss_csv_path` must have an extension `csv`.");
512512
}
513513

514-
let mut writer =
515-
csv::Writer::from_path(path).map_err(|e| candle_core::Error::Msg(e.to_string()))?;
514+
let mut writer = csv::Writer::from_path(path).map_err(candle_core::Error::msg)?;
516515

517516
let mut header = vec![format!("Step")];
518517
header.extend((0..all_losses[0].len()).map(|i| format!("Gating layer {i}")));
519518
writer
520519
.write_record(&header)
521-
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
520+
.map_err(candle_core::Error::msg)?;
522521

523522
for (i, row) in all_losses.into_iter().enumerate() {
524523
let mut new_row = vec![format!("Step {i}")];
525524
new_row.extend(row.iter().map(|x| format!("{x:.4}")));
526525
writer
527526
.write_record(&new_row)
528-
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
527+
.map_err(candle_core::Error::msg)?;
529528
}
530529

531-
writer
532-
.flush()
533-
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
530+
writer.flush().map_err(candle_core::Error::msg)?;
534531
}
535532

536533
Ok(Some(AnyMoeTrainingResult {

mistralrs-core/src/pipeline/isq.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ pub trait IsqModel {
225225
let pool = rayon::ThreadPoolBuilder::new()
226226
.num_threads(minimum_max_threads)
227227
.build()
228-
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
228+
.map_err(candle_core::Error::msg)?;
229229

230230
pool.install(|| {
231231
use indicatif::ParallelProgressIterator;

mistralrs-core/src/pipeline/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ pub trait Pipeline:
270270
let InputProcessorOutput {
271271
inputs,
272272
seq_indices,
273-
} = inputs.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
273+
} = inputs.map_err(candle_core::Error::msg)?;
274274
if i == 0 {
275275
match pre_op {
276276
CacheInstruction::In(ref adapter_inst) => {
@@ -404,7 +404,7 @@ pub trait Pipeline:
404404
let InputProcessorOutput {
405405
inputs,
406406
seq_indices,
407-
} = inputs.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
407+
} = inputs.map_err(candle_core::Error::msg)?;
408408

409409
let raw_logits = self.forward_inputs(inputs)?;
410410

mistralrs-core/src/pipeline/normal.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -602,16 +602,16 @@ impl AnyMoePipelineMixin for NormalPipeline {
602602
) -> candle_core::Result<()> {
603603
let mut vbs = Vec::new();
604604
// Precompile regex here
605-
let regex = Regex::new(match_regex).map_err(|e| candle_core::Error::Msg(e.to_string()))?;
605+
let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
606606
for model_id in model_ids {
607607
let model_id_str = &model_id;
608608
let model_id = Path::new(&model_id);
609609

610610
let api = ApiBuilder::new()
611611
.with_progress(!silent)
612-
.with_token(get_token(token).map_err(|e| candle_core::Error::Msg(e.to_string()))?)
612+
.with_token(get_token(token).map_err(candle_core::Error::msg)?)
613613
.build()
614-
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
614+
.map_err(candle_core::Error::msg)?;
615615
let revision = revision.clone().unwrap_or("main".to_string());
616616
let api = api.repo(Repo::with_revision(
617617
model_id_str.clone(),
@@ -651,9 +651,9 @@ impl AnyMoePipelineMixin for NormalPipeline {
651651

652652
let api = ApiBuilder::new()
653653
.with_progress(!silent)
654-
.with_token(get_token(token).map_err(|e| candle_core::Error::Msg(e.to_string()))?)
654+
.with_token(get_token(token).map_err(candle_core::Error::msg)?)
655655
.build()
656-
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
656+
.map_err(candle_core::Error::msg)?;
657657
let revision = revision.clone().unwrap_or("main".to_string());
658658
let api = api.repo(Repo::with_revision(
659659
model_id_str.clone(),

mistralrs-core/src/pipeline/processing.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ pub trait Processor {
4747
let encoding = pipeline
4848
.tokenizer()
4949
.encode(prompt, true)
50-
.map_err(|e| anyhow::Error::msg(e.to_string()))?;
50+
.map_err(anyhow::Error::msg)?;
5151
Ok(encoding.get_ids().to_vec())
5252
}
5353
fn inputs_processor(&self) -> Arc<dyn InputsProcessor>;

0 commit comments

Comments
 (0)