Source code for rlaopt.atoms.l1_norm
"""Implementation of the L1-norm atom."""
import torch
from rlaopt.atoms.nonsmooth_regularizer import NonsmoothRegularizer
from rlaopt.expression import Expression
from rlaopt.ext_tensordict import TensorDict
[docs]
class L1Norm(NonsmoothRegularizer):
"""L1-norm regularization atom.
Computes the scaled L1-norm: scaling * ||x||₁ = scaling * Σᵢ |xᵢ|
Args:
x: Expression to apply the L1-norm to.
scaling: Scaling factor for the L1-norm (default: 1.0).
Examples:
>>> x = Variable((100,), name='weights')
>>> l1 = L1Norm(x, scaling=0.01)
>>> penalty = l1.forward()
"""
[docs]
def __init__(self, x: Expression, scaling: float | torch.Tensor = 1.0):
"""Initializes the L1-norm atom with optional scaling.
Args:
x: Expression to apply the L1-norm to.
scaling: Scaling factor for the L1-norm (default: 1.0).
"""
super().__init__(x, scaling_params={"scaling": scaling})
[docs]
def forward(self) -> torch.Tensor:
"""Evaluates the scaled L1-norm."""
value = self.get_input("x").forward()
return self.get_buffer("scaling") * torch.sum(torch.abs(value))
def _prox(
self, relevant_variable_values: TensorDict, prox_scaling: float
) -> TensorDict:
"""Compute the proximal operator of the scaled L1-norm.
For the function f(x) = scaling * ||x||_1, the proximal operator is:
prox_f(v) = sign(v) * max(|v| - scaling * prox_scaling, 0)
"""
scaling = self.get_buffer("scaling")
threshold = scaling * prox_scaling
def soft_threshold(x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - threshold)
return relevant_variable_values.apply(soft_threshold)