English
Content on This Page

Trainer APIs

openmind.TrainingArguments Class

The TrainingArguments class is used to configure parameters of training tasks, including hyperparameters, model saving paths, log recording options, and learning rates required during training.

Parameters

  • Parameters supported by the TrainingArguments classes of both PyTorch and MindSpore
NamePyTorch TypeMindSpore TypeDescriptionDefault Value for PyTorchDefault Value for MindSpore
output_dirstrstrOutput directoryNone"./output"
overwrite_output_dirboolboolWhether to overwrite the output directoryFalseFalse
seedintintRandom seed4242
use_cpuboolboolWhether to use a CPUFalseFalse
do_trainboolboolWhether to perform trainingFalseFalse
do_evalboolboolWhether to perform evaluationFalseFalse
do_predictboolboolWhether to perform inferenceFalseFalse
num_train_epochsfloatfloatTotal number of training epochs3.03.0
resume_from_checkpointstrstrPreloaded weightsNoneNone
evaluation_strategyUnion[IntervalStrategy, str]Union[IntervalStrategy, str]Evaluation strategy"no""no"
per_device_train_batch_sizeintintTraining batch size of each device88
per_device_eval_batch_sizeintintEvaluation batch size of each device88
per_gpu_train_batch_sizeintintTraining batch size of each GPU (Not recommended.)NoneNone
per_gpu_eval_batch_sizeintintEvaluation batch size of each GPU (Not recommended.)NoneNone
gradient_accumulation_stepsintintGradient accumulation steps11
ignore_data_skipboolboolWhether to ignore data skipping during resumable trainingFalseFalse
dataloader_drop_lastboolboolWhether the data loader drops the last batchFalseTrue
dataloader_num_workersintintNumber of processes in the data loader08
optimUnion[OptimizerNames, str]Union[OptimizerType, str]Optimizer"adamw_torch""fp32_adamw"
adam_beta1floatfloatAdam optimizer beta10.90.9
adam_beta2floatfloatAdam optimizer beta20.9990.999
adam_epsilonfloatfloatAdam optimizer epsilon1e-81e-8
weight_decayfloatfloatWeight decay0.00.0
lr_scheduler_typeUnion[SchedulerType, str]Union[LrSchedulerType, str]Type of the learning rate scheduler"linear""cosine"
learning_ratefloatfloatLearning rate.5e-55e-5
warmup_ratiofloatfloatWarmup ratio0.0None
warmup_stepsintintWarmup steps00
max_grad_normfloatfloatMaximum norm of gradient clipping1.01.0
logging_strategyUnion[IntervalStrategy, str]Union[LoggingIntervalStrategy, str]Logging strategy"steps""steps"
logging_stepsfloatfloatNumber of logging steps5001
save_stepsfloatfloatWeight saving steps500500
save_strategystrUnion[SaveIntervalStrategy, str]Weight saving strategy"steps""steps"
save_total_limitintintMaximum number of weights that can be savedNone5
save_on_each_nodeboolboolWhether to save weights on different nodesFalseTrue
hub_model_idstrstrHub model IDNoneNone
hub_strategyUnion[HubStrategy, str]Union[HubStrategy, str]Hub push strategy"every_save""every_save"
hub_tokenstrstrHub tokenNoneNone
hub_private_repoboolboolHub private repositoryFalseFalse
hub_always_pushboolboolWhether to always push a model to the HubFalseFalse
data_seedintintNumber of random seeds of a data samplerNoneNone
eval_stepsfloatfloatNumber of steps during evaluationNoneNone
push_to_hubboolboolWhether to push a model to the HubFalseFalse
  • Parameters supported independently by PyTorch
NameTypeDescriptionDefault Value
optim_argsstrOptimizer parametersNone
label_namesList[str]Label nameNone
load_best_model_at_endboolWhether to load the optimal model at the endFalse
metric_for_best_modelstrMetric for the optimal modelNone
greater_is_betterboolWhether a greater metric is betterNone
label_smoothing_factorfloatLabel smoothing factor0.0
include_inputs_for_metricsboolWhether the metric includes inputsFalse
prediction_loss_onlyboolWhether to return only loss when evaluation is conducte and prediction is generatedFalse
eval_accumulation_stepsintNumber of prediction steps to accumulate output tensorsNone
eval_delayfloatNumber of steps to wait before the first evaluationNone
max_stepsintMaximum number of training steps-1
lr_scheduler_kwargsdictAddition arguments of the scheduler{}
log_levelstrLog level"passive"
log_level_replicastrLog level used on replicas"warning"
log_on_each_nodeboolWhether logs are recorded only in the primary node for distributed trainingTrue
logging_dirstrDirectory that saves logsNone
logging_first_stepboolWhether to record the first global_stepFalse
logging_nan_inf_filterboolWhether to filter 'nan' and 'inf' loss for log recordingTrue
save_safetensorsboolWhether to save the weights in safetensor formatTrue
save_only_modelboolWhether to save only model status during checkpointingFalse
jit_mode_evalboolWhether to use PyTorch JIT for inferenceFalse
use_ipexboolWhether to use Intel extensions (Not supported)False
bf16boolWhether to use the bf16 formatFalse
fp16boolWhether to use the fp16 formatFalse
tf32boolWhether to use the tf32 format (Not supported)None
fp16_opt_levelstrWeight saving strategy"O1"
fp16_backendstrSpecifies the backend used by fp16."auto"
half_precision_backendstrDevice used for mixed-precision training"auto"
bf16_full_evalboolWhether to use bf16 during evaluationFalse
fp16_full_evalboolWhether to use fp16 during evaluationFalse
disable_tqdmboolWhether to disable the progress barNone
remove_unused_columnsboolWhether to automatically delete columns that are not used by the model's forward methodTrue
fsdpUnion[List[FSDPOption, str]]Whether to use FSDPNone
fsdp_configUnion[dict, str]FSDP configurationNone
local_rankintProcess ID for distributed training-1
tpu_num_coresintNumber of cores used for TPU training (Not supported)None
past_indexintIndex when hidden states are used for prediction-1
ddp_backendstrBackend for DDP distributed trainingNone
run_namestrRunning descriptorNone
deepspeedstrDeepSpeed configurationNone
accelerator_configstrAccelerate configurationNone
debugUnion[str, List[DebugOption]]Enablement for one or more debugging functionsNone
length_column_namestrColumn name for precomputed length"length"
group_by_lengthboolWhether to group together samples with roughly the same length in the training datasetFalse
ddp_find_unused_parametersboolWhether to pass find_unused_parameters to DistributedDataParallelNone
report_toList[str]List of integrations to report the results and logs"all"
ddp_bucket_cap_mbintValue of bucket_cap_mb passed to DistributedDataParallelNone
ddp_broadcast_buffersboolWhether to pass the value of ddp_broadcast_buffers to DistributedDataParallelNone
ddp_timeoutintTimeout for DDP calls1800
dataloader_pin_memoryboolWhether to pin memory in the data loaderTrue
dataloader_persistent_workersboolWhether to maintain worker dataset instances aliveFalse
dataloader_prefetch_factorintNumber of batches preloaded by each workerNone
skip_memory_metricsboolWhether to skip adding of memory profiler report to metricsTrue
gradient_checkpointingboolWhether to use gradient checkpoints to save memoryFalse
gradient_checkpointing_kwargsdictArguments related to gradient checkpointsNone
auto_find_batch_sizeboolWhether to find a batch size that fits into memory automatically through exponential decayFalse
full_determinismboolWhether to call enable_full_determinism instead of set_seed to ensure reproducible results in distributed trainingFalse
torchdynamostrBackend compiler of TorchDynamo (Not supported)None
ray_scopestrScope to use when doing hyperparameter search with Ray"last"
use_mps_deviceboolWhether to use mps device (Not supported)False
torch_compileboolWhether to use PyTorch 2.0 to compile the model (Not supported)False
torch_compile_backendstrBackend used by torch.compile (Not supported)None
torch_compile_modestrtorch.compile mode (Not supported)None
split_batchesboolWhether to split batches generated by the data loader to devicesNone
include_tokens_per_secondboolWhether to calculate tokens per second of each deviceNone
include_num_input_tokens_seenboolWhether to track the number of input tokens seen throughout trainingNone
neftune_noise_alphafloatWhether to activate NEFTune noise embeddingsNone
optim_target_modulesUnion[str, List[str]]Target module to be optimizedNone
  • Parameters supported independently by MindSpore
NameTypeDescriptionDefault Value
only_save_strategyboolWhether a task directly exits after saving the strategy fileFalse
auto_trans_ckptboolWhether to enable automatic weight transformationFalse
src_strategystrDistributed strategy file for preloaded weightsNone
batch_sizeintTraining batch size of each device. per_device_train_batch_size will be overwritten.None
sink_modeboolWhether to directly sink data to devices through channels.True
sink_sizeintSize of sunk data for training or evaluation in each step.2
modeintRunning in GRAPH_MODE (0) or PYNATIVE_MODE (1).0
resume_trainingboolWhether to enable resumable trainingFalse
remote_save_urlstrOBS saving pathNone
device_idintDevice ID0
device_targetstrTarget device where the task is executed. Value options: 'Ascend', 'GPU', or 'CPU'."Ascend"
enable_graph_kernelboolWhether to use graph fusion.False
graph_kernel_flagsstrLevel of graph fusion."--opt_level=0"
save_graphsboolWhether to save computational graphsFalse
save_graphs_pathstrPath for saving computational graphs"./graph"
max_call_depthintMaximum depth of a function call.10000
max_device_memorystrMaximum available memory of the device."1024GB"
use_parallelboolWhether to enable the parallel modeFalse
parallel_modeintWhether to run in data parallel (0), semi-automatic parallel (1), automatic parallel (2), or hybrid parallel (3) mode.1
gradients_meanboolWhether to execute the mean operator after gradient AllReduceFalse
loss_repeated_meanboolWhether to execute the mean operator backwards during repeated computation.False
enable_alltoallboolIndicates whether to generate the AllToAll communication operator during communication.False
enable_parallel_optimizerboolWhether to enable optimizer parallelismFalse
full_batchboolIf the entire batch dataset is loaded in automatic parallel mode, set full_batch to True. This API is not recommended. Replace it with dataset_strategy.True
dataset_strategyUnion[str, tuple]Dataset sharding strategy."full_batch"
search_modestrStrategy search mode, which is valid only in automatic parallel mode. This API is an experimental API. Exercise caution when using this API."sharding_propagation"
data_parallelintData parallelism1
gradient_accumulation_shardboolWhether the gradient accumulation variable is sharded along the data parallelism dimension.False
parallel_optimizer_thresholdintSets the threshold for parameter optimizer.64
optimizer_weight_shard_sizeintSets the size of the communicator for which the optimizer weight is sharded.-1
strategy_ckpt_save_filestrPath for saving distribution strategy files."./ckpt_strategy.ckpt"
model_parallelintModel parallelism1
expert_parallelintExpert parallelism1
pipeline_stageintPipeline parallelism1
gradient_aggregation_groupintSize of a gradient communication operation fusion group.4
micro_batch_numintMinimum number of batches for pipeline computation1
micro_batch_interleave_numintNumber of concurrent copies1
use_seq_parallelboolWhether to enable sequence parallelism.False
vocab_emb_dpboolWhether to split the vocabulary only along the data parallelism dimension.True
expert_numintNumber of experts.1
capacity_factorfloatExpert factor.1.05
aux_loss_factorfloatLoss contribution factor.0.05
num_experts_chosenintNumber of experts selected for each marker.1
recomputeboolRecomputationFalse
select_recomputeboolSelective recomputationFalse
parallel_optimizer_comm_recomputeboolWhether to recalculate the AllGather communication introduced by the optimizer parallelism.False
mp_comm_recomputeboolWhether to recalculate the communication operations introduced by model parallelism.True
recompute_slice_activationboolWhether to slice the cell output stored in the memory.False
layer_scaleboolWhether to enable layer decay.False
layer_decayfloatLayer decay coefficient.0.65
lr_endfloatEnd learning rate.1e-6
warmup_lr_initfloatInitial learning rate in the warm-up phase.0.0
warmup_epochsintPerforms linear preheating in the warmup_epochs part of the total number of steps.None
lr_scaleboolWhether to enable learning rate scaling.False
lr_scale_factorintLearning rate scaling factor.256
python_multiprocessingboolWhether to enable the Python multi-process mode.False
numa_enableboolSet the default status of NUMA to enabled.False
prefetch_sizeintSets the queue capacity of threads in a pipe.1
wrapper_typestrClass name of the wrapper."MFTrainOneStepCell"
scale_senseUnion[str, float]Value or class name of scale sense."DynamicLossScaleUpdateCell"
loss_scale_valueintInitial loss scaling factor.65536
loss_scale_factorintIncrement and decrement factors of the loss scaling coefficient.2
loss_scale_windowintMaximum number of consecutive training steps for increasing the loss scaling coefficient.1000
use_clip_gradboolWhether to enable gradient clipping.True
train_datasetstrTraining dataset pathNone
eval_datasetstrEvaluation dataset pathNone
dataset_taskstrTask type corresponding to a datasetNone
dataset_typestrDataset typeNone
train_dataset_in_columnslist[str]Training dataset input column namesNone
train_dataset_out_columnslist[str]Training dataset output column namesNone
eval_dataset_in_columnslist[str]Evaluation dataset input column namesNone
eval_dataset_out_columnslist[str]Evaluation dataset output column namesNone
shuffleboolWhether the training dataset is out of orderTrue
repeatintNumber of repetitions of the training dataset1
metric_typeUnion[List[str], str]Matrix format.None
save_secondsintCheckpoints are saved every X seconds.None
integrated_saveboolWhether to merge and save the split tensors in the automatic parallelism scenario.None
eval_epochsintNumber of epoch intervals between evaluations. 1 indicates that evaluation is performed at the end of each epoch.None
profileboolWhether to enable profilingFalse
profile_start_stepintStart step of profiling1
profile_end_stepintEnd step of profiling10
init_start_profileboolWhether to enable data collection during profiler initialization.False
profile_communicationboolWhether to collect communication performance data during multi-device training.False
profile_memoryboolWhether to collect tensor memory data.True
auto_tuneboolWhether to enable automatic data acceleration.False
filepath_prefixstrSave path and file prefix of the optimized global configuration."./autotune"
autotune_per_stepintSets the step interval for adjusting the automatic data acceleration configuration.10

train_batch_size

Obtains the training batch size.

Prototype

python
def train_batch_size()

eval_batch_size

Obtains the evaluation batch size.

Prototype

python
def eval_batch_size()

world_size

Obtains the number of parallel processes.

Prototype

python
def world_size()

process_index

Obtains the index of the current process.

Prototype

python
def process_index()

local_process_index

Obtains the index of the current local process.

Prototype

python
def local_process_index()

should_log

Determines whether the current process should generate logs. Currently, this API is supported only by PyTorch.

Prototype

python
def should_log()

should_save

Determines whether the current process should be written to a disk. Currently, this API is supported only by PyTorch.

Prototype

python
def should_save()

_setup_devices

Sets the device. Currently, this API is supported only by PyTorch.

Prototype

python
def _setup_devices()

device

Obtains the device used by the current process. Currently, this API is supported only by PyTorch.

Prototype

python
def device()

get_process_log_level

Obtains the process log level. Currently, this API is supported only by PyTorch.

Prototype

python
def get_process_log_level()

main_process_first

Indicates that the main process takes precedence. Currently, this API is supported only by PyTorch.

Prototype

python
def main_process_first(local: bool = True, desc: str = "work")

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
localWhether it is localboolNot supported.
descWork descriptionstrNot supported.

get_warmup_steps

