-
Notifications
You must be signed in to change notification settings - Fork 4.2k
coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci] #2979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Is the newly generated CoreML interface compatible with the existing CoreML models? |
I've generated the interfaces and compiled enabling Core ML, and ran whisper-cli successfully. And I've tried whisper.swiftui with a coreml module (I rebuilt the xcframework) and it also worked. Is there some other way I can verify that this works correctly? |
Great, that's good enough. Btwm what is the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verified that base
and base.en
model convert successfully and that the generated CoreML interface works on my M4 Max.
The diff that I get is:
diff --git a/src/coreml/whisper-decoder-impl.h b/src/coreml/whisper-decoder-impl.h
index c6f2e853..8ec9373d 100644
--- a/src/coreml/whisper-decoder-impl.h
+++ b/src/coreml/whisper-decoder-impl.h
@@ -11,36 +11,33 @@
NS_ASSUME_NONNULL_BEGIN
-
/// Model Prediction Input Type
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>
-/// token_data as 1 by 1 matrix of 32-bit integers
+/// token_data as 1 by 1 matrix of floats
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;
-/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
+/// audio_data as 1 × 1500 × 384 3-dimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;
@end
-
/// Model Prediction Output Type
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>
-/// var_1346 as multidimensional array of floats
-@property (readwrite, nonatomic, strong) MLMultiArray * var_1346;
+/// cast_76 as multidimensional array of floats
+@property (readwrite, nonatomic, strong) MLMultiArray * cast_76;
- (instancetype)init NS_UNAVAILABLE;
-- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER;
+- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 NS_DESIGNATED_INITIALIZER;
@end
-
/// Class for model loading and prediction
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;
@@ -94,7 +91,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
-+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
++ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
/**
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
@@ -105,7 +102,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
-+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
++ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
/**
Make a prediction using the standard interface
@@ -124,10 +121,25 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
*/
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
+/**
+ Make an asynchronous prediction using the standard interface
+ @param input an instance of whisper_decoder_implInput to predict from
+ @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
+*/
+- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
+
+/**
+ Make an asynchronous prediction using the standard interface
+ @param input an instance of whisper_decoder_implInput to predict from
+ @param options prediction options
+ @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
+*/
+- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
+
/**
Make a prediction using the convenience interface
- @param token_data as 1 by 1 matrix of 32-bit integers:
- @param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
+ @param token_data 1 by 1 matrix of floats
+ @param audio_data 1 × 1500 × 384 3-dimensional array of floats
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/
diff --git a/src/coreml/whisper-decoder-impl.m b/src/coreml/whisper-decoder-impl.m
index 34060e45..732992e1 100644
--- a/src/coreml/whisper-decoder-impl.m
+++ b/src/coreml/whisper-decoder-impl.m
@@ -39,21 +39,21 @@
@implementation whisper_decoder_implOutput
-- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 {
+- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 {
self = [super init];
if (self) {
- _var_1346 = var_1346;
+ _cast_76 = cast_76;
}
return self;
}
- (NSSet<NSString *> *)featureNames {
- return [NSSet setWithArray:@[@"var_1346"]];
+ return [NSSet setWithArray:@[@"cast_76"]];
}
- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
- if ([featureName isEqualToString:@"var_1346"]) {
- return [MLFeatureValue featureValueWithMultiArray:self.var_1346];
+ if ([featureName isEqualToString:@"cast_76"]) {
+ return [MLFeatureValue featureValueWithMultiArray:self.cast_76];
}
return nil;
}
@@ -80,10 +80,13 @@
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
+ if (model == nil) {
+ return nil;
+ }
self = [super init];
- if (!self) { return nil; }
- _model = model;
- if (_model == nil) { return nil; }
+ if (self != nil) {
+ _model = model;
+ }
return self;
}
@@ -177,7 +180,29 @@
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
if (!outFeatures) { return nil; }
- return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue];
+ return [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[outFeatures featureValueForName:@"cast_76"].multiArrayValue];
+}
+
+- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
+ [self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
+ if (prediction != nil) {
+ whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
+ completionHandler(output, predictionError);
+ } else {
+ completionHandler(nil, predictionError);
+ }
+ }];
+}
+
+- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
+ [self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
+ if (prediction != nil) {
+ whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
+ completionHandler(output, predictionError);
+ } else {
+ completionHandler(nil, predictionError);
+ }
+ }];
}
- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
@@ -192,7 +217,7 @@
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
for (NSInteger i = 0; i < outBatch.count; i++) {
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
- whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue];
+ whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[resultProvider featureValueForName:@"cast_76"].multiArrayValue];
[results addObject:result];
}
return results;
diff --git a/src/coreml/whisper-encoder-impl.h b/src/coreml/whisper-encoder-impl.h
index 7b83cd90..c4d42248 100644
--- a/src/coreml/whisper-encoder-impl.h
+++ b/src/coreml/whisper-encoder-impl.h
@@ -11,9 +11,8 @@
NS_ASSUME_NONNULL_BEGIN
-
/// Model Prediction Input Type
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implInput : NSObject<MLFeatureProvider>
/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
@@ -23,9 +22,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@end
-
/// Model Prediction Output Type
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>
/// output as multidimensional array of floats
@@ -35,9 +33,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@end
-
/// Class for model loading and prediction
-API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
+API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;
@@ -91,7 +88,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
-+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
++ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
/**
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
@@ -102,7 +99,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
-+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
++ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
/**
Make a prediction using the standard interface
@@ -121,9 +118,24 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
*/
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
+/**
+ Make an asynchronous prediction using the standard interface
+ @param input an instance of whisper_encoder_implInput to predict from
+ @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
+*/
+- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
+
+/**
+ Make an asynchronous prediction using the standard interface
+ @param input an instance of whisper_encoder_implInput to predict from
+ @param options prediction options
+ @param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
+*/
+- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));
+
/**
Make a prediction using the convenience interface
- @param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
+ @param logmel_data 1 × 80 × 3000 3-dimensional array of floats
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
diff --git a/src/coreml/whisper-encoder-impl.m b/src/coreml/whisper-encoder-impl.m
index ee8e5065..2ed9dc61 100644
--- a/src/coreml/whisper-encoder-impl.m
+++ b/src/coreml/whisper-encoder-impl.m
@@ -76,10 +76,13 @@
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
+ if (model == nil) {
+ return nil;
+ }
self = [super init];
- if (!self) { return nil; }
- _model = model;
- if (_model == nil) { return nil; }
+ if (self != nil) {
+ _model = model;
+ }
return self;
}
@@ -176,6 +179,28 @@
return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
}
+- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
+ [self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
+ if (prediction != nil) {
+ whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
+ completionHandler(output, predictionError);
+ } else {
+ completionHandler(nil, predictionError);
+ }
+ }];
+}
+
+- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
+ [self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
+ if (prediction != nil) {
+ whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
+ completionHandler(output, predictionError);
+ } else {
+ completionHandler(nil, predictionError);
+ }
+ }];
+}
+
- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
return [self predictionFromFeatures:input_ error:error];
Wondering if it is a good idea to commit this new interface?
Yeah, I was not sure if I should have included a commit or not, but I'll do that now that this seems to work 👍 |
This commit disables the use of PyTorch's `scaled_dot_product_attention` in the Whisper model to avoid compatibility issues during CoreML conversion. The issue occurs because coremltools requires PyTorch 2.5.0, but the Whisper implementation may expect behavior from newer PyTorch versions. By setting `MultiHeadAttention.use_sdpa = False`, we force Whisper to use its fallback manual attention implementation, which works correctly with PyTorch 2.5.0 during the tracing process. Refs: ggml-org#2783
This commit fixes the audio shape in the whisper decoder conversion script. The motivation for this is that the audio shape was incorrect and was causing the conversion to fail.
The commit sets the -e flag in the generate-coreml-interface.sh script to make sure the script fails if any command fails.
This commit updates the generated encoder/decoder interfaces for the whisper model which is the result of running the generate-coreml-interface.sh script.
95b2438
to
456dfcb
Compare
This commit disables the use of PyTorch's
scaled_dot_product_attention
in the Whisper model to avoid compatibility issues during CoreML conversion. The issue occurs because coremltools requires PyTorch 2.5.0, but the Whisper implementation may expect behavior from newer PyTorch versions.By setting
MultiHeadAttention.use_sdpa = False
, we force Whisper to use its fallback manual attention implementation, which works correctly with PyTorch 2.5.0 during the tracing process.Refs: #2783
With the changes in this pull request I was able to run
models/generate-coreml-interface.sh
script using the following steps: