kernels.swiglu
kernels.swiglu
Module for definition of SwiGLU Triton kernels.
See “GLU Variants Improve Transformer” (https://arxiv.org/abs/2002.05202).
Credit to unsloth (https://unsloth.ai/) for inspiration for this implementation.
Functions
| Name | Description |
|---|---|
| swiglu_backward | SwiGLU backward pass using in-place operations. |
| swiglu_forward | SwiGLU forward pass. Computes SwiGLU activation: x * sigmoid(x) * up, where |
swiglu_backward
kernels.swiglu.swiglu_backward(grad_output, gate, up)SwiGLU backward pass using in-place operations.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| grad_output | torch.Tensor | Gradient of loss with respect to output, shape [batch, seq_len, hidden_dim]. |
required |
| gate | torch.Tensor | Gate tensor from forward pass, shape [batch, seq_len, hidden_dim]. |
required |
| up | torch.Tensor | Up-projection tensor from forward pass, shape [batch, seq_len, hidden_dim]. |
required |
Returns
| Name | Type | Description |
|---|---|---|
| tuple[torch.Tensor, torch.Tensor, torch.Tensor] | Tuple containing: - Forward pass output (h) - Gradient with respect to gate (df) - Gradient with respect to up-projection (de) |
swiglu_forward
kernels.swiglu.swiglu_forward(gate, up)SwiGLU forward pass. Computes SwiGLU activation: x * sigmoid(x) * up, where
x is the gate tensor.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| gate | torch.Tensor | Input gate tensor of shape [batch, seq_len, hidden_dim]. |
required |
| up | torch.Tensor | Up-projection tensor of shape [batch, seq_len, hidden_dim]. |
required |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | Output tensor of shape [batch, seq_len, hidden_dim]. |