kernels.geglu
kernels.geglu
Module for definition of GEGLU Triton kernels.
See “GLU Variants Improve Transformer” (https://arxiv.org/abs/2002.05202).
Credit to unsloth (https://unsloth.ai/) for inspiration for this implementation.
Functions
| Name | Description |
|---|---|
| geglu_backward | GEGLU backward pass using in-place operations. |
| geglu_forward | GEGLU forward pass. |
geglu_backward
kernels.geglu.geglu_backward(grad_output, gate, up)GEGLU backward pass using in-place operations.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| grad_output | torch.Tensor | Gradient of loss with respect to output, shape [batch, seq_len, hidden_dim]. |
required |
| gate | torch.Tensor | Gate tensor from forward pass, shape [batch, seq_len, hidden_dim]. |
required |
| up | torch.Tensor | Up-projection tensor from forward pass, shape [batch, seq_len, hidden_dim]. |
required |
Returns
| Name | Type | Description |
|---|---|---|
| tuple[torch.Tensor, torch.Tensor, torch.Tensor] | Tuple containing: - GEGLU activation output (h) - Gradient with respect to gate (grad_gate) - Gradient with respect to up (grad_up) |
Note
This function modifies its input tensors in-place to store results.
geglu_forward
kernels.geglu.geglu_forward(gate, up)GEGLU forward pass.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| gate | torch.Tensor | Input gate tensor of shape [batch, seq_len, hidden_dim]. | required |
| up | torch.Tensor | Up-projection tensor of shape [batch, seq_len, hidden_dim]. | required |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | torch.Tensor: Output tensor of shape [batch, seq_len, hidden_dim]. |