Skip to content

Commit edb64d6

Browse files
authored
Reduce the scope of the unsafe block across the code base (#319)
1 parent c70413b commit edb64d6

File tree

19 files changed

+1649
-1831
lines changed

19 files changed

+1649
-1831
lines changed

src/algorithm/mod.rs

Lines changed: 135 additions & 137 deletions
Large diffs are not rendered by default.

src/blas/mod.rs

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -130,20 +130,20 @@ pub fn gemm<T>(
130130
) where
131131
T: HasAfEnum + FloatingPoint,
132132
{
133-
unsafe {
134-
let mut out = output.get();
135-
let err_val = af_gemm(
133+
let mut out = unsafe { output.get() };
134+
let err_val = unsafe {
135+
af_gemm(
136136
&mut out as *mut af_array,
137137
optlhs as c_uint,
138138
optrhs as c_uint,
139139
alpha.as_ptr() as *const c_void,
140140
lhs.get(),
141141
rhs.get(),
142142
beta.as_ptr() as *const c_void,
143-
);
144-
HANDLE_ERROR(AfError::from(err_val));
145-
output.set(out);
146-
}
143+
)
144+
};
145+
HANDLE_ERROR(AfError::from(err_val));
146+
output.set(out);
147147
}
148148

149149
/// Matrix multiple of two Arrays
@@ -162,18 +162,18 @@ pub fn matmul<T>(lhs: &Array<T>, rhs: &Array<T>, optlhs: MatProp, optrhs: MatPro
162162
where
163163
T: HasAfEnum + FloatingPoint,
164164
{
165-
unsafe {
166-
let mut temp: af_array = std::ptr::null_mut();
167-
let err_val = af_matmul(
165+
let mut temp: af_array = std::ptr::null_mut();
166+
let err_val = unsafe {
167+
af_matmul(
168168
&mut temp as *mut af_array,
169169
lhs.get(),
170170
rhs.get(),
171171
optlhs as c_uint,
172172
optrhs as c_uint,
173-
);
174-
HANDLE_ERROR(AfError::from(err_val));
175-
temp.into()
176-
}
173+
)
174+
};
175+
HANDLE_ERROR(AfError::from(err_val));
176+
temp.into()
177177
}
178178

179179
/// Calculate the dot product of vectors.
@@ -194,18 +194,18 @@ pub fn dot<T>(lhs: &Array<T>, rhs: &Array<T>, optlhs: MatProp, optrhs: MatProp)
194194
where
195195
T: HasAfEnum + FloatingPoint,
196196
{
197-
unsafe {
198-
let mut temp: af_array = std::ptr::null_mut();
199-
let err_val = af_dot(
197+
let mut temp: af_array = std::ptr::null_mut();
198+
let err_val = unsafe {
199+
af_dot(
200200
&mut temp as *mut af_array,
201201
lhs.get(),
202202
rhs.get(),
203203
optlhs as c_uint,
204204
optrhs as c_uint,
205-
);
206-
HANDLE_ERROR(AfError::from(err_val));
207-
temp.into()
208-
}
205+
)
206+
};
207+
HANDLE_ERROR(AfError::from(err_val));
208+
temp.into()
209209
}
210210

211211
/// Transpose of a matrix.
@@ -220,12 +220,10 @@ where
220220
///
221221
/// Transposed Array.
222222
pub fn transpose<T: HasAfEnum>(arr: &Array<T>, conjugate: bool) -> Array<T> {
223-
unsafe {
224-
let mut temp: af_array = std::ptr::null_mut();
225-
let err_val = af_transpose(&mut temp as *mut af_array, arr.get(), conjugate);
226-
HANDLE_ERROR(AfError::from(err_val));
227-
temp.into()
228-
}
223+
let mut temp: af_array = std::ptr::null_mut();
224+
let err_val = unsafe { af_transpose(&mut temp as *mut af_array, arr.get(), conjugate) };
225+
HANDLE_ERROR(AfError::from(err_val));
226+
temp.into()
229227
}
230228

231229
/// Inplace transpose of a matrix.
@@ -236,10 +234,8 @@ pub fn transpose<T: HasAfEnum>(arr: &Array<T>, conjugate: bool) -> Array<T> {
236234
/// - `conjugate` is a boolean that indicates if the transpose operation needs to be a conjugate
237235
/// transpose
238236
pub fn transpose_inplace<T: HasAfEnum>(arr: &mut Array<T>, conjugate: bool) {
239-
unsafe {
240-
let err_val = af_transpose_inplace(arr.get(), conjugate);
241-
HANDLE_ERROR(AfError::from(err_val));
242-
}
237+
let err_val = unsafe { af_transpose_inplace(arr.get(), conjugate) };
238+
HANDLE_ERROR(AfError::from(err_val));
243239
}
244240

245241
/// Sets the cuBLAS math mode for the internal handle.

src/core/arith.rs

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,10 @@ where
108108
type Output = Array<T>;
109109

110110
fn not(self) -> Self::Output {
111-
unsafe {
112-
let mut temp: af_array = std::ptr::null_mut();
113-
let err_val = af_not(&mut temp as *mut af_array, self.get());
114-
HANDLE_ERROR(AfError::from(err_val));
115-
temp.into()
116-
}
111+
let mut temp: af_array = std::ptr::null_mut();
112+
let err_val = unsafe { af_not(&mut temp as *mut af_array, self.get()) };
113+
HANDLE_ERROR(AfError::from(err_val));
114+
temp.into()
117115
}
118116
}
119117

@@ -124,12 +122,12 @@ macro_rules! unary_func {
124122
/// This is an element wise unary operation.
125123
pub fn $fn_name<T: HasAfEnum>(input: &Array<T>) -> Array< T::$out_type >
126124
where T::$out_type: HasAfEnum {
127-
unsafe {
125+
128126
let mut temp: af_array = std::ptr::null_mut();
129-
let err_val = $ffi_fn(&mut temp as *mut af_array, input.get());
127+
let err_val = unsafe { $ffi_fn(&mut temp as *mut af_array, input.get()) };
130128
HANDLE_ERROR(AfError::from(err_val));
131129
temp.into()
132-
}
130+
133131
}
134132
)
135133
}
@@ -256,12 +254,12 @@ macro_rules! unary_boolean_func {
256254
///
257255
/// This is an element wise unary operation.
258256
pub fn $fn_name<T: HasAfEnum>(input: &Array<T>) -> Array<bool> {
259-
unsafe {
257+
260258
let mut temp: af_array = std::ptr::null_mut();
261-
let err_val = $ffi_fn(&mut temp as *mut af_array, input.get());
259+
let err_val = unsafe { $ffi_fn(&mut temp as *mut af_array, input.get()) };
262260
HANDLE_ERROR(AfError::from(err_val));
263261
temp.into()
264-
}
262+
265263
}
266264
)
267265
}
@@ -291,12 +289,11 @@ macro_rules! binary_func {
291289
A: ImplicitPromote<B>,
292290
B: ImplicitPromote<A>,
293291
{
294-
unsafe {
295-
let mut temp: af_array = std::ptr::null_mut();
296-
let err_val = $ffi_fn(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch);
297-
HANDLE_ERROR(AfError::from(err_val));
298-
Into::<Array<A::Output>>::into(temp)
299-
}
292+
let mut temp: af_array = std::ptr::null_mut();
293+
let err_val =
294+
unsafe { $ffi_fn(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch) };
295+
HANDLE_ERROR(AfError::from(err_val));
296+
Into::<Array<A::Output>>::into(temp)
300297
}
301298
};
302299
}
@@ -389,12 +386,11 @@ macro_rules! overloaded_binary_func {
389386
A: ImplicitPromote<B>,
390387
B: ImplicitPromote<A>,
391388
{
392-
unsafe {
393-
let mut temp: af_array = std::ptr::null_mut();
394-
let err_val = $ffi_name(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch);
395-
HANDLE_ERROR(AfError::from(err_val));
396-
temp.into()
397-
}
389+
let mut temp: af_array = std::ptr::null_mut();
390+
let err_val =
391+
unsafe { $ffi_name(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch) };
392+
HANDLE_ERROR(AfError::from(err_val));
393+
temp.into()
398394
}
399395

400396
#[doc=$doc_str]
@@ -491,12 +487,11 @@ macro_rules! overloaded_logic_func {
491487
A: ImplicitPromote<B>,
492488
B: ImplicitPromote<A>,
493489
{
494-
unsafe {
495-
let mut temp: af_array = std::ptr::null_mut();
496-
let err_val = $ffi_name(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch);
497-
HANDLE_ERROR(AfError::from(err_val));
498-
temp.into()
499-
}
490+
let mut temp: af_array = std::ptr::null_mut();
491+
let err_val =
492+
unsafe { $ffi_name(&mut temp as *mut af_array, lhs.get(), rhs.get(), batch) };
493+
HANDLE_ERROR(AfError::from(err_val));
494+
temp.into()
500495
}
501496

502497
#[doc=$doc_str]
@@ -611,18 +606,18 @@ where
611606
X: ImplicitPromote<Y>,
612607
Y: ImplicitPromote<X>,
613608
{
614-
unsafe {
615-
let mut temp: af_array = std::ptr::null_mut();
616-
let err_val = af_clamp(
609+
let mut temp: af_array = std::ptr::null_mut();
610+
let err_val = unsafe {
611+
af_clamp(
617612
&mut temp as *mut af_array,
618613
inp.get(),
619614
lo.get(),
620615
hi.get(),
621616
batch,
622-
);
623-
HANDLE_ERROR(AfError::from(err_val));
624-
temp.into()
625-
}
617+
)
618+
};
619+
HANDLE_ERROR(AfError::from(err_val));
620+
temp.into()
626621
}
627622

628623
/// Clamp the values of Array
@@ -979,10 +974,8 @@ pub fn bitnot<T: HasAfEnum>(input: &Array<T>) -> Array<T>
979974
where
980975
T: HasAfEnum + IntegralType,
981976
{
982-
unsafe {
983-
let mut temp: af_array = std::ptr::null_mut();
984-
let err_val = af_bitnot(&mut temp as *mut af_array, input.get());
985-
HANDLE_ERROR(AfError::from(err_val));
986-
temp.into()
987-
}
977+
let mut temp: af_array = std::ptr::null_mut();
978+
let err_val = unsafe { af_bitnot(&mut temp as *mut af_array, input.get()) };
979+
HANDLE_ERROR(AfError::from(err_val));
980+
temp.into()
988981
}

0 commit comments

Comments
 (0)