utils.trainer
utils.trainer
Module containing the Trainer class and related functions
Functions
| Name | Description |
|---|---|
| add_pose_position_ids | use the PoSE technique to extend the context length by randomly skipping |
| add_position_ids | Handle both single-example and batched data. |
| drop_long_seq | Drop samples whose sequence length is either too long (> sequence_len) |
| setup_trainer | Helper method for instantiating and building a (causal or RLHF) trainer. |
add_pose_position_ids
utils.trainer.add_pose_position_ids(
sample,
max_context_len=32768,
split_on_token_ids=None,
chunks=2,
)use the PoSE technique to extend the context length by randomly skipping positions in the context. We only want to skip right before tokens in the split_on_token_ids list. We should attempt to randomly distribute the skips, but we don’t need the final position_ids to be the full context_len. There may be multiple turns in the context, so we want to make sure we take into account the maximum possible number of skips remaining in each sample.
add_position_ids
utils.trainer.add_position_ids(sample)Handle both single-example and batched data. - single example: sample[‘input_ids’] is a list[int] - batched data: sample[‘input_ids’] is a list[list[int]]
drop_long_seq
utils.trainer.drop_long_seq(sample, sequence_len=2048, min_sequence_len=2)Drop samples whose sequence length is either too long (> sequence_len) or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
setup_trainer
utils.trainer.setup_trainer(
cfg,
train_dataset,
eval_dataset,
model,
tokenizer,
processor,
total_num_steps,
model_ref=None,
peft_config=None,
)Helper method for instantiating and building a (causal or RLHF) trainer.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| cfg | Axolotl config object containing training parameters. | required | |
| train_dataset | Dataset to use for training. | required | |
| eval_dataset | Dataset to use for evaluation. | required | |
| model | The model to train. | required | |
| tokenizer | Tokenizer for processing text input. | required | |
| processor | Processor for data preparation. | required | |
| total_num_steps | The total number of training steps. | required | |
| model_ref | Optional reference model for RLHF training. Default is None. | None |
|
| peft_config | Optional PEFT (Parameter-Efficient Fine-Tuning) configuration. Default is None. | None |
Returns
| Name | Type | Description |
|---|---|---|
A trainer instance (either HFRLTrainer or HFCausalTrainer) configured based on the provided parameters. |