Source code for icenet.deep.da

# Deep domain adaptation functions
#
# m.mieskolainen@imperial.ac.uk, 2024

import torch
from torch.autograd import Function


[docs] class GradientReversalFunction(Function): """ Unsupervised Domain Adaptation by Backpropagation https://arxiv.org/abs/1409.7495 Notes: The forward pass is an identity map. In the backprogation, the gradients are reversed by grad -> -alpha * grad. Example: net = nn.Sequential(nn.Linear(10, 10), GradientReversal(alpha=1.0)) """
[docs] @staticmethod def forward(ctx, x, alpha): ctx.alpha = alpha return x.clone()
[docs] @staticmethod def backward(ctx, grads): alpha = ctx.alpha alpha = grads.new_tensor(alpha) dx = -alpha * grads return dx, None
[docs] class GradientReversal(torch.nn.Module): def __init__(self, alpha = 1.0): super().__init__() self.alpha = alpha
[docs] def forward(self, x): return GradientReversalFunction.apply(x, self.alpha)