# patches/mpt_patch_rotary_cache.py """ Patch for MPT model: - Fix rotary embedding cache when sequence length changes between forward passes. - Correct attention mask broadcasting for cross-attention layers. """ import torch import torch.nn as nn from typing import Optional, Tuple 1. Patch Rotary Embedding Cache ---------------------------------------------------------------------- def patched_rotate_half(x: torch.Tensor) -> torch.Tensor: """Split and rotate half the hidden dims (fixed for fp16 stability).""" x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1)
# Convert to additive mask (0 = keep, -inf = mask) return mask.to(dtype).masked_fill(mask == 0, 0.0).masked_fill(mask == 1, float("-inf")) 3. Monkey-patch into existing MPT model (example) ---------------------------------------------------------------------- def apply_mpt_patches(model: nn.Module): """Replace rotary and mask functions in an existing MPT model.""" # Patch rotary class if found for name, module in model.named_modules(): if "rotary" in name.lower() and hasattr(module, "cos_cached"): module. class = PatchedRotaryEmbedding print(f"[PATCH] Replaced rotary in name") patch mpt
# Broadcast to query_len mask = mask.expand(batch, 1, query_length, key_length) # patches/mpt_patch_rotary_cache
class PatchedRotaryEmbedding(nn.Module): """Rotary embedding with cache reset on seqlen change.""" def (self, dim: int, max_seq_len: int = 2048, base: int = 10000): super(). init () self.dim = dim self.max_seq_len = max_seq_len self.base = base self._cached_cos = None self._cached_sin = None self._cached_seq_len = None init () self
# Monkey-patch attention mask expansion function if model has it if hasattr(model, "_expand_attention_mask"): model._expand_attention_mask = patch_attention_mask print("[PATCH] Replaced _expand_attention_mask") Usage example ---------------------------------------------------------------------- if name == " main ": # Assume you have an MPT model loaded # from transformers import AutoModel # model = AutoModel.from_pretrained("mosaicml/mpt-7b", trust_remote_code=True) # apply_mpt_patches(model)