NewtonCG
Implementation of Newton-CG algorithm with backtracking line-search, for PyTorch as a torch.optim.Optimizer
. CG refers to the conjugate gradient method, which is the optimizer’s sub-problem solver.
An example use case of NewtonCG optimizer
import torch
import torch.nn as nn
import torchvision
from newton_cg import NewtonCG
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
Load data
We use MNIST
transform = transforms.ToTensor()
train_set = datasets.MNIST("~/Downloads/", transform=transform)
test_set = datasets.MNIST("~/Downloads/", transform=transform, train=False)
train_loader = DataLoader(train_set, batch_size=len(train_set))
test_loader = DataLoader(test_set, batch_size=len(test_set))
NewtonCG assumes optimization steps are performed over the full dataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_inputs, train_targets = iter(train_loader).next()
train_inputs = train_inputs.reshape(-1, 784).to(device)
train_targets = train_targets.to(device)
test_inputs, test_targets = iter(test_loader).next()
test_inputs = test_inputs.reshape(-1, 784).to(device)
test_targets = test_targets.to(device)
Define model, loss function and optimizer
We use softmax regression with cross entropy loss, as it’s convex
weights = torch.zeros(784, 10, requires_grad=True, device=device)
criterion = nn.CrossEntropyLoss()
optimizer = NewtonCG([weights])
Train the model
num_epochs = 10
for epoch in range(num_epochs):
# compute test accuracy
correct = 0
total = 0
outputs = torch.mm(test_inputs, weights)
_, predicted = torch.max(outputs.data, 1)
total += test_targets.size(0)
correct += (predicted == test_targets).sum().item()
accuracy = 100 * correct / total
# optimizer step
optimizer.zero_grad()
outputs = torch.mm(train_inputs, weights)
loss = criterion(outputs, train_targets)
loss.backward()
closure = lambda : criterion(torch.mm(train_inputs, weights), train_targets)
loss = optimizer.step(closure)
print("epoch: {}, loss: {:.2e}, "
"test accuracy: {:.2f}".format(epoch, loss, accuracy))
epoch: 0, loss: 2.30e+00, test accuracy: 9.80
epoch: 1, loss: 5.14e-01, test accuracy: 85.49
epoch: 2, loss: 3.65e-01, test accuracy: 90.25
epoch: 3, loss: 3.15e-01, test accuracy: 91.45
epoch: 4, loss: 2.88e-01, test accuracy: 92.03
epoch: 5, loss: 2.78e-01, test accuracy: 92.18
epoch: 6, loss: 2.73e-01, test accuracy: 92.31
epoch: 7, loss: 2.68e-01, test accuracy: 92.31
epoch: 8, loss: 2.65e-01, test accuracy: 92.37
epoch: 9, loss: 2.62e-01, test accuracy: 92.38