In the quickstart tutorial, optimizer.zero_grad() is called at the end of a mini-batch in order to ensure that the batches are independent. In the developer documentation, they call module.zero_grad() instead. This was the first time I’d seen the latter.
The implementation of Optimizer.zero_grad() is:
def zero_grad(self, set_to_none: bool = False):
if not hasattr(self, "_zero_grad_profile_name"):
self._hook_for_profile()
with torch.autograd.profiler.record_function(self._zero_grad_profile_name):
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
if set_to_none:
p.grad = None
else:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()
And for Module.zero_grad():
def zero_grad(self, set_to_none: bool = False) -> None:
if getattr(self, '_is_replica', False):
warnings.warn(...[blah blah multithreading blah blah]...)
for p in self.parameters():
if p.grad is not None:
if set_to_none:
p.grad = None
else:
if p.grad.grad_fn is not None:
p.grad.detach_()
else:
p.grad.requires_grad_(False)
p.grad.zero_()
The body of the innermost for loop is the same for each:
# PyTorch can use either None or a tensor of zeros for the zero gradient
if p.grad is not None:
if set_to_none:
p.grad = None
else:
if p.grad.grad_fn is not None:
# Detaches the gradient tensor from the graph that created it
p.grad.detach_()
else:
# Disable the computational graph so that zeroing the gradient
# doesn't count as a computation
p.grad.requires_grad_(False)
# Zero out the gradient
p.grad.zero_()
The difference is that the one in Module just does it for its own parameters. The one in Optimizer is a little fancier:
- It times the operation for its built-in profiler
- It iterates over all the parameters in all the parameter groups known to this optimizer
- Which are discovered by traversing the PyTorch computational graph (autograd functions) from the leaf nodes passed to the Optimizer constructor
Typically, you’ll use the one on Optimizer. The one on Module is really just there for specialized situations, such as when you might want to zero out a submodule’s gradient in the middle of a larger operation.
Although the gradient is stored on the parameter tensors themselves (code, docs), it is associated with the computational graph, so Tensor does not have a zero_grad method. From an OO standpoint, this is pathological, but the PyTorch team prioritizes practical simplicity over engineering purity.