Obtains the number of warmup iteration steps.

Prototype

python
def get_warmup_steps(num_training_steps: int)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
num_training_stepsNumber of training iteration stepsintint

to_dict

Serializes instances into a dictionary.

Prototype

python
def to_dict()

to_json_string

Serializes instances into a JSON string.

Prototype

python
def to_json_string()

to_sanitized_dict

Serializes instances into a parameter dictionary that can be used for TensorBoard. Currently, this API is supported only by PyTorch.

Prototype

python
def to_sanitized_dict()

set_training

Set training parameters.

Prototype

python
def set_training(
    learning_rate: float = 5e-5,
    batch_size: int = 8,
    weight_decay: float = 0,
    num_epochs: float = 3,
    max_steps: int = -1,
    gradient_accumulation_steps: int = 1,
    seed: int = 42,
    **kwargs,
)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
learning_rateInitial learning ratefloatfloat
batch_sizeTraining batch size of each deviceintint
weight_decayWeight decayfloatfloat
num_epochsTotal number of training epochsfloatfloat
max_stepsMaximum number of training stepsintNot supported.
gradient_accumulation_stepsNumber of update steps of gradient accumulationintint
seedRandom seed set at the beginning of trainingintint
kwargs["gradient_checkpointing"]If the value is True, gradient checkpoints are used to save memory, but the backpropagation speed is slowed down.boolNot supported.

set_evaluate

Sets evaluation parameters.

Prototype

python
def set_evaluate(
    strategy: Union[str, IntervalStrategy] = "no",
    steps: int = 500,
    batch_size: int = 8,
    **kwargs,
)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
strategyEvaluation strategy used during training. The options are as follows:
- "no": No evaluation is performed during training.
- "steps": An evaluation is performed (and logged) every steps.
- "epoch": An evaluation is performed at the end of each epoch.
Union[str, IntervalStrategy]Union[str, IntervalStrategy]
stepsNumber of update steps between two evaluations if strategy="steps".intint
batch_sizeBatch size of each device used for evaluationintint
kwargs["accumulation_steps"]Number of prediction steps to accumulate the output tensors before the output tensors are moved to the CPUintNot supported.
kwargs["delay"]Number of epochs or steps to wait before the first evaluation is performed, which depends on the evaluation_strategyfloatNot supported.
kwargs["loss_only"]Ignores all outputs except lossesboolNot supported.
kwargs["jit_mode"]Whether to use PyTorch JIT in inferenceboolNot supported.

set_testing

Sets test parameters.

Prototype

python
def set_testing(
    batch_size: int = 8,
    **kwargs,
)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
batch_sizeBatch size of each device used for testingintint
kwargs["loss_only"]Ignores all outputs except lossesboolbool
kwargs["jit_mode"]Whether to use PyTorch JIT in inferenceboolNot supported.

set_save

Sets and saves all relevant parameters.

Prototype

python
def set_save(
    strategy: Union[str, IntervalStrategy] = "steps",
    steps: int = 500,
    total_limit: Optional[int] = None,
    on_each_node: bool = False,
)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
strategyWeight saving strategy. The options are as follows:
- "no": Checkpoints are not saved during training.
- "epoch": Checkpoints are saved at the end of each epoch.
- "steps": Checkpoints are saved every save_steps .
Union[str, IntervalStrategy]Union[str, IntervalStrategy]
stepsNumber of update steps between two checkpoint savings if strategy="steps"intint
total_limitLimits the total number of checkpoints. Older checkpoints in output_dir are deleted.intint
on_each_nodeWhether to save models and checkpoints on each node or only on the primary node during multi-node distributed trainingboolbool

set_logging

Sets all parameters related to log recording.

Prototype

python
def set_logging(
    strategy: Union[str, IntervalStrategy] = "steps",
    steps: int = 500,
    report_to: Union[str, List[str]] = "none",
    level: str = "passive",
    first_step: bool = False,
    nan_inf_filter: bool = False,
    on_each_node: bool = False,
    replica_level: str = "passive",
)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
strategyTraining log saving strategy. The options are as follows:
- "no": No log is recorded during training.
- "epoch": Logs are recorded at the end of each epoch.
- "steps": Logs are recorded every save_steps.
Union[str, IntervalStrategy]Union[str, IntervalStrategy]
stepsNumber of update steps between two log records if strategy="steps".intint
report_toToken used to push a model to the HubstrNot supported.
levelLogger log level to be used on the main process. The options include "debug", "info", "warning", "error", "critical", and "passive".strNot supported.
first_stepWhether to record and evaluate the first global_step.boolNot supported.
nan_inf_filterWhether to filter the nan and inf losses in logs.boolNot supported.
on_each_nodeWhether log_level is used for logging on each node or only on the primary node during distributed training.boolNot supported.
replica_levelLogger log level used on replicasstrNot supported.

set_push_to_hub

Sets all parameters for synchronizing checkpoints with the Hub.

Prototype

python
def set_push_to_hub(
    model_id: str,
    strategy: Union[str, HubStrategy] = "every_save",
    token: Optional[str] = None,
    private_repo: bool = False,
    always_push: bool = False,
)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
model_idName of the repository synchronized with the local output_dir. It can be a simple model ID, in which case the model will be pushed to your namespace. It can also be the name of the repository, for example, "user_name/model".strstr
strategyDefines the strategy for pushing data to the Hub. The options are as follows:
- "end": When the Trainer.save_model method is invoked, push a model and its configuration, tokenizer, and model card.
- "every_save": Each time a model is saved, push the model and its configuration, tokenizer, and model card. The push is asynchronous and does not block training. If saving is very frequent, a new push will only be attempted after the previous push is complete. At the end of training, the final model will make one last push.
- "checkpoint": similar to "every_save", but the latest checkpoint is also pushed to a subfolder named last-checkpoint, making it easy to resume training using trainer.train(resume_from_checkpoint="last-checkpoint").
- "all_checkpoints": similar to "checkpoint", but all checkpoints are pushed (so you will get a checkpoint folder for each folder in your final repository).
Union[str, HubStrategy]Union[str, HubStrategy]
tokenToken used to push a model to the Hubstrstr
private_repoIf the value is True, a Hub repository will be set to private.boolbool
always_pushIf the value is False, the Trainer will skip checkpoint pushing when the previous push is not complete.boolbool

set_optimizer

Sets all parameters related to the optimizer and its hyperparameters.

Prototype

python
def set_optimizer(
    name: Union[str, OptimizerNames],
    learning_rate: float = 5e-5,
    weight_decay: float = 0,
    beta1: float = 0.9,
    beta2: float = 0.999,
    epsilon: float = 1e-8,
    **kwargs,
)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
nameOptimizer typeUnion[str, OptimizerNames]Union[str, OptimizerType]
learning_rateInitial learning ratefloatfloat
lr_endEnd learning rate.Not supported.float
weight_decayWeight decayfloatfloat
beta1beta1 hyperparameter of the Adam optimizer or its variantsfloatfloat
beta2beta2 hyperparameter of the Adam optimizer or its variantsfloatfloat
epsilonepsilon hyperparameter of the Adam optimizer or its variantsfloatfloat
kwargs["args"]Optional parameter passed to AnyPrecisionAdamW (valid only when optim="adamw_anyprecision"). The default value is None.strNot supported.

set_lr_scheduler

Sets all parameters related to the learning rate scheduler and its hyperparameters.

Prototype

python
def set_lr_scheduler(
    name: Union[str, SchedulerType] = "linear",
    num_epochs: float = 3.0,
    max_steps: int = -1,
    warmup_ratio: float = 0,
    warmup_steps: int = 0,
)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
nameType of the learning rate schedulerUnion[str, SchedulerType]Union[str, LrSchedulerType]
num_epochsTotal number of training epochsfloatfloat
max_stepsMaximum number of training stepsintNot supported.
warmup_ratioRatio of total training steps used for a linear warmup from 0 to learning_rate.floatfloat
warmup_stepsNumber of steps used for a linear warmup from 0 to learning_rate.intint

set_dataloader

Sets a data loader.

Prototype

python
def set_dataloader(
    train_batch_size: int = 8,
    eval_batch_size: int = 8,
    drop_last: bool = False,
    num_workers: int = 0,
    ignore_data_skip: bool = False,
    sampler_seed: Optional[int] = None,
    **kwargs,
)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
train_batch_sizeTraining batch sizeintint
eval_batch_sizeEvaluation batch sizeintint
drop_lastWhether to drop the last incomplete batchboolbool
num_workersNumber of subprocesses used for data loadingintint
ignore_data_skipWhether to skip the batches and epochs during training resumption to get the data loading at the same stage as in the previous trainingboolbool
sampler_seedRandom seed for data samplerintint
kwargs["pin_memory"]Whether to pin memory in the data loader. The default value is True.boolNot supported.
kwargs["persistent_workers"]If the value is True, the data loader does not close the worker process after a dataset has been consumed once. This allows the dataset instance of the worker process to be active. The training may be accelerated, but the RAM usage increases. The default value is False.boolNot supported.
kwargs["auto_find_batch_size"]Automatically finds a batch size that will fit into memory. accelerate needs to be installed. The default value is False.boolNot supported.
kwargs["prefetch_factor"]Number of batches that each process loads in advance.intNot supported.

openmind.Trainer Class

The Trainer class is used to implement functions such as model training, evaluation, and inference. It is the core component of training and provides many methods and functions to manage the entire training process, including data loading, model forward propagation, loss calculation, and gradient update.

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
argsArguments used to configure datasets, hyperparameters, and optimizersTrainingArgumentsTrainingArguments
taskEnumerates the task types.Not supported.str
modelModel instance used for training, evaluation, or predictionUnion[PreTrainedModel, torch.nn.Module]Union[mindformers.models.PreTrainedModel, str]
model_nameModel name.Not supported.str
pet_methodPET method name.Not supported.str
tokenizerTokenizerPreTrainedTokenizerBasemindformers.models.PreTrainedTokenizerBase
train_datasetTraining datasetDatasetUnion[str, mindspore.dataset.BaseDataset]
eval_datasetEvaluation datasetUnion[Dataset, Dict[str, Dataset]]Union[str, mindspore.dataset.BaseDataset]
data_collatorFunction for batch data processingDataCollatorNot supported.
image_processorProcessor for image pre-processing.Not supported.mindformers.models.BaseImageProcessor
audio_processorProcessor for audio pre-processing.Not supported.mindformers.models.BaseAudioProcessor
optimizersOptimizerTuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]mindspore.nn.Optimizer
compute_metricsFunction for computing metrics during evaluationCallable[[EvalPrediction], Dict]Union[dict, set]
callbacksCallback function listList[TrainerCallback]Union[List[mindspore.train.Callback], mindspore.train.Callback]
eval_callbacksEvaluation callback function list.Not supported.Union[List[mindspore.train.Callback], mindspore.train.Callback]
model_initFunction that instantiates the model to be usedCallable[[], PreTrainedModel]Not supported.
preprocess_logits_for_metricsFunction that preprocesses the output results before computing metricsCallable[[torch.Tensor, torch.Tensor], torch.Tensor]Not supported.
save_configSaves the configuration of the current task.Not supported.bool

