Skip to content

Commit f9ad268

Browse files
authored
Fix reading modules.json for Dense modules in local models (#738)
1 parent 9ef569d commit f9ad268

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

backends/candle/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,9 +511,8 @@ impl CandleBackend {
511511
}
512512
};
513513

514-
// Load Dense layers from the provided Dense paths
515514
let mut dense_layers = Vec::new();
516-
if let Some(dense_paths) = &dense_paths {
515+
if let Some(dense_paths) = dense_paths {
517516
if !dense_paths.is_empty() {
518517
tracing::info!("Loading Dense module/s from path/s: {dense_paths:?}");
519518

backends/src/lib.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,40 @@ async fn init_backend(
433433
tracing::info!("Dense modules downloaded in {:?}", start.elapsed());
434434
Some(dense_paths)
435435
} else {
436-
None
436+
// TODO(alvarobartt): eventually detach the Sentence Transformers module handling
437+
// to prevent from duplicated code here and there
438+
// For local models, try to parse modules.json and handle dense_path logic
439+
let modules_json_path = model_path.join("modules.json");
440+
if modules_json_path.exists() {
441+
match parse_dense_paths_from_modules(&modules_json_path).await {
442+
Ok(module_paths) => match module_paths.len() {
443+
0 => Some(vec![]),
444+
1 => {
445+
let path_to_use = if let Some(ref user_path) = dense_path {
446+
if user_path != &module_paths[0] {
447+
tracing::info!("`{}` found in `modules.json`, but using provided `--dense-path={user_path}` instead", module_paths[0]);
448+
}
449+
user_path.clone()
450+
} else {
451+
module_paths[0].clone()
452+
};
453+
Some(vec![path_to_use])
454+
}
455+
_ => {
456+
if dense_path.is_some() {
457+
tracing::warn!("A value for `--dense-path` was provided, but since there's more than one subsequent Dense module, then the provided value will be ignored.");
458+
}
459+
Some(module_paths)
460+
}
461+
},
462+
Err(err) => {
463+
tracing::warn!("Failed to parse local modules.json: {err}");
464+
None
465+
}
466+
}
467+
} else {
468+
None
469+
}
437470
};
438471

439472
let backend = CandleBackend::new(

0 commit comments

Comments
 (0)