monkeypatch.transformers_fa_utils
monkeypatch.transformers_fa_utils
see https://github.com/huggingface/transformers/pull/35834
Functions
| Name | Description |
|---|---|
| fixed_fa_peft_integration_check | PEFT usually casts the layer norms in float32 for training stability reasons |
fixed_fa_peft_integration_check
monkeypatch.transformers_fa_utils.fixed_fa_peft_integration_check(
query,
key,
value,
target_dtype=None,
preferred_dtype=None,
)PEFT usually casts the layer norms in float32 for training stability reasons therefore the input hidden states gets silently casted in float32. Hence, we need cast them back in float16 / bfloat16 just to be sure everything works as expected. This might slowdown training & inference so it is recommended to not cast the LayerNorms!
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| query | torch.Tensor |
Input query states to be passed to Flash Attention API | required |
| key | torch.Tensor |
Input key states to be passed to Flash Attention API | required |
| value | torch.Tensor |
Input value states to be passed to Flash Attention API | required |
| target_dtype | torch.dtype, optional |
The dtype to convert the attention tensors to. Conversion can be ignored by not providing the target dtype. | None |
| preferred_dtype | torch.dtype, optional |
The preferred dtype to convert the attention tensors to regardless of the target dtype. | None |