中文
本页内容

Trainer 模块接口

openmind.TrainingArguments类

TrainingArguments 类用于配置训练任务的参数,包括训练过程中所需的超参数、模型保存路径、日志记录选项、学习率等。

参数列表

  • PyTorch和MindSpore的TrainingArguments类共同支持的参数
参数名PyTorch类型MindSpore类型描述PyTorch默认值MindSpore默认值
output_dirstrstr输出目录。"./output"
overwrite_output_dirboolbool是否覆盖输出目录。FalseFalse
seedintint随机种子。4242
use_cpuboolbool是否使用CPU。FalseFalse
do_trainboolbool是否进行训练。FalseFalse
do_evalboolbool是否进行评估。FalseFalse
do_predictboolbool是否进行推理。FalseFalse
num_train_epochsfloatfloat训练的总轮数。3.03.0
resume_from_checkpointstrstr预加载权重。NoneNone
evaluation_strategyUnion[IntervalStrategy, str]Union[IntervalStrategy, str]评估策略。"no""no"
per_device_train_batch_sizeintint每个设备的训练批大小。88
per_device_eval_batch_sizeintint每个设备的评估批大小。88
per_gpu_train_batch_sizeintint每个GPU的训练批大小。(不推荐使用)NoneNone
per_gpu_eval_batch_sizeintint每个GPU的评估批大小。(不推荐使用)NoneNone
gradient_accumulation_stepsintint梯度累积步数。11
ignore_data_skipboolbool断点续训是否忽略数据跳过。FalseFalse
dataloader_drop_lastboolbool数据加载器是否丢弃最后一批。FalseTrue
dataloader_num_workersintint数据加载器进程数。08
optimUnion[OptimizerNames, str]Union[OptimizerType, str]优化器。"adamw_torch""fp32_adamw"
adam_beta1floatfloatAdam优化器的beta1。0.90.9
adam_beta2floatfloatAdam优化器的beta2。0.9990.999
adam_epsilonfloatfloatAdam优化器的epsilon。1e-81e-8
weight_decayfloatfloat权重衰减。0.00.0
lr_scheduler_typeUnion[SchedulerType, str]Union[LrSchedulerType, str]学习率调度器类型。"linear""cosine"
learning_ratefloatfloat学习率。5e-55e-5
warmup_ratiofloatfloat预热比率。0.0None
warmup_stepsintint预热步数。00
max_grad_normfloatfloat梯度裁剪的最大范数。1.01.0
logging_strategyUnion[IntervalStrategy, str]Union[LoggingIntervalStrategy, str]日志保存策略。"steps""steps"
logging_stepsfloatfloat日志保存步数。5001
save_stepsfloatfloat权重保存步数。500500
save_strategystrUnion[SaveIntervalStrategy, str]权重保存策略。"steps""steps"
save_total_limitintint权重最大保存数量限制。None5
save_on_each_nodeboolbool是否分片保存权重。FalseTrue
hub_model_idstrstrHub模型ID。NoneNone
hub_strategyUnion[HubStrategy, str]Union[HubStrategy, str]Hub推送策略。"every_save""every_save"
hub_tokenstrstrHub令牌。NoneNone
hub_private_repoboolboolHub私有仓库。FalseFalse
hub_always_pushboolbool是否始终推送到Hub。FalseFalse
data_seedintint数据采样器随机种子数。NoneNone
eval_stepsfloatfloat评估阶段的步骤数。NoneNone
push_to_hubboolbool是否推送到Hub。FalseFalse
  • PyTorch独立支持的参数
参数名类型描述默认值
optim_argsstr优化器参数。None
label_namesList[str]标签名称。None
load_best_model_at_endbool是否在最后加载最佳模型。False
metric_for_best_modelstr用于最佳模型的指标。None
greater_is_betterbool指标是否越大越好。None
label_smoothing_factorfloat标签平滑因子。0.0
include_inputs_for_metricsbool指标中是否包含输入。False
prediction_loss_onlybool是否在执行评估和生成预测时,仅返回损失。False
eval_accumulation_stepsint需要累积输出张量的预测步骤数。None
eval_delayfloat第一次评估需要等待的步骤数。None
max_stepsint最大训练步数。-1
lr_scheduler_kwargsdict调度器的额外参数。{}
log_levelstr日志等级。"passive"
log_level_replicastr在副本上使用的日志等级。"warning"
log_on_each_nodebool分布式训练是否只在主节点记录日志。True
logging_dirstr日志保存目录。None
logging_first_stepbool是否记录第一个“global_step”。False
logging_nan_inf_filterbool是否过滤'nan'和'inf'损失以进行日志记录。True
save_safetensorsbool是否以safetensor格式保存权重。True
save_only_modelbool在checkpointing时是否只保存模型状态。False
jit_mode_evalbool是否使用PyTorch jit跟踪进行推理。False
use_ipexbool是否使用Intel扩展。(不支持)False
bf16bool是否使用bf16格式。False
fp16bool是否使用fp16格式。False
tf32bool是否使用tf32格式。(不支持)None
fp16_opt_levelstr权重保存策略。"O1"
fp16_backendstr指定fp16所使用的后端。"auto"
half_precision_backendstr定义混精训练所使用的设备。"auto"
bf16_full_evalbool是否在评估阶段使用bf16。False
fp16_full_evalbool是否在评估阶段使用fp16。False
disable_tqdmbool是否禁用进度条工具。None
remove_unused_columnsbool是否自动删除模型forward方法未使用的列。True
fsdpUnion[List[FSDPOption, str]]是否使用fsdp。None
fsdp_configUnion[dict, str]fsdp配置。None
local_rankint分布式训练的进程号。-1
tpu_num_coresint使用TPU训练时使用的内核数。(不支持)None
past_indexint使用hidden states用作预测时的index。-1
ddp_backendstrddp分布式训练所使用的后端。None
run_namestr运行描述符。None
deepspeedstrdeepspeed配置。None
accelerator_configstraccelerate配置。None
debugUnion[str, List[DebugOption]]启用一个或多个调试功能。None
length_column_namestr预先计算长度的列名。"length"
group_by_lengthbool是否将训练数据集中长度大致相同的样本组合在一起。False
ddp_find_unused_parametersbool'find_unused_parameters'是否传递给'DistributedDataParallel'。None
report_toList[str]要报告结果和日志的集成列表。"all"
ddp_bucket_cap_mbint'bucket_cap_mb'传递给'DistributedDataParallel'的值。None
ddp_broadcast_buffersbool'ddp_broadcast_buffers'的值是否传递给'DistributedDataParallel'。None
ddp_timeoutintddp调用的超时。1800
dataloader_pin_memorybool是否要在数据加载器中固定内存。True
dataloader_persistent_workersbool是否保持工作线程数据集实例处于活动状态。False
dataloader_prefetch_factorint每个线程提前装载的Batch数。None
skip_memory_metricsbool是否跳过将内存探查器报告添加到指标。True
gradient_checkpointingbool是否使用梯度检查点来节省内存。False
gradient_checkpointing_kwargsdict梯度检查点相关参数。None
auto_find_batch_sizebool是否通过指数衰减自动找到适合内存的批处理大小。False
full_determinismbool调用'enable_full_determinism'而不是'set_seed',以确保分布式训练中的可重复结果。False
torchdynamostr指定TorchDynamo的后端编译器。(不支持)None
ray_scopestr使用Ray进行超参数搜索时要使用的范围。"last"
use_mps_devicebool是否使用mps设备。(不支持)False
torch_compilebool是否使用PyTorch 2.0编译模型。(不支持)False
torch_compile_backendstrtorch.compile所使用的后端。(不支持)None
torch_compile_modestrtorch.compile模式。(不支持)None
split_batchesbool是否将数据加载器生成的批次拆分到设备之间。None
include_tokens_per_secondbool是否计算每个设备每秒的tokens。None
include_num_input_tokens_seenbool是否跟踪在整个训练过程中看到的输入tokens。None
neftune_noise_alphafloat是否激活NEFTune噪声嵌入。None
optim_target_modulesUnion[str, List[str]]要优化的目标模块。None
  • MindSpore独立支持的参数
参数名类型描述默认值
only_save_strategybool任务是否保存策略文件后直接退出。False
auto_trans_ckptbool是否开启权重自动转换。False
src_strategystr预加载权重的分布式策略文件。None
batch_sizeint每个设备的训练批大小。会覆盖per_device_train_batch_sizeNone
sink_modebool是否通过通道将数据直接下沉到设备True
sink_sizeint每步训练或评估的数据下沉数量2
modeint指示运行在 GRAPH_MODE(0) 或 PYNATIVE_MODE(1)0
resume_trainingbool是否开启断点续训。False
remote_save_urlstrOBS保存路径。None
device_idint设备号。0
device_targetstr执行的目标设备,支持 'Ascend'、'GPU' 和 'CPU'。"Ascend"
enable_graph_kernelbool是否启用图融合。False
graph_kernel_flagsstr图融合级别。"--opt_level=0"
save_graphsbool是否保存计算图。False
save_graphs_pathstr保存计算图路径。"./graph"
max_call_depthint函数调用的最大深度。10000
max_device_memorystr设备的最大可用内存。"1024GB"
use_parallelbool是否开启并行模式。False
parallel_modeint指示是否运行于数据并行(0)、半自动并行(1)、自动并行(2)或混合并行(3)模式。1
gradients_meanbool是否在梯度AllReduce后执行平均算子。False
loss_repeated_meanbool在重复计算时,是否向后执行均值操作符。False
enable_alltoallbool是否允许在通信过程中生成 AllToAll 通信操作符。False
enable_parallel_optimizerbool是否开启优化器并行。False
full_batchbool如果在自动并行模式下加载整个批处理数据集,则应将 full_batch 设置为 True。当前不建议使用此接口,请将其替换为 dataset_strategy。True
dataset_strategyUnion[str, tuple]数据集分片策略。"full_batch"
search_modestr策略搜索模式,仅在自动并行模式下有效,实验性接口,请谨慎使用。"sharding_propagation"
data_parallelint数据并行。1
gradient_accumulation_shardbool累积梯度变量是否沿着数据并行维度进行分割。False
parallel_optimizer_thresholdint设置参数分割的阈值。64
optimizer_weight_shard_sizeint设置指定优化器权重分割的通信域大小。-1
strategy_ckpt_save_filestr保存分布式策略文件的路径。"./ckpt_strategy.ckpt"
model_parallelint模型并行。1
expert_parallelint专家并行。1
pipeline_stageint流水线并行。1
gradient_aggregation_groupint梯度通信操作融合组的大小。4
micro_batch_numint流水线计算最小批次数量。1
micro_batch_interleave_numint多副本并行数量。1
use_seq_parallelbool是否启用序列并行。False
vocab_emb_dpbool是否仅沿着数据并行维度分割词汇表。True
expert_numint专家的数量。1
capacity_factorfloat专家因子。1.05
aux_loss_factorfloat损失贡献因子。0.05
num_experts_chosenint每个标记选择的专家数量。1
recomputebool重计算。False
select_recomputebool选择重计算。False
parallel_optimizer_comm_recomputebool是否重新计算由优化器并行引入的 AllGather 通信。False
mp_comm_recomputebool是否重新计算模型并行引入的通信操作。True
recompute_slice_activationbool是否对保留在内存中的 Cell 输出进行切片。False
layer_scalebool是否启用层衰减。False
layer_decayfloat层衰减系数。0.65
lr_endfloat最终学习率。1e-6
warmup_lr_initfloat预热阶段的初始学习率。0.0
warmup_epochsint在总步数的 warmup_epochs 部分进行线性预热。None
lr_scalebool是否启用学习率缩放。False
lr_scale_factorint学习率缩放因子。256
python_multiprocessingbool是否启动 Python 多进程模式。False
numa_enablebool将 NUMA 的默认状态设置为启用状态。False
prefetch_sizeint设置管道中线程的队列容量。1
wrapper_typestr包装器的类名。"MFTrainOneStepCell"
scale_senseUnion[str, float]scale sense的值或类名。"DynamicLossScaleUpdateCell"
loss_scale_valueint初始损失缩放因子。65536
loss_scale_factorint损失缩放系数的增量和减量因子。2
loss_scale_windowint增加损失缩放系数的最大连续训练步数。1000
use_clip_gradbool是否启用梯度裁剪。True
train_datasetstr训练集路径。None
eval_datasetstr评估集路径。None
dataset_taskstr数据集对应的任务类型。None
dataset_typestr数据集类型。None
train_dataset_in_columnslist[str]训练集输入标签名称。None
train_dataset_out_columnslist[str]训练集输出标签名称。None
eval_dataset_in_columnslist[str]评估集输入标签名称。None
eval_dataset_out_columnslist[str]评估集输出标签名称。None
shufflebool训练集是否乱序。True
repeatint训练集重复次数。1
metric_typeUnion[List[str], str]矩阵类型。None
save_secondsint每隔 X 秒保存一次检查点。None
integrated_savebool在自动并行场景中是否合并并保存分割的张量。None
eval_epochsint每次评估之间的纪元间隔数,1 表示每个纪元结束时进行评估。None
profilebool是否开启性能分析收集。False
profile_start_stepint性能分析起始step。1
profile_end_stepint性能分析结束step。10
init_start_profilebool是否在 Profiler 初始化时启用数据收集。False
profile_communicationbool是否在多设备训练中收集通信性能数据。False
profile_memorybool是否收集张量内存数据。True
auto_tunebool是否启用自动数据加速。False
filepath_prefixstr优化后的全局配置的保存路径和文件前缀。"./autotune"
autotune_per_stepint设置调整自动数据加速配置的步长间隔。10

train_batch_size

获取训练批大小

接口原型

python
def train_batch_size()

eval_batch_size

获取评估批大小

接口原型

python
def eval_batch_size()

world_size

获取并行的进程数量

接口原型

python
def world_size()

process_index

获取当前进程的索引

接口原型

python
def process_index()

local_process_index

获取当前本地进程的索引

接口原型

python
def local_process_index()

should_log

获取当前进程是否应生成日志,当前仅支持PyTorch

接口原型

python
def should_log()

should_save

获取当前进程是否应写入磁盘,当前仅支持PyTorch

接口原型

python
def should_save()

_setup_devices

设置设备,当前仅支持PyTorch

接口原型

python
def _setup_devices()

device

获取当前进程使用的设备,当前仅支持PyTorch

接口原型

python
def device()

get_process_log_level

获取进程日志级别,当前仅支持PyTorch

接口原型

python
def get_process_log_level()

main_process_first

主进程优先,当前仅支持PyTorch

接口原型

python
def main_process_first(local: bool = True, desc: str = "work")

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
local是否本地。bool不支持
desc工作描述。str不支持

get_warmup_steps

获取预热迭代步数

接口原型

python
def get_warmup_steps(num_training_steps: int)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
num_training_steps训练迭代步数。intint

to_dict

将实例序列化为字典

接口原型

python
def to_dict()

to_json_string

将实例序列化为JSON字符串

接口原型

python
def to_json_string()

to_sanitized_dict

将实例序列化为可用于TensorBoard的参数字典,当前仅支持PyTorch

接口原型

python
def to_sanitized_dict()

set_training

设置训练参数

接口原型

python
def set_training(
    learning_rate: float = 5e-5,
    batch_size: int = 8,
    weight_decay: float = 0,
    num_epochs: float = 3,
    max_steps: int = -1,
    gradient_accumulation_steps: int = 1,
    seed: int = 42,
    **kwargs,
)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
learning_rate初始学习率。floatfloat
batch_size每个设备训练的批量大小。intint
weight_decay权重衰减。floatfloat
num_epochs执行的总训练周期数。floatfloat
max_steps最大训练步数。int不支持
gradient_accumulation_steps累积梯度的更新步数。intint
seed在训练开始时设置的随机种子。intint
kwargs["gradient_checkpointing"]如果为True,则使用梯度检查点来节省内存,但反向传播速度会减慢。bool不支持

set_evaluate

设置评估参数

接口原型

python
def set_evaluate(
    strategy: Union[str, IntervalStrategy] = "no",
    steps: int = 500,
    batch_size: int = 8,
    **kwargs,
)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
strategy训练过程中采用的评估策略,支持以下取值:
- "no": 在训练过程中不进行评估。
- "steps": 每steps步进行一次评估(并记录日志)。
- "epoch": 在每个周期结束时进行一次评估。
Union[str, IntervalStrategy]Union[str, IntervalStrategy]
steps如果strategy="steps",则在两次评估之间的更新步数。intint
batch_size用于评估的每个设备的批大小。intint
kwargs["accumulation_steps"]在将输出张量移动到CPU之前,累积预测步骤的输出张量的数量。int不支持
kwargs["delay"]在进行第一次评估之前等待的周期数或步数,具体取决于evaluation_strategy。float不支持
kwargs["loss_only"]仅忽略除损失之外的所有输出。bool不支持
kwargs["jit_mode"]是否在推理中使用PyTorch jit。bool不支持

set_testing

设置测试参数

接口原型

python
def set_testing(
    batch_size: int = 8,
    **kwargs,
)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
batch_size用于测试的每个设备的批大小。intint
kwargs["loss_only"]仅忽略除损失之外的所有输出。boolbool
kwargs["jit_mode"]是否在推理中使用PyTorch jit。bool不支持

set_save

设置与保存相关的所有参数

接口原型

python
def set_save(
    strategy: Union[str, IntervalStrategy] = "steps",
    steps: int = 500,
    total_limit: Optional[int] = None,
    on_each_node: bool = False,
)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
strategy权重保存策略,支持以下取值:
- "no": 训练期间不保存检查点。
- "epoch": 每个周期结束时保存检查点。
- "steps": 每save_steps步保存检查点。
Union[str, IntervalStrategy]Union[str, IntervalStrategy]
steps如果strategy="steps",该参数表示两次检查点保存之间的更新步数。intint
total_limit限制检查点的总数量。删除output_dir中较旧的检查点。intint
on_each_node在进行多节点分布式训练时,是否在每个节点上保存模型和检查点,还是只在主节点上保存。boolbool

set_logging

设置与日志记录相关的所有参数

接口原型

python
def set_logging(
    strategy: Union[str, IntervalStrategy] = "steps",
    steps: int = 500,
    report_to: Union[str, List[str]] = "none",
    level: str = "passive",
    first_step: bool = False,
    nan_inf_filter: bool = False,
    on_each_node: bool = False,
    replica_level: str = "passive",
)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
strategy训练日志保存策略,支持以下取值:
- "no": 训练期间不进行日志记录。
- "epoch": 每个周期结束时进行日志记录。
- "steps": 每save_steps步进行日志记录。
Union[str, IntervalStrategy]Union[str, IntervalStrategy]
steps如果strategy="steps",则在两次日志记录之间的更新步数。intint
report_to用于将模型推送到Hub的令牌。str不支持
level主进程上要使用的记录器日志级别。包括:"debug""info""warning""error""critical",以及"passive"str不支持
first_step是否记录和评估第一个global_stepbool不支持
nan_inf_filter是否过滤日志中的naninf损失。bool不支持
on_each_node在分布式训练中,是否在每个节点上使用log_level记录,还是只在主节点上记录。bool不支持
replica_level在副本上使用的记录器日志级别。str不支持

set_push_to_hub

设置与Hub同步检查点相关的所有参数

接口原型

python
def set_push_to_hub(
    model_id: str,
    strategy: Union[str, HubStrategy] = "every_save",
    token: Optional[str] = None,
    private_repo: bool = False,
    always_push: bool = False,
)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
model_id与本地output_dir同步的存储库的名称。它可以是一个简单的模型ID,在这种情况下,模型将被推送到您的命名空间。也可以是整个存储库名称,例如"user_name/model"strstr
strategy定义推送到Hub的策略,支持以下取值:
- "end": 当调用Trainer.save_model方法时,推送模型、其配置、分词器和模型卡片。
- "every_save": 每次保存模型时,推送模型、配置、分词器和模型卡片。推送是异步的,不会阻塞训练,如果保存非常频繁,只有在前一个推送完成后才会尝试新的推送。在训练结束时,最终模型会进行最后一次推送。
- "checkpoint": 类似于 "every_save",但最新的检查点也被推送到名为last-checkpoint的子文件夹中,使您可以轻松地使用trainer.train(resume_from_checkpoint="last-checkpoint")恢复训练。
- "all_checkpoints": 类似于"checkpoint",但所有检查点都被推送(因此您将在最终存储库中获得每个文件夹的一个检查点文件夹。)
Union[str, HubStrategy]Union[str, HubStrategy]
token用于将模型推送到Hub的令牌。strstr
private_repo如果为True,则Hub存储库将设置为私有。boolbool
always_push如果为False,当前一个推送未完成时,Trainer将跳过推送检查点。boolbool

set_optimizer

设置与优化器及其超参数相关的所有参数

接口原型

python
def set_optimizer(
    name: Union[str, OptimizerNames],
    learning_rate: float = 5e-5,
    weight_decay: float = 0,
    beta1: float = 0.9,
    beta2: float = 0.999,
    epsilon: float = 1e-8,
    **kwargs,
)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
name优化器类型。Union[str, OptimizerNames]Union[str, OptimizerType]
learning_rate初始学习率。floatfloat
lr_end最终学习率。不支持tfloat
weight_decay权重衰减。floatfloat
beta1Adam优化器或其变体的beta1超参数。floatfloat
beta2Adam优化器或其变体的beta2超参数。floatfloat
epsilonAdam优化器或其变体的epsilon超参数。floatfloat
kwargs["args"]传递给AnyPrecisionAdamW的可选参数(仅当optim="adamw_anyprecision"时有用),默认None。str不支持

set_lr_scheduler

设置与学习率调度器及其超参数相关的所有参数

接口原型

python
def set_lr_scheduler(
    name: Union[str, SchedulerType] = "linear",
    num_epochs: float = 3.0,
    max_steps: int = -1,
    warmup_ratio: float = 0,
    warmup_steps: int = 0,
)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
name学习率调度器类型。Union[str, SchedulerType]Union[str, LrSchedulerType]
num_epochs训练的总轮数。floatfloat
max_steps最大训练步数。int不支持
warmup_ratio用于从0到 learning_rate 进行线性预热的总训练步数的比率。floatfloat
warmup_steps用于从0到learning_rate进行线性预热的步数。intint

set_dataloader

设置数据加载器

接口原型

python
def set_dataloader(
    train_batch_size: int = 8,
    eval_batch_size: int = 8,
    drop_last: bool = False,
    num_workers: int = 0,
    ignore_data_skip: bool = False,
    sampler_seed: Optional[int] = None,
    **kwargs,
)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
train_batch_size训练的批大小。intint
eval_batch_size评估的批大小。intint
drop_last是否丢弃最后一个不完整的批次。boolbool
num_workers用于数据加载的子进程数量。intint
ignore_data_skip在恢复训练时,是否跳过批次和周期以使数据加载处于与上一次训练相同的阶段。boolbool
sampler_seed用于数据采样器的随机种子。intint
kwargs["pin_memory"]是否要在数据加载器中固定内存,默认True。bool不支持
kwargs["persistent_workers"]如果为True,则数据加载器在数据集被消耗一次后不会关闭工作进程。这允许保持工作进程的数据集实例处于活动状态。可能会加速训练,但会增加RAM使用量,默认False。bool不支持
kwargs["auto_find_batch_size"]自动找到适合内存的批大小,需要安装accelerate,默认False。bool不支持
kwargs["prefetch_factor"]每个进程预先会加载的批次数量。int不支持

openmind.Trainer类

Trainer类用于实现模型的训练、评估和推理等功能。它是训练的核心组件,提供了许多方法和功能来管理整个训练过程,包括数据加载、模型前向传播、损失计算、梯度更新等。

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
args用于配置数据集、超参数、优化器等的任务配置。TrainingArgumentsTrainingArguments
task任务类型。不支持str
model用于训练、评估或进行预测的模型实例。Union[PreTrainedModel, torch.nn.Module]Union[mindformers.models.PreTrainedModel, str]
model_name模型名称。不支持str
pet_methodPET方法名称。不支持str
tokenizer分词器。PreTrainedTokenizerBasemindformers.models.PreTrainedTokenizerBase
train_dataset训练数据集。DatasetUnion[str, mindspore.dataset.BaseDataset]
eval_dataset评估数据集。Union[Dataset, Dict[str, Dataset]]Union[str, mindspore.dataset.BaseDataset]
data_collator数据批处理的函数。DataCollator不支持
image_processor图像预处理的处理器。不支持mindformers.models.BaseImageProcessor
audio_processor音频预处理的处理器。不支持mindformers.models.BaseAudioProcessor
optimizers优化器。Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]mindspore.nn.Optimizer
compute_metrics评估时计算指标的函数。Callable[[EvalPrediction], Dict]Union[dict, set]
callbacks回调函数列表。List[TrainerCallback]Union[List[mindspore.train.Callback], mindspore.train.Callback]
eval_callbacks评估回调函数列表。不支持Union[List[mindspore.train.Callback], mindspore.train.Callback]
model_init实例化要使用的模型的函数。Callable[[], PreTrainedModel]不支持
preprocess_logits_for_metrics计算指标前对输出结果预处理函数。Callable[[torch.Tensor, torch.Tensor], torch.Tensor]不支持
save_config保存当前任务的配置。不支持bool

train

执行训练步骤

接口原型

python
def train(*args, **kwargs)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
train_checkpoint恢复网络的训练权重。不支持Union[str, bool]
resume_from_checkpoint预加载权重。Union[str, bool]Union[str, bool]
trial运行的试验或用于超参数搜索的超参数字典。Union["optuna.Trial", Dict[str, Any]]不支持
ignore_keys_for_eval在训练期间用于收集评估预测时,应该忽略的模型输出中的键列表。List[str]不支持
resume_training断点续训开关。不支持bool
auto_trans_ckpt权重自动转换开关。不支持bool
src_strategy预加载权重的分布式策略文件。不支持str
do_eval是否在训练过程中进行评估。不支持bool

evaluate

运行评估

接口原型

python
def evaluate(*args, **kwargs)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
eval_dataset评估数据集。Union[Dataset, Dict[str, Dataset]]Union[str, mindspore.dataset.BaseDataset, mindspore.dataset.Dataset, Iterable]
eval_checkpoint评估网络的权重。不支持Union[str, bool]
ignore_keys在收集预测时,应该忽略模型输出中的键列表。List[str]不支持
metric_key_prefix指标名前缀。str不支持
auto_trans_ckpt权重自动转换开关。不支持bool
src_strategy预加载权重的分布式策略文件。不支持str

predict

运行推理

接口原型

python
def predict(*args, **kwargs)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
predict_checkpoint推理网络的权重。不支持Union[str, bool]
test_dataset推理数据集。Dataset不支持
ignore_keys在收集预测时,应该忽略模型输出中的键列表。List[str]不支持
metric_key_prefix指标名前缀。str不支持
input_data推理输入数据。不支持Union[GeneratorDataset,Tensor, np.ndarray, Image, str, list]
batch_size批处理大小。不支持int
auto_trans_ckpt权重自动转换开关。不支持bool
src_strategy预加载权重的分布式策略文件。不支持str

add_callback

向当前回调列表中添加一个回调函数

接口原型

python
def add_callback(callback)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
callback回调函数。Union[type, TrainerCallback]Union[type, mindspore.train.Callback]

pop_callback

从当前回调列表中删除一个回调并将其返回。如果找不到回调,则返回None(不会引发错误)

接口原型

python
def pop_callback(callback)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
callback回调函数。Union[type, TrainerCallback]Union[type, mindspore.train.Callback]

remove_callback

从当前回调列表中删除一个回调函数

接口原型

python
def remove_callback(callback)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
callback回调函数。Union[type, TrainerCallback]Union[type, mindspore.train.Callback]

save_model

保存模型,以便您可以使用from_pretrained()方法重新加载它

接口原型

python
def save_model(*args, **kwargs)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
output_dir模型保存路径。strstr
_internal_callargs.push_to_hub为True的情况下,当用户调用save_model方法时,是否将模型上传到存储库Hub中。默认值为False,表示进行推送。boolbool

init_hf_repo

创建并初始化args.hub_model_id中的git repo。

接口原型

python
def init_hf_repo()

push_to_hub

modeltokenizer上传到模型存储库Hub中的仓库args.hub_model_id

接口原型

python
def push_to_hub(
    commit_message: Optional[str] = "End of training",
    blocking: bool = True,
    **kwargs,
)

参数列表

参数名描述PyTorch支持类型MindSpore支持类型
commit_message推送时的提交消息,默认为"End of training"。strstr
blocking函数是否应该仅在git push完成时返回,默认为True。boolbool