utils.ctx_managers.sequence_parallel
utils.ctx_managers.sequence_parallel
Module for Axolotl trainer sequence parallelism manager and utilities
Classes
| Name | Description |
|---|---|
| AllGatherWithGrad | Custom autograd function for all-gather to preserve gradients. |
| SequenceParallelContextManager | Context manager for sequence parallelism operations. |
AllGatherWithGrad
utils.ctx_managers.sequence_parallel.AllGatherWithGrad()Custom autograd function for all-gather to preserve gradients.
Methods
| Name | Description |
|---|---|
| backward | Backward pass for all-gather operation. |
| forward | Forward pass of all-gather of data with sequence dimension. |
backward
utils.ctx_managers.sequence_parallel.AllGatherWithGrad.backward(
ctx,
grad_output,
)Backward pass for all-gather operation.
Extracts the gradient slice corresponding to this rank’s original input from the full gradient tensor.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| ctx | torch.autograd.function.FunctionCtx | torch.autograd function context. |
required |
| grad_output | torch.Tensor | Gradient from subsequent layers with respect to the concatenated output tensor. | required |
Returns
| Name | Type | Description |
|---|---|---|
| tuple[torch.Tensor, None] | Tuple containing the gradient slice for this rank’s input tensor and None for the process group parameter which doesn’t require gradients. |
forward
utils.ctx_managers.sequence_parallel.AllGatherWithGrad.forward(
ctx,
input_tensor,
group,
)Forward pass of all-gather of data with sequence dimension.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| ctx | torch.autograd.function.FunctionCtx | torch.autograd function context. |
required |
| input_tensor | torch.Tensor | Tensor from model output with sequence dimension. | required |
| group | dist.ProcessGroup | torch.distributed process group. |
required |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | Tensor from gathering the input_tensor from across the process group and concatenating along the sequence dimension. |
SequenceParallelContextManager
utils.ctx_managers.sequence_parallel.SequenceParallelContextManager(
models,
context_parallel_size,
gradient_accumulation_steps,
ring_attn_func,
heads_k_stride,
gather_outputs,
device_mesh=None,
)Context manager for sequence parallelism operations.
This class provides a context that will automatically apply sequence parallelism during model forward passes using a pre-forward hook, and gather outputs from across the sequence parallelism group using a post-forward hook.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| models | list[nn.Module] | List of models to apply sequence parallelism to pre- and post- forward hooks. | required |
| context_parallel_size | int | Number of processes to split sequences over. | required |
| gradient_accumulation_steps | int | Number of steps to accumulate gradients over. | required |
| ring_attn_func | RingAttnFunc | Which ring attention function to use. Currently unused. | required |
| heads_k_stride | int | None | Sequence parallelism K head stride size. Passed through to varlen_llama3 ring_flash_attn implementation. |
required |
| gather_outputs | bool | Whether to gather outputs after model forward pass across the sequence parallel group. | required |
Functions
| Name | Description |
|---|---|
| apply_sequence_parallelism | Apply sequence parallelism slicing to a batch. |
apply_sequence_parallelism
utils.ctx_managers.sequence_parallel.apply_sequence_parallelism(
batch,
local_rank,
local_world_size,
gradient_accumulation_steps,
ring_attn_func,
)Apply sequence parallelism slicing to a batch.
Special handling is implemented for integer logits_to_keep, which indicates to only keep the last N tokens in the sequence during generation.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| batch | dict[str, torch.Tensor] | Batch dictionary (e.g., input_ids, attention_mask, etc.). | required |
| local_rank | int | Local rank in the sequence parallel group. | required |
| local_world_size | int | World size of the sequence parallel group. | required |
| gradient_accumulation_steps | int | Number of steps to accumulate gradients over. | required |
| ring_attn_func | RingAttnFunc | Which ring attention function to use. Currently unused, but related to above TODO. | required |
Returns
| Name | Type | Description |
|---|---|---|
| tuple[dict[str, torch.Tensor], int, int] | tuple of: - Batch dictionary with sliced tensors. - The original sequence length before padding. - The number of padding tokens added. |