Skip to content

coreml: fix Whisper to CoreML conversion by disabling SDPA [no ci] #2979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions models/convert-whisper-to-coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions
from whisper import load_model

# Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues.
# The Whisper implementation expects a specific behavior from
# torch.nn.functional.scaled_dot_product_attention that differs between PyTorch
# versions. Setting use_sdpa=False forces Whisper to use its manual attention
# implementation instead, which is more stable across different PyTorch versions
# (2.5.0 required by coremltools vs newer versions).
import whisper.model
whisper.model.MultiHeadAttention.use_sdpa = False

# Use for changing dim of input in encoder and decoder embeddings
def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
Expand Down Expand Up @@ -260,10 +269,11 @@ def convert_decoder(hparams, model, quantize=False):
model.eval()

tokens_shape = (1, 1)
audio_shape = (1, hparams.n_audio_state, 1, 1500)
audio_shape = (1, hparams.n_audio_ctx, hparams.n_audio_state)

audio_data = torch.randn(audio_shape)
token_data = torch.randint(50257, tokens_shape).long()
token_data = torch.randint(hparams.n_vocab, tokens_shape).long()

traced_model = torch.jit.trace(model, (token_data, audio_data))

model = ct.convert(
Expand Down
2 changes: 2 additions & 0 deletions models/generate-coreml-interface.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# - src/coreml/whisper-decoder-impl.h and src/coreml/whisper-decoder-impl.m
#

set -e

wd=$(dirname "$0")
cd "$wd/../" || exit

Expand Down
42 changes: 27 additions & 15 deletions src/coreml/whisper-decoder-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,33 @@

NS_ASSUME_NONNULL_BEGIN


/// Model Prediction Input Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implInput : NSObject<MLFeatureProvider>

/// token_data as 1 by 1 matrix of 32-bit integers
/// token_data as 1 by 1 matrix of floats
@property (readwrite, nonatomic, strong) MLMultiArray * token_data;

/// audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats
/// audio_data as 1 × 1500 × 384 3-dimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * audio_data;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data NS_DESIGNATED_INITIALIZER;

@end


/// Model Prediction Output Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_implOutput : NSObject<MLFeatureProvider>

/// var_1346 as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * var_1346;
/// cast_76 as multidimensional array of floats
@property (readwrite, nonatomic, strong) MLMultiArray * cast_76;
- (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 NS_DESIGNATED_INITIALIZER;
- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 NS_DESIGNATED_INITIALIZER;

@end


/// Class for model loading and prediction
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_decoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;

Expand Down Expand Up @@ -94,7 +91,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));

/**
Construct whisper_decoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Expand All @@ -105,7 +102,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_decoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler;
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_decoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));

/**
Make a prediction using the standard interface
Expand All @@ -124,10 +121,25 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
*/
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;

/**
Make an asynchronous prediction using the standard interface
@param input an instance of whisper_decoder_implInput to predict from
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
*/
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));

/**
Make an asynchronous prediction using the standard interface
@param input an instance of whisper_decoder_implInput to predict from
@param options prediction options
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
*/
- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));

/**
Make a prediction using the convenience interface
@param token_data as 1 by 1 matrix of 32-bit integers:
@param audio_data as 1 × 384 × 1 × 1500 4-dimensional array of floats:
@param token_data 1 by 1 matrix of floats
@param audio_data 1 × 1500 × 384 3-dimensional array of floats
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_decoder_implOutput
*/
Expand Down
45 changes: 35 additions & 10 deletions src/coreml/whisper-decoder-impl.m
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ - (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {

@implementation whisper_decoder_implOutput

- (instancetype)initWithVar_1346:(MLMultiArray *)var_1346 {
- (instancetype)initWithCast_76:(MLMultiArray *)cast_76 {
self = [super init];
if (self) {
_var_1346 = var_1346;
_cast_76 = cast_76;
}
return self;
}

- (NSSet<NSString *> *)featureNames {
return [NSSet setWithArray:@[@"var_1346"]];
return [NSSet setWithArray:@[@"cast_76"]];
}

- (nullable MLFeatureValue *)featureValueForName:(NSString *)featureName {
if ([featureName isEqualToString:@"var_1346"]) {
return [MLFeatureValue featureValueWithMultiArray:self.var_1346];
if ([featureName isEqualToString:@"cast_76"]) {
return [MLFeatureValue featureValueWithMultiArray:self.cast_76];
}
return nil;
}
Expand All @@ -80,10 +80,13 @@ + (nullable NSURL *)URLOfModelInThisBundle {
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
if (model == nil) {
return nil;
}
self = [super init];
if (!self) { return nil; }
_model = model;
if (_model == nil) { return nil; }
if (self != nil) {
_model = model;
}
return self;
}

Expand Down Expand Up @@ -177,7 +180,29 @@ - (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder
- (nullable whisper_decoder_implOutput *)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error {
id<MLFeatureProvider> outFeatures = [self.model predictionFromFeatures:input options:options error:error];
if (!outFeatures) { return nil; }
return [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[outFeatures featureValueForName:@"var_1346"].multiArrayValue];
return [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[outFeatures featureValueForName:@"cast_76"].multiArrayValue];
}

- (void)predictionFromFeatures:(whisper_decoder_implInput *)input completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
[self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
if (prediction != nil) {
whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
completionHandler(output, predictionError);
} else {
completionHandler(nil, predictionError);
}
}];
}

- (void)predictionFromFeatures:(whisper_decoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_decoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
[self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
if (prediction != nil) {
whisper_decoder_implOutput *output = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[prediction featureValueForName:@"cast_76"].multiArrayValue];
completionHandler(output, predictionError);
} else {
completionHandler(nil, predictionError);
}
}];
}

- (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray *)token_data audio_data:(MLMultiArray *)audio_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
Expand All @@ -192,7 +217,7 @@ - (nullable whisper_decoder_implOutput *)predictionFromToken_data:(MLMultiArray
NSMutableArray<whisper_decoder_implOutput*> *results = [NSMutableArray arrayWithCapacity:(NSUInteger)outBatch.count];
for (NSInteger i = 0; i < outBatch.count; i++) {
id<MLFeatureProvider> resultProvider = [outBatch featuresAtIndex:i];
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithVar_1346:(MLMultiArray *)[resultProvider featureValueForName:@"var_1346"].multiArrayValue];
whisper_decoder_implOutput * result = [[whisper_decoder_implOutput alloc] initWithCast_76:(MLMultiArray *)[resultProvider featureValueForName:@"cast_76"].multiArrayValue];
[results addObject:result];
}
return results;
Expand Down
30 changes: 21 additions & 9 deletions src/coreml/whisper-encoder-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@

NS_ASSUME_NONNULL_BEGIN


/// Model Prediction Input Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implInput : NSObject<MLFeatureProvider>

/// logmel_data as 1 × 80 × 3000 3-dimensional array of floats
Expand All @@ -23,9 +22,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v

@end


/// Model Prediction Output Type
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_implOutput : NSObject<MLFeatureProvider>

/// output as multidimensional array of floats
Expand All @@ -35,9 +33,8 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v

@end


/// Class for model loading and prediction
API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((visibility("hidden")))
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
@interface whisper_encoder_impl : NSObject
@property (readonly, nonatomic, nullable) MLModel * model;

Expand Down Expand Up @@ -91,7 +88,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));

