Skip to content

Commit 11688b2

Browse files
authored
coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci] (#2979)
* coreml: fix Whisper to CoreML conversion by disabling SDPA This commit disables the use of PyTorch's `scaled_dot_product_attention` in the Whisper model to avoid compatibility issues during CoreML conversion. The issue occurs because coremltools requires PyTorch 2.5.0, but the Whisper implementation may expect behavior from newer PyTorch versions. By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to use its fallback manual attention implementation, which works correctly with PyTorch 2.5.0 during the tracing process. Refs: #2783 * coreml: fix audio shape in whisper decoder conversion This commit fixes the audio shape in the whisper decoder conversion script. The motivation for this is that the audio shape was incorrect and was causing the conversion to fail. * coreml : set -e in generate-coreml-interface.sh The commit sets the -e flag in the generate-coreml-interface.sh script to make sure the script fails if any command fails. * coreml : update generated encoder/decoder interfaces This commit updates the generated encoder/decoder interfaces for the whisper model which is the result of running the generate-coreml-interface.sh script.
1 parent 04b9508 commit 11688b2

File tree

6 files changed

+125
-39
lines changed

6 files changed

+125
-39
lines changed

models/convert-whisper-to-coreml.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
1313
from whisper import load_model
1414

15+
# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
16+
# The Whisper implementation expects a specific behavior from
17+
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
18+
# versions. Setting use_sdpa=False forces Whisper to use its manual attention
19+
# implementation instead, which is more stable across different PyTorch versions
20+
# (2.5.0 required by coremltools vs newer versions).
21+
import whisper.model
22+
whisper.model.MultiHeadAttention.use_sdpa = False
23+
1524
# Use for changing dim of input in encoder and decoder embeddings
1625
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
1726
missing_keys, unexpected_keys, error_msgs):
@@ -260,10 +269,11 @@ def convert_decoder(hparams, model, quantize=False):
260269
model.eval()
261270

262271
tokens_shape = (1, 1)
263-
audio_shape = (1, hparams.n_audio_state, 1, 1500)
272+
audio_shape = (1, hparams.n_audio_ctx, hparams.n_audio_state)
264273

265274
audio_data = torch.randn(audio_shape)
266-
token_data = torch.randint(50257, tokens_shape).long()
275+
token_data = torch.randint(hparams.n_vocab, tokens_shape).long()
276+
267277
traced_model = torch.jit.trace(model, (token_data, audio_data))
268278

269279
model = ct.convert(

models/generate-coreml-interface.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# - src/coreml/whisper-decoder-impl.h and src/coreml/whisper-decoder-impl.m
66
#
77

8+
set -e
9+
810
wd=$(dirname "$0")
911
cd "$wd/../" || exit
1012

src/coreml/whisper-decoder-impl.h

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,36 +11,33 @@
1111

1212
NS_ASSUME_NONNULL_BEGIN
1313

14-
1514
/// Model Prediction Input Type
16-
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
15+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
1716
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
1817

19-
/// token_data as 1 by 1 matrix of 32-bit integers
18+
/// token_data as 1 by 1 matrix of floats
2019
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;
2120

22-
/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
21+
/// audio_data as 1 × 1500 × 384 3-dimensional array of floats
2322
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
2423
- (instancetype)init NS_UNAVAILABLE;
2524
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
2625

2726
@end
2827

29-
3028
/// Model Prediction Output Type
31-
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
29+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
3230
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
3331

34-
/// var_1346 as multidimensional array of floats
35-
@property (readwrite, nonatomic, strong) MLMultiArray * var_1346;
32+
/// cast_76 as multidimensional array of floats
33+
@property (readwrite, nonatomic, strong) MLMultiArray * cast_76;
3634
- (instancetype)init NS_UNAVAILABLE;
37-
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER;
35+
- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 NS_DESIGNATED_INITIALIZER;
3836

3937
@end
4038

41-
4239
/// Class for model loading and prediction
43-
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
40+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
4441
@interface whisper_decoder_impl : NSObject
4542
@property (readonly, nonatomic, nullable) MLModel * model;
4643

@@ -94,7 +91,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
9491
@param configuration The model configuration
9592
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
9693
*/
97-
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
94+
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
9895

9996
/**
10097
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
@@ -105,7 +102,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
105102
@param configuration The model configuration
106103
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
107104
*/
108-
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
105+
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
109106

110107
/**
111108
Make a prediction using the standard interface
@@ -124,10 +121,25 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
124121
*/
125122
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
126123

124+
/**
125+
Make an asynchronous prediction using the standard interface
126+
@param input an instance of whisper_decoder_implInput to predict from
127+
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
128+
*/
129+
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
130+
131+
/**
132+
Make an asynchronous prediction using the standard interface
133+
@param input an instance of whisper_decoder_implInput to predict from
134+
@param options prediction options
135+
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
136+
*/
137+
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
138+
127139
/**
128140
Make a prediction using the convenience interface
129-
@param token_data as 1 by 1 matrix of 32-bit integers:
130-
@param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
141+
@param token_data 1 by 1 matrix of floats
142+
@param audio_data 1 × 1500 × 384 3-dimensional array of floats
131143
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
132144
@return the prediction as whisper_decoder_implOutput
133145
*/

src/coreml/whisper-decoder-impl.m

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,21 @@ - (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
3939

4040
@implementation whisper_decoder_implOutput
4141

42-
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 {
42+
- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 {
4343
self = [super init];
4444
if (self) {
45-
_var_1346 = var_1346;
45+
_cast_76 = cast_76;
4646
}
4747
return self;
4848
}
4949

5050
- (NSSet<NSString *> *)featureNames {
51-
return [NSSet setWithArray:@[@"var_1346"]];
51+
return [NSSet setWithArray:@[@"cast_76"]];
5252
}
5353

5454
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
55-
if ([featureName isEqualToString:@"var_1346"]) {
56-
return [MLFeatureValue featureValueWithMultiArray:self.var_1346];
55+
if ([featureName isEqualToString:@"cast_76"]) {
56+
return [MLFeatureValue featureValueWithMultiArray:self.cast_76];
5757
}
5858
return nil;
5959
}
@@ -80,10 +80,13 @@ + (nullable NSURL *)URLOfModelInThisBundle {
8080
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
8181
*/
8282
- (instancetype)initWithMLModel:(MLModel *)model {
83+
if (model == nil) {
84+
return nil;
85+
}
8386
self = [super init];
84-
if (!self) { return nil; }
85-
_model = model;
86-
if (_model == nil) { return nil; }
87+
if (self != nil) {
88+
_model = model;
89+
}
8790
return self;
8891
}
8992

@@ -177,7 +180,29 @@ - (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder
177180
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
178181
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
179182
if (!outFeatures) { return nil; }
180-
return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue];
183+
return [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[outFeatures featureValueForName:@"cast_76"].multiArrayValue];
184+
}
185+
186+
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
187+
[self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
188+
if (prediction != nil) {
189+
whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
190+
completionHandler(output, predictionError);
191+
} else {
192+
completionHandler(nil, predictionError);
193+
}
194+
}];
195+
}
196+
197+
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
198+
[self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
199+
if (prediction != nil) {
200+
whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
201+
completionHandler(output, predictionError);
202+
} else {
203+
completionHandler(nil, predictionError);
204+
}
205+
}];
181206
}
182207

