Distillation: create student model from a different base model than teacher

Hi,

The current implementation of distillation in examples/seq2seq/distillation.py creates a student model by copying selected layers from the teacher model. However I am interested in creating a student model from a different base model, for e.g., teacher model using t5-large and student model using t5-small. I have made changes here. I think I am missing something, because when I try to run this using:

python distillation.py --teacher t5-large --data_dir $NQOPEN_DIR \
--student_base_model t5-small --tokenizer_name t5-small \
--learning_rate=3e-4 --freeze_encoder --freeze_embeds \
--do_train --train_batch_size 32 \
--do_predict --n_train 10 \
--model_name_or_path t5-small --eval_beams 2 --eval_max_gen_length 142 \
--val_check_interval 0.25 --n_val 10 \
--output_dir distilt5 --gpus 1 --logger_name wandb

I get the following error. Could you please let me know what I am missing?

Traceback (most recent call last):
  File "distillation.py", line 361, in <module>
    distill_main(args)
  File "distillation.py", line 352, in distill_main
    return ft_main(args, model=model)
  File "/home/sumithrab/transformers/examples/seq2seq/finetune.py", line 407, in main
    trainer: pl.Trainer = generic_train(
  File "/home/sumithrab/transformers/examples/lightning_base.py", line 382, in generic_train
    trainer.fit(model)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1003, in fit
    results = self.single_gpu_train(model)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 186, in single_gpu_train
    results = self.run_pretrain_routine(model)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in run_pretrain_routine
    eval_results = self._evaluate(model,
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 293, in _evaluate
    output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 470, in evaluation_forward
    output = model.validation_step(*args)
  File "/home/sumithrab/transformers/examples/seq2seq/finetune.py", line 181, in validation_step
    return self._generative_step(batch)
  File "/home/sumithrab/transformers/examples/seq2seq/finetune.py", line 225, in _generative_step
    loss_tensors = self._step(batch)
  File "distillation.py", line 211, in _step
    outputs = self.teacher(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 1201, in forward
    decoder_outputs = self.decoder(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 757, in forward
    layer_outputs = layer_module(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 547, in forward
    cross_attention_outputs = self.layer[1](
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 469, in forward
    attention_output = self.EncDecAttention(
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/transformers/src/transformers/modeling_t5.py", line 356, in forward
    k = shape(self.k(k))  # (bs, n_heads, qlen, dim_per_head)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 91, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/sumithrab/miniconda3/envs/t5/lib/python3.8/site-packages/torch/nn/functional.py", line 1676, in linear
    output = input.matmul(weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0

Also, I am not sure if the --model_name_or_path and the --tokenizer_name arguments are correct-- should they be t5-large or t5-small?

Thanks,
Sumithra