Source code for rlaopt.atoms.l2_norm

"""Implementation of the L2-norm atom."""

import torch

from rlaopt.atoms.nonsmooth_regularizer import NonsmoothRegularizer
from rlaopt.expression import Expression
from rlaopt.ext_tensordict import TensorDict


[docs] class L2Norm(NonsmoothRegularizer): """L2-norm regularization atom. Computes the scaled L2-norm: scaling * ||x||₂ = scaling * sqrt(Σᵢ xᵢ²) Args: x: Expression to apply the L2-norm to. scaling: Scaling factor for the L2-norm (default: 1.0). Examples: >>> x = Variable((100,), name='weights') >>> l2 = L2Norm(x, scaling=0.01) >>> penalty = l2.forward() """
[docs] def __init__(self, x: Expression, scaling: float | torch.Tensor = 1.0): """Initializes the L2-norm atom with optional scaling. Args: x: Expression to apply the L2-norm to. scaling: Scaling factor for the L2-norm (default: 1.0). """ super().__init__(x, scaling_params={"scaling": scaling})
[docs] def forward(self) -> torch.Tensor: """Evaluates the scaled L2-norm.""" value = self.get_input("x").forward() return self.get_buffer("scaling") * torch.sqrt(torch.sum(value**2))
def _prox( self, relevant_variable_values: TensorDict, prox_scaling: float ) -> TensorDict: """Compute the proximal operator of the scaled L2-norm. For f(x) = scaling * ||x||_2 and prox parameter tau = prox_scaling, the proximal operator is: prox_{τ f}(v) = (1 - λ / max(||v||_2, λ)) * v, where λ = scaling * prox_scaling If ||v||_2 ≤ λ, the result is 0 """ scaling = self.get_buffer("scaling") lam = scaling * prox_scaling def prox_l2(x: torch.Tensor) -> torch.Tensor: norm = torch.linalg.norm(x) # ||x||_2 (scalar) denom = torch.maximum(norm, lam) # max(||x||, λ) scale = 1 - (lam / denom) # 1 - (λ / max{||x||, λ}) return scale * x return relevant_variable_values.apply(prox_l2)