Skip to content

Commit ec89a66

Browse files
authored
[rust] Avoid panic in error case (#3133)
1 parent 6efe660 commit ec89a66

File tree

3 files changed

+36
-17
lines changed

3 files changed

+36
-17
lines changed

extensions/tokenizers/rust/src/models/mod.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ mod distilbert;
44
use crate::ndarray::as_data_type;
55
use crate::{cast_handle, to_handle, to_string_array};
66
use bert::{BertConfig, BertModel};
7-
use candle_core::DType;
7+
use candle_core::{DType, Error};
88
use candle_core::{Device, Result, Tensor};
99
use candle_nn::VarBuilder;
1010
use distilbert::{DistilBertConfig, DistilBertModel};
@@ -43,7 +43,10 @@ fn load_model<'local>(
4343

4444
// Load config
4545
let config: String = std::fs::read_to_string(model_path.join("config.json"))?;
46-
let config: Config = serde_json::from_str(&config).unwrap();
46+
let config: Config = match serde_json::from_str(&config) {
47+
Ok(conf) => conf,
48+
Err(err) => return Err(Error::wrap(err)),
49+
};
4750

4851
// Get candle device
4952
let device = if candle_core::utils::cuda_is_available() {
@@ -55,7 +58,7 @@ fn load_model<'local>(
5558
}?;
5659

5760
// Get candle dtype
58-
let dtype = as_data_type(dtype).unwrap();
61+
let dtype = as_data_type(dtype)?;
5962

6063
let safetensors_path = model_path.join("model.safetensors");
6164
let vb = if safetensors_path.exists() {

extensions/tokenizers/rust/src/ndarray/mod.rs

+22-14
Original file line numberDiff line numberDiff line change
@@ -295,22 +295,30 @@ fn as_device<'local>(env: &mut JNIEnv<'local>, device_type: JString, _: usize) -
295295
match device_type.as_str() {
296296
"cpu" => Ok(Device::Cpu),
297297
"gpu" => {
298-
let mut device = CUDA_DEVICE.lock().unwrap();
299-
if let Some(device) = device.as_ref() {
300-
return Ok(device.clone());
301-
};
302-
let d = Device::new_cuda(0).unwrap();
303-
*device = Some(d.clone());
304-
Ok(d)
298+
if candle_core::utils::cuda_is_available() {
299+
let mut device = CUDA_DEVICE.lock().unwrap();
300+
if let Some(device) = device.as_ref() {
301+
return Ok(device.clone());
302+
};
303+
let d = Device::new_cuda(0).unwrap();
304+
*device = Some(d.clone());
305+
Ok(d)
306+
} else {
307+
Err(Error::Msg(String::from("CUDA is not available.")))
308+
}
305309
}
306310
"mps" => {
307-
let mut device = METAL_DEVICE.lock().unwrap();
308-
if let Some(device) = device.as_ref() {
309-
return Ok(device.clone());
310-
};
311-
let d = Device::new_metal(0).unwrap();
312-
*device = Some(d.clone());
313-
Ok(d)
311+
if candle_core::utils::metal_is_available() {
312+
let mut device = METAL_DEVICE.lock().unwrap();
313+
if let Some(device) = device.as_ref() {
314+
return Ok(device.clone());
315+
};
316+
let d = Device::new_metal(0).unwrap();
317+
*device = Some(d.clone());
318+
Ok(d)
319+
} else {
320+
Err(Error::Msg(String::from("metal is not available.")))
321+
}
314322
}
315323
_ => Err(Error::Msg(format!("Invalid device type: {}", device_type))),
316324
}

extensions/tokenizers/src/main/java/ai/djl/engine/rust/RsModel.java

+8
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ public void load(Path modelPath, String prefix, Map<String, ?> options)
4949
"Model directory doesn't exist: " + modelPath.toAbsolutePath());
5050
}
5151
modelDir = modelPath.toAbsolutePath();
52+
Path config = modelDir.resolve("config.json");
53+
if (!Files.isRegularFile(config)) {
54+
throw new FileNotFoundException("config.json file not found");
55+
}
56+
Path file = modelDir.resolve("model.safetensors");
57+
if (!Files.isRegularFile(file)) {
58+
throw new FileNotFoundException("model.safetensors file not found");
59+
}
5260
long handle = RustLibrary.loadModel(modelDir.toString(), dataType.ordinal());
5361
block = new RsSymbolBlock((RsNDManager) manager, handle);
5462
}

0 commit comments

Comments
 (0)