To use torch.optim you have to construct an optimizer object that will hold the current state and will update the parameters based on the computed gradients.
Constructing it
To construct an Optimizer you have to give it an iterable containing the parameters (all should be s) to optimize. Then, you can specify optimizer-specific options such as the learning rate, weight decay, etc.
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optimizer = optim.Adam([var1, var2], lr=0.0001)
Taking an optimization step
All optimizers implement a step() method, that updates the parameters. It can be used in two ways:
This is a simplified version supported by most optimizers. The function can be called once the gradients are computed using e.g. .
for input, target in dataset: ? ?optimizer.zero_grad() ? ?output = model(input) ? ?loss = loss_fn(output, target) ? ?loss.backward() ? ?optimizer.step()
Some optimization algorithms such as Conjugate Gradient and LBFGS need to reevaluate the function multiple times, so you have to pass in a closure that allows them to recompute your model. The closure should clear the gradients, compute the loss, and return it.
for input, target in dataset: ? ?def closure(): ? ? ? ?optimizer.zero_grad() ?# 重置上一步中的梯度值 ? ? ? ?output = model(input) ? ? ? ?loss = loss_fn(output, target) ? ? ? ?loss.backward() ? ? ? ?return loss ? ?optimizer.step(closure)
CLASS torch.optim.Optimizer(params, defaults)[SOURCE]
Base class for all optimizers.
Optimizer.add_param_group | Add a param group to the Optimizer s param_groups. |
---|---|
Optimizer.load_state_dict | Loads the optimizer state. |
Optimizer.state_dict | Returns the state of the optimizer as a dict. |
Optimizer.step | Performs a single optimization step (parameter update). |
Optimizer.zero_grad | Resets the gradients of all optimized torch.Tensor s. |
Adadelta | Implements Adadelta algorithm. |
---|---|
Adagrad | Implements Adagrad algorithm. |
Adam | Implements Adam algorithm. |
AdamW | Implements AdamW algorithm. |
SparseAdam | SparseAdam implements a masked version of the Adam algorithm suitable for sparse gradients. |
Adamax | Implements Adamax algorithm (a variant of Adam based on infinity norm). |
ASGD | Implements Averaged Stochastic Gradient Descent. |
LBFGS | Implements L-BFGS algorithm, heavily inspired by minFunc. |
NAdam | Implements NAdam algorithm. |
RAdam | Implements RAdam algorithm. |
RMSprop | Implements RMSprop algorithm. |
Rprop | Implements the resilient backpropagation algorithm. |
SGD | Implements stochastic gradient descent (optionally with momentum). |
ADADELTA
CLASS torch.optim.Adadelta(params, lr=1.0, rho=0.9, eps=1e-06, weight_decay=0, foreach=None, ***, maximize=False, differentiable=False)[SOURCE]
初学的话,只用设置params和lr即可,其他参数可用默认值,需要时再做学习
import torch import torchvision from torch import nn from torch.nn import Sequential from torch.utils.data import DataLoader ? dataset = torchvision.datasets.CIFAR10(root="https://blog.csdn.net/m0_64711692/article/details/dataset", train=True, download=False, ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? transform=torchvision.transforms.ToTensor()) dataloader = DataLoader(dataset, batch_size=64, shuffle=True) ? class XiaoMo(nn.Module): ? ?def __init__(self): ? ? ? ?super(XiaoMo, self).__init__() ? ? ? ? ?self.model1 = Sequential( ? ? ? ? ? ?nn.Conv2d(3, 32, 5, 1, 2), ? ? ? ? ? ?nn.MaxPool2d(2), ? ? ? ? ? ?nn.Conv2d(32, 32, 5, 1, 2), ? ? ? ? ? ?nn.MaxPool2d(2), ? ? ? ? ? ?nn.Conv2d(32, 64, 5, 1, 2), ? ? ? ? ? ?nn.MaxPool2d(2), ? ? ? ? ? ?nn.Flatten(), ? ? ? ? ? ?nn.Linear(64 * 4 * 4, 64), ? ? ? ? ? ?nn.Linear(64, 10) ? ? ? ) ? ? ?def forward(self, x): ? ? ? ?x = self.model1(x) ? ? ? ? ?return x ? loss = nn.CrossEntropyLoss() ? xiaomo = XiaoMo() ? optim = torch.optim.SGD(xiaomo.parameters(), lr=0.01) ?# 创建优化器 ? for epoch in range(20): ?# 重复20轮 ? ?loss_running = 0.0 ? ?for imgs, target in dataloader: ?# 对数据进行一轮学习 ? ? ? ?outputs = xiaomo(imgs) ? ? ? ?loss_res = loss(outputs, target) ? ? ? ? ?optim.zero_grad() ? ? ? ?loss_res.backward() ?# 计算梯度 ? ? ? ? ?optim.step() ? ? ? ?loss_running += loss_res ? ? ?print(loss_running)