183208
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
@@ -192,7 +217,7 @@ - (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray
192217
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
193218
for (NSInteger i = 0; i < outBatch.count; i++) {
194219
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
195-
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue];
220+
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[resultProvider featureValueForName:@"cast_76"].multiArrayValue];
196221
[results addObject:result];
197222
}
198223
return results;

src/coreml/whisper-encoder-impl.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111

1212
NS_ASSUME_NONNULL_BEGIN
1313

14-
1514
/// Model Prediction Input Type
16-
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
15+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
1716
@interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
1817

1918
/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
@@ -23,9 +22,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
2322

2423
@end
2524

26-
2725
/// Model Prediction Output Type
28-
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
26+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
2927
@interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
3028

3129
/// output as multidimensional array of floats
@@ -35,9 +33,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
3533

3634
@end
3735

38-
3936
/// Class for model loading and prediction
40-
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
37+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
4138
@interface whisper_encoder_impl : NSObject
4239
@property (readonly, nonatomic, nullable) MLModel * model;
4340

@@ -91,7 +88,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
9188
@param configuration The model configuration
9289
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
9390
*/
94-
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
91+
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
9592

9693
/**
9794
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
@@ -102,7 +99,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
10299
@param configuration The model configuration
103100
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
104101
*/
105-
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
102+
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
106103

107104
/**
108105
Make a prediction using the standard interface
@@ -121,9 +118,24 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
121118
*/
122119
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
123120

121+
/**
122+
Make an asynchronous prediction using the standard interface
123+
@param input an instance of whisper_encoder_implInput to predict from
124+
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
125+
*/
126+
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
127+
128+
/**
129+
Make an asynchronous prediction using the standard interface
130+
@param input an instance of whisper_encoder_implInput to predict from
131+
@param options prediction options
132+
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
133+
*/
134+
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
135+
124136
/**
125137
Make a prediction using the convenience interface
126-
@param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
138+
@param logmel_data 1 × 80 × 3000 3-dimensional array of floats
127139
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
128140
@return the prediction as whisper_encoder_implOutput
129141
*/

src/coreml/whisper-encoder-impl.m

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,13 @@ + (nullable NSURL *)URLOfModelInThisBundle {
7676
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
7777
*/
7878
- (instancetype)initWithMLModel:(MLModel *)model {
79+
if (model == nil) {
80+
return nil;
81+
}
7982
self = [super init];
80-
if (!self) { return nil; }
81-
_model = model;
82-
if (_model == nil) { return nil; }
83+
if (self != nil) {
84+
_model = model;
85+
}
8386
return self;
8487
}
8588

@@ -176,6 +179,28 @@ - (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder
176179
return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
177180
}
178181

182+
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
183+
[self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
184+
if (prediction != nil) {
185+
whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
186+
completionHandler(output, predictionError);
187+
} else {
188+
completionHandler(nil, predictionError);
189+
}
190+
}];
191+
}
192+
193+
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
194+
[self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
195+
if (prediction != nil) {
196+
whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
197+
completionHandler(output, predictionError);
198+
} else {
199+
completionHandler(nil, predictionError);
200+
}
201+
}];
202+
}
203+
179204
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
180205
whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
181206
return [self predictionFromFeatures:input_ error:error];

0 commit comments

Comments
 (0)