Source code for rlaopt.atoms.box
"""Box constraint atom for optimization."""
import torch
from rlaopt.atoms.polyhedron import Polyhedron
from rlaopt.expression import Variable
from rlaopt.ext_tensordict import TensorDict
[docs]
class Box(Polyhedron):
"""Box constraint atom representing elementwise bounds.
A box constraint restricts each element of a variable to lie within
specified lower and upper bounds: lower <= x <= upper.
This is a special case of Polyhedra with identity inequality constraints
(C = I), but with an efficient closed-form proximal operator (projection
via clamping).
Args:
x: Variable to constrain.
lower: Lower bound vector (optional). If None, defaults to -infinity.
upper: Upper bound vector (optional). If None, defaults to +infinity.
Examples:
>>> # Standard box constraint: 0 <= x <= 1
>>> x = Variable((10,), name='x')
>>> box = Box(x, lower=0.0, upper=1.0)
>>> # One-sided constraint: x >= 0 (non-negativity)
>>> box_nonneg = Box(x, lower=0.0)
>>> # One-sided constraint: x <= 1
>>> box_upper = Box(x, upper=1.0)
>>> # Use proximal operator for projection
>>> out_of_bounds = torch.randn(10)
>>> projected = box.prox(out_of_bounds, prox_scaling=1.0)
"""
[docs]
def __init__(
self,
x: Variable,
lower: torch.Tensor | int | float | None = None,
upper: torch.Tensor | int | float | None = None,
):
"""Initialize the box constraint atom.
Args:
x: Variable to constrain.
lower (torch.Tensor | int | float): Lower bound vector (optional).
upper: (torch.Tensor | int | float): Upper bound vector (optional).
"""
super().__init__(x, A=None, b=None, C=None, lower=lower, upper=upper)
[docs]
def is_proxable(self) -> bool:
"""Check if the box constraint has a computable proximal operator.
Returns:
bool: Always True, as box constraints have a closed-form
proximal operator (projection via clamping).
"""
return True
def _prox(
self, relevant_variable_values: TensorDict, prox_scaling: float
) -> TensorDict:
"""Compute the proximal operator of the box constraint.
The proximal operator projects onto the box by clamping each element
to lie within [lower, upper]. The prox_scaling parameter is unused
because the projection is independent of scaling.
"""
lower = self.get_buffer("lower")
upper = self.get_buffer("upper")
return relevant_variable_values.apply(lambda x: torch.clamp(x, lower, upper))