Pytorch torch.no_grad()

What is torch.no_grad() for?

When using Pytorch, if we want to do testing after the training, the model.eval() switches the mode into eval mode. The batchnorm and dropout layers would not work on that case and make sure the values pass through the network.

When I read more code, I found that the torch.no_grad() is also widely used in Pytorch project. What is this for?


According to the discussion on https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615.

The model.eval() and with torch.no_grad() two functions have different goals:

  • model.eval() will notify all your layers that you are in eval mode, that way, batchnorm or dropout layers will work in eval mode instead of training mode.
  • torch.no_grad() impacts the autograd engine and deactivate it. It will reduce memory usage and speed up computations but you won’t be able to backprop (which you don’t want in an eval script).

— answered by @albanD

Example of torch.no_grad()

  • SPATIAL TRANSFORMER NETWORKS TUTORIAL:
    https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    def test():
    with torch.no_grad():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
    data, target = data.to(device), target.to(device)
    output = model(data)

    # sum up batch loss
    test_loss += F.nll_loss(output, target, size_average=False).item()
    # get the index of the max log-probability
    pred = output.max(1, keepdim=True)[1]
    correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
    .format(test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

I also reproduce this code on my github repo: https://github.com/ShuoGH/deepLearningAl


Using torch.no_grad() would accelerate the computation of neural network.

Reference

  1. Pytorch Forum: ‘model.eval()’ vs ‘with torch.no_grad()’: https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615