Skip to content

coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci] #2979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 1, 2025

Conversation

danbev
Copy link
Collaborator

@danbev danbev commented Apr 1, 2025

This commit disables the use of PyTorch's scaled_dot_product_attention in the Whisper model to avoid compatibility issues during CoreML conversion. The issue occurs because coremltools requires PyTorch 2.5.0, but the Whisper implementation may expect behavior from newer PyTorch versions.

By setting MultiHeadAttention.use_sdpa = False, we force Whisper to use its fallback manual attention implementation, which works correctly with PyTorch 2.5.0 during the tracing process.

Refs: #2783


With the changes in this pull request I was able to run models/generate-coreml-interface.sh script using the following steps:

$ python3.11 -m venv venv
$ source venv/bin/activate
(venv) $ pip install ane_transformers openai-whisper coremltools
(venv) $ ./models/generate-coreml-interface.sh

@ggerganov
Copy link
Member

Is the newly generated CoreML interface compatible with the existing CoreML models?

@danbev
Copy link
Collaborator Author

danbev commented Apr 1, 2025

Is the newly generated CoreML interface compatible with the existing CoreML models?

I've generated the interfaces and compiled enabling Core ML, and ran whisper-cli successfully. And I've tried whisper.swiftui with a coreml module (I rebuilt the xcframework) and it also worked. Is there some other way I can verify that this works correctly?

@ggerganov
Copy link
Member

Great, that's good enough.

Btwm what is the diff that you get for the sources in src/coreml after this change?

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified that base and base.en model convert successfully and that the generated CoreML interface works on my M4 Max.

The diff that I get is:

diff --git a/src/coreml/whisper-decoder-impl.h b/src/coreml/whisper-decoder-impl.h
index c6f2e853..8ec9373d 100644
--- a/src/coreml/whisper-decoder-impl.h
+++ b/src/coreml/whisper-decoder-impl.h
@@ -11,36 +11,33 @@
 
 NS_ASSUME_NONNULL_BEGIN
 
-
 /// Model Prediction Input Type
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
 @interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
 
-/// token_data as 1 by 1 matrix of 32-bit integers
+/// token_data as 1 by 1 matrix of floats
 @property (readwrite, nonatomic, strong) MLMultiArray * token_data;
 
-/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
+/// audio_data as 1 × 1500 × 384 3-dimensional array of floats
 @property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
 - (instancetype)init NS_UNAVAILABLE;
 - (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
 
 @end
 
-
 /// Model Prediction Output Type
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
 @interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
 
-/// var_1346 as multidimensional array of floats
-@property (readwrite, nonatomic, strong) MLMultiArray * var_1346;
+/// cast_76 as multidimensional array of floats
+@property (readwrite, nonatomic, strong) MLMultiArray * cast_76;
 - (instancetype)init NS_UNAVAILABLE;
-- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER;
+- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 NS_DESIGNATED_INITIALIZER;
 
 @end
 
-
 /// Class for model loading and prediction
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
 @interface whisper_decoder_impl : NSObject
 @property (readonly, nonatomic, nullable) MLModel * model;
 
@@ -94,7 +91,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
     @param configuration The model configuration
     @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
 */
-+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
++ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
 
 /**
     Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
@@ -105,7 +102,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
     @param configuration The model configuration
     @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
 */
-+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
++ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
 
 /**
     Make a prediction using the standard interface
@@ -124,10 +121,25 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
 */
 - (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
 
+/**
+    Make an asynchronous prediction using the standard interface
+    @param input an instance of whisper_decoder_implInput to predict from
+    @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
+*/
+- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
+
+/**
+    Make an asynchronous prediction using the standard interface
+    @param input an instance of whisper_decoder_implInput to predict from
+    @param options prediction options
+    @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
+*/
+- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
+
 /**
     Make a prediction using the convenience interface
-    @param token_data as 1 by 1 matrix of 32-bit integers:
-    @param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
+    @param token_data 1 by 1 matrix of floats
+    @param audio_data 1 × 1500 × 384 3-dimensional array of floats
     @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
     @return the prediction as whisper_decoder_implOutput
 */
diff --git a/src/coreml/whisper-decoder-impl.m b/src/coreml/whisper-decoder-impl.m
index 34060e45..732992e1 100644
--- a/src/coreml/whisper-decoder-impl.m
+++ b/src/coreml/whisper-decoder-impl.m
@@ -39,21 +39,21 @@
 
 @implementation whisper_decoder_implOutput
 
-- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 {
+- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 {
     self = [super init];
     if (self) {
-        _var_1346 = var_1346;
+        _cast_76 = cast_76;
     }
     return self;
 }
 
 - (NSSet<NSString *> *)featureNames {
-    return [NSSet setWithArray:@[@"var_1346"]];
+    return [NSSet setWithArray:@[@"cast_76"]];
 }
 
 - (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
-    if ([featureName isEqualToString:@"var_1346"]) {
-        return [MLFeatureValue featureValueWithMultiArray:self.var_1346];
+    if ([featureName isEqualToString:@"cast_76"]) {
+        return [MLFeatureValue featureValueWithMultiArray:self.cast_76];
     }
     return nil;
 }
@@ -80,10 +80,13 @@
     Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
 */
 - (instancetype)initWithMLModel:(MLModel *)model {
+    if (model == nil) {
+        return nil;
+    }
     self = [super init];
-    if (!self) { return nil; }
-    _model = model;
-    if (_model == nil) { return nil; }
+    if (self != nil) {
+        _model = model;
+    }
     return self;
 }
 
@@ -177,7 +180,29 @@
 - (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
     id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
     if (!outFeatures) { return nil; }
-    return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue];
+    return [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[outFeatures featureValueForName:@"cast_76"].multiArrayValue];
+}
+
+- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
+    [self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
+        if (prediction != nil) {
+            whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
+            completionHandler(output, predictionError);
+        } else {
+            completionHandler(nil, predictionError);
+        }
+    }];
+}
+
+- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
+    [self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
+        if (prediction != nil) {
+            whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
+            completionHandler(output, predictionError);
+        } else {
+            completionHandler(nil, predictionError);
+        }
+    }];
 }
 
 - (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
@@ -192,7 +217,7 @@
     NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
     for (NSInteger i = 0; i < outBatch.count; i++) {
         id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
-        whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue];
+        whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[resultProvider featureValueForName:@"cast_76"].multiArrayValue];
         [results addObject:result];
     }
     return results;
