Source code for torchbnn.modules.loss

import warnings

from torch.nn import Module
from torch.nn import functional as F
from torch.nn import _reduction as _Reduction

from .. import functional as BF

class _Loss(Module):
    def __init__(self, reduction='mean'):
        super(_Loss, self).__init__()
        self.reduction = reduction
            
[docs]class BKLLoss(_Loss): """ Loss for calculating KL divergence of baysian neural network model. Arguments: reduction (string, optional): Specifies the reduction to apply to the output: ``'mean'``: the sum of the output will be divided by the number of elements of the output. ``'sum'``: the output will be summed. last_layer_only (Bool): True for return only the last layer's KL divergence. """ __constants__ = ['reduction'] def __init__(self, reduction='mean', last_layer_only=False): super(BKLLoss, self).__init__(reduction) self.last_layer_only = last_layer_only
[docs] def forward(self, model): """ Arguments: model (nn.Module): a model to be calculated for KL-divergence. """ return BF.bayesian_kl_loss(model, reduction=self.reduction, last_layer_only=self.last_layer_only)