train

Performs training steps.

Prototype

python
def train(*args, **kwargs)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
train_checkpointRestores the training weight of the network.Not supported.Union[str, bool]
resume_from_checkpointPreloaded weightsUnion[str, bool]Union[str, bool]
trialTrial run or hyperparameter dictionary for hyperparameter searchUnion["optuna.Trial", Dict[str, Any]]Not supported.
ignore_keys_for_evalList of keys in the model output that should be ignored when used to collect evaluation predictions during trainingList[str]Not supported.
resume_trainingResumable training switchNot supported.bool
auto_trans_ckptAutomatic weight transformation switchNot supported.bool
src_strategyDistributed strategy file for preloaded weightsNot supported.str
do_evalWhether to perform evaluation during training.Not supported.bool

evaluate

Evaluates the operation.

Prototype

python
def evaluate(*args, **kwargs)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
eval_datasetEvaluation datasetUnion[Dataset, Dict[str, Dataset]]Union[str, mindspore.dataset.BaseDataset, mindspore.dataset.Dataset, Iterable]
eval_checkpointWeight of the evaluation network.Not supported.Union[str, bool]
ignore_keysList of keys in the model output that should be ignored when used to collect predictionsList[str]Not supported.
metric_key_prefixMetric name prefixstrNot supported.
auto_trans_ckptAutomatic weight transformation switchNot supported.bool
src_strategyDistributed strategy file for preloaded weightsNot supported.str

predict

Executes inference.

Prototype

python
def predict(*args, **kwargs)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
predict_checkpointWeight of the inference network.Not supported.Union[str, bool]
test_datasetInference datasetDatasetNot supported.
ignore_keysList of keys in the model output that should be ignored when used to collect predictionsList[str]Not supported.
metric_key_prefixMetric name prefixstrNot supported.
input_dataInput data for inferenceNot supported.Union[GeneratorDataset,Tensor, np.ndarray, Image, str, list]
batch_sizeBatch sizeNot supported.int
auto_trans_ckptAutomatic weight transformation switchNot supported.bool
src_strategyDistributed strategy file for preloaded weightsNot supported.str

add_callback

Adds a callback function to the current callback list.

Prototype

python
def add_callback(callback)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
callbackCallback used to return the result.Union[type, TrainerCallback]Union[type, mindspore.train.Callback]

pop_callback

Deletes a callback from the current callback list and returns it. If the callback cannot be found, None is returned (no error is thrown).

Prototype

python
def pop_callback(callback)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
callbackCallback used to return the result.Union[type, TrainerCallback]Union[type, mindspore.train.Callback]

remove_callback

Deletes a callback function from the current callback list.

Prototype

python
def remove_callback(callback)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
callbackCallback used to return the result.Union[type, TrainerCallback]Union[type, mindspore.train.Callback]

save_model

Saves a model so that you can reload it using the from_pretrained() method.

Prototype

python
def save_model(*args, **kwargs)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
output_dirModel saving pathstrstr
_internal_callWhether to upload the model to the repository hub when the save_model method is called when args.push_to_hub is set to True. The default value is False, indicating that the push is performed.boolbool

init_hf_repo

Create and initialize git repo in args.hub_model_id.

Prototype

python
def init_hf_repo()

push_to_hub

Upload model and tokenizer to the args.hub_model_id repository on the Hub.

Prototype

python
def push_to_hub(
    commit_message: Optional[str] = "End of training",
    blocking: bool = True,
    **kwargs,
)

Parameters

NameDescriptionPyTorch Supported TypeMindSpore Supported Type
commit_messageMessage to commit during push. The default value is End of trainingstrstr
blockingWhether the function should be returned only when git push is complete. The default value is True.boolbool