Source code for abel.utils
import torch
from typing import Iterable
[docs]def get_weight_norm(param_groups: Iterable) -> torch.Tensor:
"""
Returns weight norm of the param groups
Args:
param_groups (Iterable): List of parameters of the model
"""
norm = None
for group in param_groups:
for p in group['params']:
if norm is None:
norm = torch.norm(p, 2) ** 2
else:
norm += torch.norm(p, 2) ** 2
return norm