Source code for torchbnn.functional

import math
import torch

from .modules import *

def _kl_loss(mu_0, log_sigma_0, mu_1, log_sigma_1) :
    """
    An method for calculating KL divergence between two Normal distribtuion.

    Arguments:
        mu_0 (Float) : mean of normal distribution.
        log_sigma_0 (Float): log(standard deviation of normal distribution).
        mu_1 (Float): mean of normal distribution.
        log_sigma_1 (Float): log(standard deviation of normal distribution).
   
    """
    kl = log_sigma_1 - log_sigma_0 + \
    (torch.exp(log_sigma_0)**2 + (mu_0-mu_1)**2)/(2*math.exp(log_sigma_1)**2) - 0.5
    return kl.sum()

[docs]def bayesian_kl_loss(model, reduction='mean', last_layer_only=False) : """ An method for calculating KL divergence of whole layers in the model. Arguments: model (nn.Module): a model to be calculated for KL-divergence. reduction (string, optional): Specifies the reduction to apply to the output: ``'mean'``: the sum of the output will be divided by the number of elements of the output. ``'sum'``: the output will be summed. last_layer_only (Bool): True for return only the last layer's KL divergence. """ device = torch.device("cuda" if next(model.parameters()).is_cuda else "cpu") kl = torch.Tensor([0]).to(device) kl_sum = torch.Tensor([0]).to(device) n = torch.Tensor([0]).to(device) for m in model.modules() : if isinstance(m, (BayesLinear, BayesConv2d)): kl = _kl_loss(m.weight_mu, m.weight_log_sigma, m.prior_mu, m.prior_log_sigma) kl_sum += kl n += len(m.weight_mu.view(-1)) if m.bias : kl = _kl_loss(m.bias_mu, m.bias_log_sigma, m.prior_mu, m.prior_log_sigma) kl_sum += kl n += len(m.bias_mu.view(-1)) if isinstance(m, BayesBatchNorm2d): if m.affine : kl = _kl_loss(m.weight_mu, m.weight_log_sigma, m.prior_mu, m.prior_log_sigma) kl_sum += kl n += len(m.weight_mu.view(-1)) kl = _kl_loss(m.bias_mu, m.bias_log_sigma, m.prior_mu, m.prior_log_sigma) kl_sum += kl n += len(m.bias_mu.view(-1)) if last_layer_only or n == 0 : return kl if reduction == 'mean' : return kl_sum/n elif reduction == 'sum' : return kl_sum else : raise ValueError(reduction + " is not valid")