# NewtonCG

Implementation of Newton-CG algorithm with backtracking line-search, for PyTorch as a `torch.optim.Optimizer`

. CG refers to the conjugate gradient method, which is the optimizer’s sub-problem solver.

# An example use case of NewtonCG optimizer

```
import torch
import torch.nn as nn
import torchvision
from newton_cg import NewtonCG
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
```

## Load data

### We use MNIST

```
transform = transforms.ToTensor()
train_set = datasets.MNIST("~/Downloads/", transform=transform)
test_set = datasets.MNIST("~/Downloads/", transform=transform, train=False)
train_loader = DataLoader(train_set, batch_size=len(train_set))
test_loader = DataLoader(test_set, batch_size=len(test_set))
```

### NewtonCG assumes optimization steps are performed over the full dataset

```
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_inputs, train_targets = iter(train_loader).next()
train_inputs = train_inputs.reshape(-1, 784).to(device)
train_targets = train_targets.to(device)
test_inputs, test_targets = iter(test_loader).next()
test_inputs = test_inputs.reshape(-1, 784).to(device)
test_targets = test_targets.to(device)
```

## Define model, loss function and optimizer

### We use softmax regression with cross entropy loss, as it’s convex

```
weights = torch.zeros(784, 10, requires_grad=True, device=device)
criterion = nn.CrossEntropyLoss()
optimizer = NewtonCG([weights])
```

## Train the model

```
num_epochs = 10
for epoch in range(num_epochs):
# compute test accuracy
correct = 0
total = 0
outputs = torch.mm(test_inputs, weights)
_, predicted = torch.max(outputs.data, 1)
total += test_targets.size(0)
correct += (predicted == test_targets).sum().item()
accuracy = 100 * correct / total
# optimizer step
optimizer.zero_grad()
outputs = torch.mm(train_inputs, weights)
loss = criterion(outputs, train_targets)
loss.backward()
closure = lambda : criterion(torch.mm(train_inputs, weights), train_targets)
loss = optimizer.step(closure)
print("epoch: {}, loss: {:.2e}, "
"test accuracy: {:.2f}".format(epoch, loss, accuracy))
```

```
epoch: 0, loss: 2.30e+00, test accuracy: 9.80
epoch: 1, loss: 5.14e-01, test accuracy: 85.49
epoch: 2, loss: 3.65e-01, test accuracy: 90.25
epoch: 3, loss: 3.15e-01, test accuracy: 91.45
epoch: 4, loss: 2.88e-01, test accuracy: 92.03
epoch: 5, loss: 2.78e-01, test accuracy: 92.18
epoch: 6, loss: 2.73e-01, test accuracy: 92.31
epoch: 7, loss: 2.68e-01, test accuracy: 92.31
epoch: 8, loss: 2.65e-01, test accuracy: 92.37
epoch: 9, loss: 2.62e-01, test accuracy: 92.38
```