diff --git a/src/coreml/whisper-encoder-impl.h b/src/coreml/whisper-encoder-impl.h
index 7b83cd90..c4d42248 100644
--- a/src/coreml/whisper-encoder-impl.h
+++ b/src/coreml/whisper-encoder-impl.h
@@ -11,9 +11,8 @@
 
 NS_ASSUME_NONNULL_BEGIN
 
-
 /// Model Prediction Input Type
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
 @interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
 
 /// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
@@ -23,9 +22,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
 
 @end
 
-
 /// Model Prediction Output Type
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
 @interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
 
 /// output as multidimensional array of floats
@@ -35,9 +33,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
 
 @end
 
-
 /// Class for model loading and prediction
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
 @interface whisper_encoder_impl : NSObject
 @property (readonly, nonatomic, nullable) MLModel * model;
 
@@ -91,7 +88,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
     @param configuration The model configuration
     @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
 */
-+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
++ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
 
 /**
     Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
@@ -102,7 +99,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
     @param configuration The model configuration
     @param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
 */
-+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
++ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
 
 /**
     Make a prediction using the standard interface
@@ -121,9 +118,24 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
 */
 - (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
 
+/**
+    Make an asynchronous prediction using the standard interface
+    @param input an instance of whisper_encoder_implInput to predict from
+    @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
+*/
+- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
+
+/**
+    Make an asynchronous prediction using the standard interface
+    @param input an instance of whisper_encoder_implInput to predict from
+    @param options prediction options
+    @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
+*/
+- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
+
 /**
     Make a prediction using the convenience interface
-    @param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
+    @param logmel_data 1 × 80 × 3000 3-dimensional array of floats
     @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
     @return the prediction as whisper_encoder_implOutput
 */
diff --git a/src/coreml/whisper-encoder-impl.m b/src/coreml/whisper-encoder-impl.m
index ee8e5065..2ed9dc61 100644
--- a/src/coreml/whisper-encoder-impl.m
+++ b/src/coreml/whisper-encoder-impl.m
@@ -76,10 +76,13 @@
     Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
 */
 - (instancetype)initWithMLModel:(MLModel *)model {
+    if (model == nil) {
+        return nil;
+    }
     self = [super init];
-    if (!self) { return nil; }
-    _model = model;
-    if (_model == nil) { return nil; }
+    if (self != nil) {
+        _model = model;
+    }
     return self;
 }
 
@@ -176,6 +179,28 @@
     return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
 }
 
+- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
+    [self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
+        if (prediction != nil) {
+            whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
+            completionHandler(output, predictionError);
+        } else {
+            completionHandler(nil, predictionError);
+        }
+    }];
+}
+
+- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
+    [self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
+        if (prediction != nil) {
+            whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
+            completionHandler(output, predictionError);
+        } else {
+            completionHandler(nil, predictionError);
+        }
+    }];
+}
+
 - (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
     whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
     return [self predictionFromFeatures:input_ error:error];

Wondering if it is a good idea to commit this new interface?

@danbev
Copy link
Collaborator Author

danbev commented Apr 1, 2025

Wondering if it is a good idea to commit this new interface?

Yeah, I was not sure if I should have included a commit or not, but I'll do that now that this seems to work 👍

danbev added 4 commits April 1, 2025 17:05
This commit disables the use of PyTorch's
`scaled_dot_product_attention` in the Whisper model to avoid
compatibility issues during CoreML conversion.
The issue occurs because coremltools requires PyTorch 2.5.0, but the
Whisper implementation may expect behavior from newer PyTorch versions.

By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to
use its fallback manual attention implementation, which works correctly
with PyTorch 2.5.0 during the tracing process.

Refs: ggml-org#2783
This commit fixes the audio shape in the whisper decoder conversion
script.

The motivation for this is that the  audio shape was incorrect and
was causing the conversion to fail.
The commit sets the -e flag in the generate-coreml-interface.sh script
to make sure the script fails if any command fails.
This commit updates the generated encoder/decoder interfaces for the
whisper model which is the result of running the
generate-coreml-interface.sh script.
@danbev danbev force-pushed the coreml-generate-interfaces branch from 95b2438 to 456dfcb Compare April 1, 2025 15:06
@danbev danbev merged commit 11688b2 into ggml-org:master Apr 1, 2025
51 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants