跳转至

task.torch

common

DataLoaderProvider

geyser_lava.task.torch.common.DataLoaderProvider

构建数据加载器

参数类型 定义
provides ('loader',)
requires ('dataset_params', 'loader_params')
revert_requires ()

dataset_params定义构建 PyTorch 风格数据集对象的参数,其中reference为数据集类的路径,数据集类必须为DatasetIterableDataset的子类,其他参数会直接输入数据集类的构造函数中。

loader_params定义输入 PyTorch 风格数据加载器,即DataLoader的参数。

loader为任务构建的数据加载器,即DataLoader对象。

task

ModelProvider

geyser_lava.task.torch.common.ModelProvider

构建模型

参数类型 定义
provides ('model',)
requires ('model_params',)
revert_requires ()

model_params定义构建 PyTorch 模型对象的参数,其中reference为模型类或构建模型函数的路径,模型类必须为Module的子类,其他参数会直接输入模型类的构造函数中。

model为任务构建的模型,即Module对象。

task

LossProvider

geyser_lava.task.torch.common.LossProvider

构建损失函数

参数类型 定义
provides ('loss',)
requires ('loss_params',)
revert_requires ()

loss_params定义构建 PyTorch 损失函数对象的参数,其中reference为损失函数类或损失函数的构建函数的路径,损失函数类必须为Module的子类,其他参数会直接输入损失函数类的构造函数中。

loss为任务构建的损失函数,即Module对象。

task

OptimizerProvider

geyser_lava.task.torch.common.OptimizerProvider

构建优化器

参数类型 定义
provides ('optimizer',)
requires ('optimizer_params', 'model')
revert_requires ()

optimizer_params定义构建 PyTorch 优化器对象的参数,其中reference为优化器类的路径,优化器类必须为Optimizer的子类,其他参数会直接输入优化器类的构造函数中。

model定义该优化器优化的模型对象。

optimizer为任务构建的损失函数,即Optimizer对象。

task

trainer

SupervisedTrainer

geyser_lava.task.torch.trainer.SupervisedTrainer

有监督训练模型

参数类型 定义
provides ('model',)
requires ('model', 'train_loader', 'validate_loader', 'optimizer', 'device', 'loss', 'metrics_params', 'max_epochs', 'non_blocking')
revert_requires ()

model定义需要训练的模型对象,即Module对象。

train_loadervalidate_loader定义训练过程中训练数据集与验证数据集的数据加载器,即DataLoader对象。

optimizer定义训练使用的优化器对象,即Optimizer对象。

device定义模型训练过程中使用的设备,详见Device

loss定义训练过程中优化的损失函数对象,即Module对象。

metrics_params定义训练过程中显示的指标对象参数列表,每个元素参数中的reference为指标类的路径,关于指标类详见ignite.metrics

max_epochs定义训练过程的最大轮数。

non_blocking定义训练过程中数据加载是否为非阻塞方式。

model为训练完成时的模型对象,即Module对象。

task