task.torch
common
DataLoaderProvider
geyser_lava.task.torch.common.DataLoaderProvider
构建数据加载器
参数类型 | 定义 |
---|---|
provides | ('loader',) |
requires | ('dataset_params', 'loader_params') |
revert_requires | () |
dataset_params
定义构建 PyTorch 风格数据集对象的参数,其中reference
为数据集类的路径,数据集类必须为Dataset或IterableDataset的子类,其他参数会直接输入数据集类的构造函数中。
loader_params
定义输入 PyTorch 风格数据加载器,即DataLoader的参数。
loader
为任务构建的数据加载器,即DataLoader对象。
task
ModelProvider
geyser_lava.task.torch.common.ModelProvider
构建模型
参数类型 | 定义 |
---|---|
provides | ('model',) |
requires | ('model_params',) |
revert_requires | () |
model_params
定义构建 PyTorch 模型对象的参数,其中reference
为模型类或构建模型函数的路径,模型类必须为Module的子类,其他参数会直接输入模型类的构造函数中。
model
为任务构建的模型,即Module对象。
task
LossProvider
geyser_lava.task.torch.common.LossProvider
构建损失函数
参数类型 | 定义 |
---|---|
provides | ('loss',) |
requires | ('loss_params',) |
revert_requires | () |
loss_params
定义构建 PyTorch 损失函数对象的参数,其中reference
为损失函数类或损失函数的构建函数的路径,损失函数类必须为Module的子类,其他参数会直接输入损失函数类的构造函数中。
loss
为任务构建的损失函数,即Module对象。
task
OptimizerProvider
geyser_lava.task.torch.common.OptimizerProvider
构建优化器
参数类型 | 定义 |
---|---|
provides | ('optimizer',) |
requires | ('optimizer_params', 'model') |
revert_requires | () |
optimizer_params
定义构建 PyTorch 优化器对象的参数,其中reference
为优化器类的路径,优化器类必须为Optimizer的子类,其他参数会直接输入优化器类的构造函数中。
model
定义该优化器优化的模型对象。
optimizer
为任务构建的损失函数,即Optimizer对象。
task
trainer
SupervisedTrainer
geyser_lava.task.torch.trainer.SupervisedTrainer
有监督训练模型
参数类型 | 定义 |
---|---|
provides | ('model',) |
requires | ('model', 'train_loader', 'validate_loader', 'optimizer', 'device', 'loss', 'metrics_params', 'max_epochs', 'non_blocking') |
revert_requires | () |
model
定义需要训练的模型对象,即Module对象。
train_loader
与validate_loader
定义训练过程中训练数据集与验证数据集的数据加载器,即DataLoader对象。
optimizer
定义训练使用的优化器对象,即Optimizer对象。
device
定义模型训练过程中使用的设备,详见Device。
loss
定义训练过程中优化的损失函数对象,即Module对象。
metrics_params
定义训练过程中显示的指标对象参数列表,每个元素参数中的reference
为指标类的路径,关于指标类详见ignite.metrics。
max_epochs
定义训练过程的最大轮数。
non_blocking
定义训练过程中数据加载是否为非阻塞方式。
model
为训练完成时的模型对象,即Module对象。
task