Source code for rlaopt.atoms.elastic_net
"""Elastic net regularization atom."""
import torch
from rlaopt.atoms.nonsmooth_regularizer import NonsmoothRegularizer
from rlaopt.expression import Expression
from rlaopt.ext_tensordict import TensorDict
[docs]
class ElasticNet(NonsmoothRegularizer):
"""Elastic net regularization combining L1 and L2 penalties.
The elastic net penalty is defined as:
l1_scaling * ||x||₁ + (l2_scaling / 2) * ||x||₂²
Args:
x: Expression to apply the elastic net penalty to.
l1_scaling: Scaling factor for the L1-norm penalty. Defaults to 1.0.
l2_scaling: Scaling factor for the L2-norm penalty. Defaults to 1.0.
Raises:
TypeError: If x is not an Expression.
Examples:
>>> x = Variable((100,), name='weights')
>>> # Standard elastic net with equal L1 and L2 contribution
>>> elastic = ElasticNet(x, l1_scaling=0.5, l2_scaling=0.5)
>>> penalty = elastic.forward()
>>> # Lasso-like (emphasize sparsity)
>>> elastic_lasso = ElasticNet(x, l1_scaling=1.0, l2_scaling=0.1)
>>> # Ridge-like (emphasize smoothness)
>>> elastic_ridge = ElasticNet(x, l1_scaling=0.1, l2_scaling=1.0)
"""
[docs]
def __init__(
self,
x: Expression,
l1_scaling: float | torch.Tensor = 1.0,
l2_scaling: float | torch.Tensor = 1.0,
):
"""Initialize the elastic net atom.
Args:
x: Expression to apply the elastic net penalty to.
l1_scaling: Scaling factor for the L1-norm penalty. Defaults to 1.0.
l2_scaling: Scaling factor for the L2-norm penalty. Defaults to 1.0.
"""
super().__init__(
x,
scaling_params={"l1_scaling": l1_scaling, "l2_scaling": l2_scaling},
)
[docs]
def forward(self) -> torch.Tensor:
"""Evaluate the elastic net penalty at the current variable value.
Returns:
torch.Tensor: The elastic net penalty value:
l1_scaling * ||x||₁ + (l2_scaling / 2) * ||x||₂²
"""
value = self.get_input("x").forward()
l1_norm = torch.sum(torch.abs(value))
l2_norm = torch.sum(value**2)
return (
self.get_buffer("l1_scaling") * l1_norm
+ (self.get_buffer("l2_scaling") / 2) * l2_norm
)
def _prox(
self, relevant_variable_values: TensorDict, prox_scaling: float
) -> TensorDict:
"""Compute the proximal operator of the elastic net.
The proximal operator applies soft-thresholding followed by scaling
to account for the L2 regularization term.
"""
l2_term = 1 + prox_scaling * self.get_buffer("l2_scaling")
threshold = self.get_buffer("l1_scaling") * prox_scaling
def soft_threshold_with_scaling(x: torch.Tensor) -> torch.Tensor:
return (
torch.nn.functional.relu(x - threshold)
- torch.nn.functional.relu(-x - threshold)
) / l2_term
return relevant_variable_values.apply(soft_threshold_with_scaling)