--- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -403,7 +403,7 @@ # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported self.is_aiter_triton_fp4_bmm_enabled = ( rocm_aiter_ops.is_fp4bmm_enabled() - and self.kv_b_proj.weight.dtype == torch.bfloat16 + and (self.kv_b_proj.weight.dtype if hasattr(self.kv_b_proj, "weight") else torch.bfloat16) == torch.bfloat16 ) # Attributes for forward_impl method @@ -2358,9 +2358,9 @@ # model dtype input and will quantize internally. if ( use_fp8_prefill - or self.kv_b_proj.weight.dtype != current_platform.fp8_dtype() + or (self.kv_b_proj.weight.dtype if hasattr(self.kv_b_proj, "weight") else torch.bfloat16) != current_platform.fp8_dtype() ): - kv_c_normed = kv_c_normed.to(self.kv_b_proj.weight.dtype) + kv_c_normed = kv_c_normed.to((self.kv_b_proj.weight.dtype if hasattr(self.kv_b_proj, "weight") else torch.bfloat16)) k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) kv_nope = self.kv_b_proj(kv_c_normed)[0].view(