Skip to content

Commit 9b3eb06

Browse files
authored
Make cache replacement export log less noisy (#8936)
So we don't see `WARNING:root:Replacing KVCache with CustomKVCache. This modifies the model in place.` a million times per export
1 parent a82b823 commit 9b3eb06

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

examples/models/llama/source_transformation/quantized_kv_cache.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,13 @@ def replace_kv_cache_with_quantized_kv_cache(module):
205205
# This is needed to ensure that custom ops are registered
206206
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
207207

208-
logging.warning(
208+
logging.info(
209209
"Replacing KVCache with QuantizedKVCache. This modifies the model in place."
210210
)
211+
return _replace_kv_cache_with_quantized_kv_cache(module)
212+
213+
214+
def _replace_kv_cache_with_quantized_kv_cache(module):
211215
for name, child in module.named_children():
212216
if isinstance(child, KVCache) or isinstance(child, CustomKVCache):
213217
setattr(
@@ -220,7 +224,7 @@ def replace_kv_cache_with_quantized_kv_cache(module):
220224
),
221225
)
222226
else:
223-
replace_kv_cache_with_quantized_kv_cache(child)
227+
_replace_kv_cache_with_quantized_kv_cache(child)
224228
return module
225229

226230

@@ -263,16 +267,20 @@ def update(
263267

264268

265269
def replace_kv_cache_with_custom_kv_cache(module):
266-
r"""
270+
"""
267271
Replace KVCache with CustomKVCache. This modifies the model in place.
268272
At the moment custom kv cache only supports cache with shape
269273
[B, S, H, D] as opposed to [B, H, S, D]
270274
This is because the custom op treats second dim as sequence dim.
271275
Future work: support [B, H, S, D]
272276
"""
273-
logging.warning(
277+
logging.info(
274278
"Replacing KVCache with CustomKVCache. This modifies the model in place."
275279
)
280+
return _replace_kv_cache_with_custom_kv_cache(module)
281+
282+
283+
def _replace_kv_cache_with_custom_kv_cache(module):
276284
for name, child in module.named_children():
277285
if isinstance(child, KVCache):
278286
cache_shape = child.k_cache.shape
@@ -290,5 +298,5 @@ def replace_kv_cache_with_custom_kv_cache(module):
290298
),
291299
)
292300
else:
293-
replace_kv_cache_with_custom_kv_cache(child)
301+
_replace_kv_cache_with_custom_kv_cache(child)
294302
return module

0 commit comments

Comments
 (0)