Skip to content

Solves #2598 #2599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
136 changes: 136 additions & 0 deletions recipes/configs/qwen2_5/14B_to_7B_KD_lora_single_device.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Config for single device knowledge distillation (KD) in knowledge_distillation_single_device.py
# using a Qwen14B teacher and student model
#
# This config assumes that you've ran the following commands before launching KD:
# First download the student and teacher models
# tune download Qwen/Qwen2.5-14B-Instruct
# tune download Qwen/Qwen2.5-7B-Instruct
#
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset:
# tune run lora_finetune_single_device --config qwen2_5/14B_lora_single_device.yaml
#
# To launch on a single device, run the following command from root:
# tune run knowledge_distillation_single_device --config qwen2_5/14B_to_7B_KD_lora_single_device.yaml
#
# This config works only for training on single device.


output_dir: /tmp/kd_7B # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_7b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 64 # higher increases accuracy and memory
lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0

teacher_model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_14b_instruct
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']

# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
path: /Qwen2.5-7B-Instruct/vocab.json
merges_file: /Qwen2.5-7B-Instruct/merges.txt
max_seq_len: null

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2.5-7B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00050"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Teacher checkpoint
teacher_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2.5-14B-Instruct
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00050"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False # True increases speed
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

kd_loss:
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss
kd_ratio: 0.5

# Training
epochs: 5
max_steps_per_epoch: null
gradient_accumulation_steps: 8 # Use to increase effective batch size
clip_grad_norm: null
compile: False # torch.compile the model + loss, True increases speed + decreases memory

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 2
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory



# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
26 changes: 26 additions & 0 deletions torchtune/modules/loss/kd_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ def __init__(self, ignore_index: int = -100):
super().__init__()
self.ignore_index = ignore_index

def pad_logits(self,student_logits, teacher_logits):
student_size, teacher_size = student_logits.size(-1), teacher_logits.size(-1)
if student_size != teacher_size:
pad_size = abs(student_size - teacher_size)
pad_tensor = torch.zeros(
(*teacher_logits.shape[:-1], pad_size),
dtype=teacher_logits.dtype,
device=teacher_logits.device,
)
return (
(torch.cat([student_logits, pad_tensor], dim=-1), teacher_logits)
if student_size < teacher_size
else (student_logits, torch.cat([teacher_logits, pad_tensor], dim=-1))
)
return student_logits, teacher_logits
def forward(
self,
student_logits: torch.Tensor,
Expand All @@ -49,6 +64,16 @@ def forward(
teacher_prob = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
inf_mask = torch.isinf(student_logits)
student_logprob = F.log_softmax(student_logits, dim=-1, dtype=torch.float32)
student_logprob, teacher_prob = self.pad_logits(
student_logprob,
teacher_prob,
)
inf_mask,_ = self.pad_logits(
inf_mask,
teacher_prob,
)
inf_mask = torch.isinf(student_logprob)

prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
mask = (labels != self.ignore_index).int()
Expand All @@ -61,6 +86,7 @@ def forward(
return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)



class ReverseKLLoss(torch.nn.Module):
"""
The Kullback-Leibler divergence loss for valid indexes.
Expand Down