/**
Construct whisper_encoder_impl instance asynchronously with URL of .mlmodelc directory and optional configuration.
Expand All @@ -102,7 +99,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
@param configuration The model configuration
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid whisper_encoder_impl instance or NSError object.
*/
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler;
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(whisper_encoder_impl * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));

/**
Make a prediction using the standard interface
Expand All @@ -121,9 +118,24 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
*/
- (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;

/**
Make an asynchronous prediction using the standard interface
@param input an instance of whisper_encoder_implInput to predict from
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
*/
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));

/**
Make an asynchronous prediction using the standard interface
@param input an instance of whisper_encoder_implInput to predict from
@param options prediction options
@param completionHandler a block that will be called upon completion of the prediction. error will be nil if no error occurred.
*/
- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler API_AVAILABLE(macos(14.0), ios(17.0), watchos(10.0), tvos(17.0)) __attribute__((visibility("hidden")));

/**
Make a prediction using the convenience interface
@param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
@param logmel_data 1 × 80 × 3000 3-dimensional array of floats
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
@return the prediction as whisper_encoder_implOutput
*/
Expand Down
31 changes: 28 additions & 3 deletions src/coreml/whisper-encoder-impl.m
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ + (nullable NSURL *)URLOfModelInThisBundle {
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
*/
- (instancetype)initWithMLModel:(MLModel *)model {
if (model == nil) {
return nil;
}
self = [super init];
if (!self) { return nil; }
_model = model;
if (_model == nil) { return nil; }
if (self != nil) {
_model = model;
}
return self;
}

Expand Down Expand Up @@ -176,6 +179,28 @@ - (nullable whisper_encoder_implOutput *)predictionFromFeatures:(whisper_encoder
return [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[outFeatures featureValueForName:@"output"].multiArrayValue];
}

- (void)predictionFromFeatures:(whisper_encoder_implInput *)input completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
[self.model predictionFromFeatures:input completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
if (prediction != nil) {
whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
completionHandler(output, predictionError);
} else {
completionHandler(nil, predictionError);
}
}];
}

- (void)predictionFromFeatures:(whisper_encoder_implInput *)input options:(MLPredictionOptions *)options completionHandler:(void (^)(whisper_encoder_implOutput * _Nullable output, NSError * _Nullable error))completionHandler {
[self.model predictionFromFeatures:input options:options completionHandler:^(id<MLFeatureProvider> prediction, NSError *predictionError) {
if (prediction != nil) {
whisper_encoder_implOutput *output = [[whisper_encoder_implOutput alloc] initWithOutput:(MLMultiArray *)[prediction featureValueForName:@"output"].multiArrayValue];
completionHandler(output, predictionError);
} else {
completionHandler(nil, predictionError);
}
}];
}

- (nullable whisper_encoder_implOutput *)predictionFromLogmel_data:(MLMultiArray *)logmel_data error:(NSError * _Nullable __autoreleasing * _Nullable)error {
whisper_encoder_implInput *input_ = [[whisper_encoder_implInput alloc] initWithLogmel_data:logmel_data];
return [self predictionFromFeatures:input_ error:error];
Expand Down
Loading