Transformers

Exploring transformers library in depth
def clean_memory():
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
2025-02-24 18:03:36.136586: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1740449016.158831   27875 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740449016.165767   27875 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
type(model)
transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM

Lets look at the source code for this model.

From here we have

class Qwen2ForCausalLM(LlamaForCausalLM):
    pass
type(model.model)
transformers.models.qwen2.modeling_qwen2.Qwen2Model

And this one is defined here as

class Qwen2Model(MistralModel):
    pass

And MistralModel is defined here like so

class MistralModel(LlamaModel):
    def __init__(self, config: MistralConfig):
        super().__init__(config)
        self.layers = nn.ModuleList(
            [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
    ...

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        config: MistralConfig,
        past_key_values: Cache,
    ):
    ...

Here is LlamaModel:

class LlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
    ) -> Union[Tuple, BaseModelOutputWithPast]:
    ...

So, MistralModel overwrites LLamaModel’s deep layers with its own

Here is MistralDecoderLayer:

class MistralDecoderLayer(LlamaDecoderLayer):
    def __init__(self, config: MistralConfig, layer_idx: int):
        super().__init__(config, layer_idx)
        self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
        self.mlp = MistralMLP(config)
len(model.model.layers)
24
model.model.layers[0]
Qwen2DecoderLayer(
  (self_attn): Qwen2SdpaAttention(
    (q_proj): Linear(in_features=896, out_features=896, bias=True)
    (k_proj): Linear(in_features=896, out_features=128, bias=True)
    (v_proj): Linear(in_features=896, out_features=128, bias=True)
    (o_proj): Linear(in_features=896, out_features=896, bias=False)
    (rotary_emb): Qwen2RotaryEmbedding()
  )
  (mlp): Qwen2MLP(
    (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
    (up_proj): Linear(in_features=896, out_features=4864, bias=False)
    (down_proj): Linear(in_features=4864, out_features=896, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
  (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
)

Btw, here are the model sizes that work on Google Collab:

l = model.model.layers[0]
l.self_attn
Qwen2SdpaAttention(
  (q_proj): Linear(in_features=896, out_features=896, bias=True)
  (k_proj): Linear(in_features=896, out_features=128, bias=True)
  (v_proj): Linear(in_features=896, out_features=128, bias=True)
  (o_proj): Linear(in_features=896, out_features=896, bias=False)
  (rotary_emb): Qwen2RotaryEmbedding()
)

Here is how we can extract query projection weight:

l.self_attn.q_proj.weight
Parameter containing:
tensor([[-0.0019, -0.0052,  0.0188,  ..., -0.0061, -0.0153,  0.0038],
        [ 0.0084,  0.0018,  0.0435,  ...,  0.0066, -0.0422, -0.0181],
        [-0.0168, -0.0248,  0.0422,  ...,  0.0089, -0.0008, -0.0094],
        ...,
        [-0.1040,  0.0791,  0.0132,  ..., -0.0161, -0.0221, -0.0588],
        [-0.0140,  0.0654,  0.0591,  ...,  0.0410, -0.0046,  0.0025],
        [ 0.0215,  0.0625,  0.0635,  ..., -0.0036, -0.0354, -0.0957]],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)

Question: What is o_proj?

l.mlp
Qwen2MLP(
  (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
  (up_proj): Linear(in_features=896, out_features=4864, bias=False)
  (down_proj): Linear(in_features=4864, out_features=896, bias=False)
  (act_fn): SiLU()
)