@@ -205,9 +205,13 @@ def replace_kv_cache_with_quantized_kv_cache(module):
205
205
# This is needed to ensure that custom ops are registered
206
206
from executorch .extension .llm .custom_ops import custom_ops # noqa: F401
207
207
208
- logging .warning (
208
+ logging .info (
209
209
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
210
210
)
211
+ return _replace_kv_cache_with_quantized_kv_cache (module )
212
+
213
+
214
+ def _replace_kv_cache_with_quantized_kv_cache (module ):
211
215
for name , child in module .named_children ():
212
216
if isinstance (child , KVCache ) or isinstance (child , CustomKVCache ):
213
217
setattr (
@@ -220,7 +224,7 @@ def replace_kv_cache_with_quantized_kv_cache(module):
220
224
),
221
225
)
222
226
else :
223
- replace_kv_cache_with_quantized_kv_cache (child )
227
+ _replace_kv_cache_with_quantized_kv_cache (child )
224
228
return module
225
229
226
230
@@ -263,16 +267,20 @@ def update(
263
267
264
268
265
269
def replace_kv_cache_with_custom_kv_cache (module ):
266
- r """
270
+ """
267
271
Replace KVCache with CustomKVCache. This modifies the model in place.
268
272
At the moment custom kv cache only supports cache with shape
269
273
[B, S, H, D] as opposed to [B, H, S, D]
270
274
This is because the custom op treats second dim as sequence dim.
271
275
Future work: support [B, H, S, D]
272
276
"""
273
- logging .warning (
277
+ logging .info (
274
278
"Replacing KVCache with CustomKVCache. This modifies the model in place."
275
279
)
280
+ return _replace_kv_cache_with_custom_kv_cache (module )
281
+
282
+
283
+ def _replace_kv_cache_with_custom_kv_cache (module ):
276
284
for name , child in module .named_children ():
277
285
if isinstance (child , KVCache ):
278
286
cache_shape = child .k_cache .shape
@@ -290,5 +298,5 @@ def replace_kv_cache_with_custom_kv_cache(module):
290
298
),
291
299
)
292
300
else :
293
- replace_kv_cache_with_custom_kv_cache (child )
301
+ _replace_kv_cache_with_custom_kv_cache (child )
294
302
return module
0 commit comments