NewtonCG

Repository

Implementation of Newton-CG algorithm with backtracking line-search, for PyTorch as a torch.optim.Optimizer. CG refers to the conjugate gradient method, which is the optimizer’s sub-problem solver.

An example use case of NewtonCG optimizer

import torch
import torch.nn as nn
import torchvision
from newton_cg import NewtonCG
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

Load data

We use MNIST

transform = transforms.ToTensor()
train_set = datasets.MNIST("~/Downloads/", transform=transform)
test_set = datasets.MNIST("~/Downloads/", transform=transform, train=False)
train_loader = DataLoader(train_set, batch_size=len(train_set))
test_loader = DataLoader(test_set, batch_size=len(test_set))

NewtonCG assumes optimization steps are performed over the full dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_inputs, train_targets = iter(train_loader).next()
train_inputs = train_inputs.reshape(-1, 784).to(device)
train_targets = train_targets.to(device)
test_inputs, test_targets = iter(test_loader).next()
test_inputs = test_inputs.reshape(-1, 784).to(device)
test_targets = test_targets.to(device)

Define model, loss function and optimizer

We use softmax regression with cross entropy loss, as it’s convex

weights = torch.zeros(784, 10, requires_grad=True, device=device)
criterion = nn.CrossEntropyLoss()
optimizer = NewtonCG([weights])

Train the model

num_epochs = 10
for epoch in range(num_epochs):
    # compute test accuracy
    correct = 0
    total = 0
    outputs = torch.mm(test_inputs, weights)
    _, predicted = torch.max(outputs.data, 1)
    total += test_targets.size(0)
    correct += (predicted == test_targets).sum().item()
    accuracy = 100 * correct / total

    # optimizer step
    optimizer.zero_grad()
    outputs = torch.mm(train_inputs, weights)
    loss = criterion(outputs, train_targets)
    loss.backward()
    closure = lambda : criterion(torch.mm(train_inputs, weights), train_targets)
    loss = optimizer.step(closure)

    print("epoch: {},  loss: {:.2e},  "
          "test accuracy: {:.2f}".format(epoch, loss, accuracy))
epoch: 0,  loss: 2.30e+00,  test accuracy: 9.80
epoch: 1,  loss: 5.14e-01,  test accuracy: 85.49
epoch: 2,  loss: 3.65e-01,  test accuracy: 90.25
epoch: 3,  loss: 3.15e-01,  test accuracy: 91.45
epoch: 4,  loss: 2.88e-01,  test accuracy: 92.03
epoch: 5,  loss: 2.78e-01,  test accuracy: 92.18
epoch: 6,  loss: 2.73e-01,  test accuracy: 92.31
epoch: 7,  loss: 2.68e-01,  test accuracy: 92.31
epoch: 8,  loss: 2.65e-01,  test accuracy: 92.37
epoch: 9,  loss: 2.62e-01,  test accuracy: 92.38