Source code for rlaopt.atoms.norm_balls

"""Norm ball constraint atoms for optimization."""

from abc import ABC, abstractmethod
from numbers import Real

import torch

from rlaopt.atoms.atom import Atom, AtomDecomposition
from rlaopt.atoms.box import Box
from rlaopt.expression import Expression, Variable
from rlaopt.ext_tensordict import TensorDict


class _NormBall(Atom, ABC):
    """Base class for norm ball constraint atoms.

    Provides shared logic for atoms that represent the indicator function of
    a norm ball ``{x : ||x|| <= radius}``. Subclasses define the specific norm
    via ``_norm`` and the projection in ``_prox``.
    """

    def __init__(self, x: Expression, radius: float | int | torch.Tensor = 1.0):
        """Initialize the norm ball constraint atom."""
        radius = _validate_radius(radius)
        super().__init__(exprs={"x": x}, buffers={"radius": radius})

    def is_smooth(self) -> bool:
        """Norm ball indicators are non-smooth."""
        return False

    def is_proxable(self) -> bool:
        """Check if the proximal operator is computable."""
        return isinstance(self.get_input("x"), Variable)

    def forward(self) -> torch.Tensor:
        """Evaluate the indicator function of the norm ball."""
        value = self.get_input("x").forward()
        radius = self.get_buffer("radius").to(device=value.device, dtype=value.dtype)
        norm = self._norm(value)
        satisfied = (norm <= radius).item()
        return _indicator(satisfied, value.device, value.dtype)

    def decompose(self) -> list[AtomDecomposition] | None:
        """Decompose the constraint if the input is affine."""
        input_expr = self.get_input("x")
        if not input_expr.is_affine():
            return None

        new_var = Variable.like(input_expr)
        radius = self.get_buffer("radius")
        new_atom = type(self)(new_var, radius=radius)
        return [AtomDecomposition(atom=new_atom, affine_expr=input_expr)]

    @abstractmethod
    def _norm(self, value: torch.Tensor) -> torch.Tensor:
        """Compute the norm used by this norm ball."""


[docs] class L1NormBall(_NormBall): """L1-norm ball constraint enforcing ||x||_1 <= radius. This atom represents the indicator function of the L1-norm ball: 0 if ||x||_1 <= radius, +inf otherwise. Args: x: Expression to constrain. radius: Non-negative radius of the L1-norm ball (default: 1.0). """ def _norm(self, value: torch.Tensor) -> torch.Tensor: return torch.sum(torch.abs(value)) def _prox( self, relevant_variable_values: TensorDict, prox_scaling: float ) -> TensorDict: """Project onto the L1-norm ball (prox of indicator).""" radius = self.get_buffer("radius") def project_onto_l1_ball(x: torch.Tensor) -> torch.Tensor: radius_t = radius.to(device=x.device, dtype=x.dtype) if radius_t.item() <= 0: return torch.zeros_like(x) flat = x.reshape(-1) abs_flat = torch.abs(flat) if (torch.sum(abs_flat) <= radius_t).item(): return x sorted_abs, _ = torch.sort(abs_flat, descending=True) cumsum = torch.cumsum(sorted_abs, dim=0) idx = torch.arange( 1, sorted_abs.numel() + 1, device=flat.device, dtype=flat.dtype ) cond = sorted_abs * idx > (cumsum - radius_t) rho = torch.nonzero(cond, as_tuple=False)[-1].item() theta = (cumsum[rho] - radius_t) / (rho + 1) projected = torch.sign(flat) * torch.nn.functional.relu(abs_flat - theta) return projected.reshape(x.shape) return relevant_variable_values.apply(project_onto_l1_ball)
[docs] class L2NormBall(_NormBall): """L2-norm (Euclidean) ball constraint enforcing ||x||_2 <= radius. This atom represents the indicator function of the Euclidean ball: 0 if ||x||_2 <= radius, +inf otherwise. Args: x: Expression to constrain. radius: Non-negative radius of the L2-norm ball (default: 1.0). """ def _norm(self, value: torch.Tensor) -> torch.Tensor: return torch.linalg.norm(value) def _prox( self, relevant_variable_values: TensorDict, prox_scaling: float ) -> TensorDict: """Project onto the L2-norm ball (prox of indicator).""" radius = self.get_buffer("radius") def project_onto_l2_ball(x: torch.Tensor) -> torch.Tensor: radius_t = radius.to(device=x.device, dtype=x.dtype) if radius_t.item() <= 0: return torch.zeros_like(x) norm = torch.linalg.norm(x) if (norm <= radius_t).item(): return x return (radius_t / norm) * x return relevant_variable_values.apply(project_onto_l2_ball)
[docs] class LInfNormBall(Box): """L-infinity norm ball constraint enforcing ||x||_inf <= radius. This atom represents the indicator function of the L-infinity norm ball: 0 if ||x||_inf <= radius, +inf otherwise. Args: x: Variable to constrain. radius: Non-negative radius of the L-infinity norm ball (default: 1.0). """
[docs] def __init__(self, x: Variable, radius: float | int | torch.Tensor = 1.0): """Initialize the L-infinity norm ball constraint atom.""" radius = _validate_radius(radius) if torch.is_tensor(radius): upper = radius else: upper = float(radius) lower = -upper super().__init__(x, lower=lower, upper=upper)
def _validate_radius(radius: float | int | torch.Tensor) -> float | torch.Tensor: """Validate and normalize the radius parameter.""" if isinstance(radius, Real): if radius < 0: raise ValueError("radius must be non-negative") return float(radius) if torch.is_tensor(radius): if radius.numel() != 1: raise ValueError("radius must be a scalar tensor") if torch.any(radius < 0): raise ValueError("radius must be non-negative") return radius raise TypeError( f"radius must be float, int, or Tensor, got {type(radius).__name__}" ) def _indicator( satisfied: bool, device: torch.device, dtype: torch.dtype ) -> torch.Tensor: """Return 0 if satisfied, infinity otherwise.""" if satisfied: return torch.tensor(0.0, device=device, dtype=dtype) return torch.tensor(torch.inf, device=device, dtype=dtype)