Source code for rlaopt.atoms.polyhedron

"""Polyhedron constraint atom for optimization."""

from enum import Enum, auto
from functools import partial
from math import isclose
from typing import Callable
from warnings import warn

import torch
from typing_extensions import Self

from rlaopt.atoms.atom import Atom, AtomDecomposition
from rlaopt.expression import Variable
from rlaopt.ext_tensordict import TensorDict


class _PolyhedronType(Enum):
    """Type of constraints in a Polyhedron atom."""

    EQUALITY_ONLY = auto()  # Only A @ x = b
    INEQUALITY_ONLY = auto()  # Only lower <= C @ x <= upper
    MIXED = auto()  # Both equality and inequality constraints


[docs] class Polyhedron(Atom): """Polyhedral constraint atom for linear equality and inequality constraints. A polyhedron is defined by: - Equality constraints: A @ x = b - Inequality constraints: lower <= C @ x <= upper The atom evaluates to 0 if all constraints are satisfied, and infinity otherwise (indicator function of the polyhedral set). Args: x: Variable to constrain. A: Equality constraint matrix (optional). If provided, b must also be provided. b: Equality constraint vector (optional). Required if A is provided. C: Inequality constraint matrix (optional). If None, uses identity (box constraints). lower: Lower bound vector for inequalities (optional). upper: Upper bound vector for inequalities (optional). Raises: ValueError: If A is provided but b is None. ValueError: If constraint dimensions are inconsistent. ValueError: If no constraints are provided (trivial polyhedron). """
[docs] def __init__( self, x: Variable, A: torch.Tensor | None = None, b: torch.Tensor | None = None, C: torch.Tensor | None = None, lower: torch.Tensor | int | float | None = None, upper: torch.Tensor | int | float | None = None, ): """Initialize the polyhedral constraint atom. Args: x: Variable to constrain. A: Equality constraint matrix (optional). b: Equality constraint vector (optional). C: Inequality constraint matrix (optional). lower: Lower bound vector for inequalities (optional). upper: Upper bound vector for inequalities (optional). Raises: ValueError: If A is provided but b is None. ValueError: If constraint dimensions are inconsistent. ValueError: If no constraints are provided. """ if (A is not None) and (b is None): raise ValueError("b cannot be None when A is not None") elif (A is None) and (b is not None): raise ValueError("A cannot be None when b is not None") ## Convert float/int bounds to tensors FIRST (before using .device/.dtype) if isinstance(lower, (int, float)): lower = torch.tensor(float(lower), device=x.forward().device) if isinstance(upper, (int, float)): upper = torch.tensor(float(upper), device=x.forward().device) # Validate input dimensional consistency _validate(A, C, b, lower, upper) # if upper is provided but not lower, set lower to -infinity if (upper is not None) and (lower is None): lower = torch.tensor(-torch.inf, device=upper.device, dtype=upper.dtype) # if lower is provided but not upper, set upper to infinity elif (lower is not None) and (upper is None): upper = torch.tensor(torch.inf, device=lower.device, dtype=lower.dtype) # Determine constraint type (validation ensures at least one exists) has_equality = A is not None and b is not None has_inequality = lower is not None # implies upper is not None if has_equality and has_inequality: self._constraint_type = _PolyhedronType.MIXED elif has_equality: self._constraint_type = _PolyhedronType.EQUALITY_ONLY else: # has_inequality must be True self._constraint_type = _PolyhedronType.INEQUALITY_ONLY super().__init__( exprs={"x": x}, buffers={"A": A, "b": b, "C": C, "lower": lower, "upper": upper}, variable_names=["x"], ) self._eval = _build_eval( self.get_buffer("A"), self.get_buffer("C"), self.get_buffer("b"), self.get_buffer("lower"), self.get_buffer("upper"), )
[docs] def forward(self) -> torch.Tensor: """Evaluate the polyhedral constraint at the current variable value. Returns: torch.Tensor: 0.0 if constraints are satisfied, infinity otherwise. """ value = self.get_input("x").forward() return self._eval(value)
[docs] def is_smooth(self) -> bool: """Check if the polyhedral constraint is smooth. Returns: bool: Always False, as indicator functions are non-smooth. """ return False
[docs] def is_proxable(self) -> bool: """Check if the polyhedral constraint has a computable proximal operator. Returns: bool: Always False (general polyhedral projection not implemented). """ return False
def _prox( self, relevant_variable_values: TensorDict, prox_scaling: float ) -> TensorDict: """Polyhedral constraint does not have a prox operator in general.""" raise NotImplementedError("Polyhedron is not proxable") def decompose(self) -> list[AtomDecomposition]: """Decompose the Polyhedron atom. Returns: list[AtomDecomposition]: Decomposition containing the new atom r(z) and affine expression if decomposable. """ if self._constraint_type == _PolyhedronType.EQUALITY_ONLY: return [self._decompose_equality(self.get_input("x"))] elif self._constraint_type == _PolyhedronType.INEQUALITY_ONLY: return [self._decompose_inequality(self.get_input("x"))] else: # MIXED eq_decomp = self._decompose_equality(self.get_input("x")) ineq_decomp = self._decompose_inequality(self.get_input("x")) return [eq_decomp, ineq_decomp] def _decompose_equality(self, input_var: Variable) -> AtomDecomposition: """Decompose equality constraint.""" input_expr = self.get_buffer("A") @ input_var - self.get_buffer("b") new_var = Variable.like(input_expr) from rlaopt.atoms.box import Box input_value = input_expr.forward() new_atom = Box( new_var, lower=torch.zeros_like(input_value), upper=torch.zeros_like(input_value), ) return AtomDecomposition(atom=new_atom, affine_expr=input_expr) def _decompose_inequality(self, input_var: Variable) -> AtomDecomposition: """Decompose inequality constraint.""" C = self.get_buffer("C") lower = self.get_buffer("lower") upper = self.get_buffer("upper") # C could be None (identity) input_expr = C @ input_var if C is not None else input_var new_var = Variable.like(input_expr) from rlaopt.atoms.box import Box new_atom = Box( new_var, lower=lower, upper=upper, ) return AtomDecomposition(atom=new_atom, affine_expr=input_expr) def _scale(self, scaling: float) -> Self: """Scale the polyhedral constraint.""" if isclose(scaling, 0.0): warn( f"Scaling a {self.__class__.__name__} constraint by zero has no effect.", # noqa: E501 ) return self # Scaling does not change the constraint set
def _validate(A, C, b, lower, upper): """Validate dimensional consistency of constraint matrices and vectors. Args: A: Equality constraint matrix (optional). C: Inequality constraint matrix (optional). b: Equality constraint vector (optional). lower: Lower bound vector (optional). upper: Upper bound vector (optional). Raises: ValueError: If dimensions are inconsistent. """ if A is not None and b is not None: _validate_equality_constraints(A, b) if C is not None: _validate_inequality_constraints(C, lower, upper) def _validate_equality_constraints(A, b): """Validate equality constraint dimensions (A @ x = b).""" if A.dim() == 0: raise ValueError("A must be at least 1-dimensional") if A.dim() == 1: _validate_equality_vector(A, b) else: _validate_equality_matrix(A, b) def _validate_equality_vector(A, b): """Validate hyperplane constraint (a^T x = b).""" if b.dim() != 0: raise ValueError("For 1D A (hyperplane), b must be a scalar") if _is_zero(A): raise ValueError("To define a valid constraint, a must be non-zero.") def _validate_equality_matrix(A, b): """Validate matrix equality constraints (A @ x = b).""" if A.shape[0] > A.shape[1]: raise ValueError("Valid A must have more columns than rows!") if b.dim() != 1: raise ValueError("For 2D A, b must be 1D") if A.shape[0] != b.shape[0]: raise ValueError("A and b must have matching row counts") if _is_zero(A): raise ValueError("To define a valid constraint, A must be non-zero") def _validate_inequality_constraints(C, lower, upper): """Validate inequality constraint dimensions (lower <= C @ x <= upper).""" if C.dim() == 0: raise ValueError("C must be at least 1-dimensional") if C.dim() == 1: _validate_inequality_vector(C, lower, upper) else: _validate_inequality_matrix(C, lower, upper) def _validate_inequality_vector(C, lower, upper): """Validate halfspace constraint (lower <= c^T x <= upper).""" if lower is not None and lower.dim() != 0: raise ValueError("For 1D C (halfspace), lower must be a scalar") if upper is not None and upper.dim() != 0: raise ValueError("For 1D C (halfspace), upper must be a scalar") if _is_zero(C): raise ValueError("To define a valid constraint, c must be non-zero") def _validate_inequality_matrix(C, lower, upper): """Validate matrix inequality constraints (lower <= C @ x <= upper).""" if lower is not None and lower.dim() > 0 and C.shape[0] != lower.shape[0]: raise ValueError("C and lower must have matching row counts") if upper is not None and upper.dim() > 0 and C.shape[0] != upper.shape[0]: raise ValueError("C and upper must have matching row counts") if _is_zero(C): raise ValueError("To define a valid inequality constraint, C must be non-zero") def _is_zero(X: torch.Tensor) -> bool: """Checks input if input tensor X is identically 0.""" return torch.all(X == 0).item() def _build_eval(A, C, b, lower, upper) -> Callable[[torch.Tensor], torch.Tensor]: """Build the constraint evaluation function. Constructs a function that evaluates the indicator function for the polyhedron defined by the provided constraints. Args: A: Equality constraint matrix (optional). C: Inequality constraint matrix (optional). b: Equality constraint vector (optional). lower: Lower bound vector (optional). upper: Upper bound vector (optional). Returns: Callable: Function that evaluates the indicator function. Raises: ValueError: If no constraints are provided. """ eq_exists = A is not None and b is not None ineq_exists = lower is not None # implies upper is not None eval_fns = [] if eq_exists: if A.dim() > 1: eval_fns.append(partial(_eval_eq, A=A, b=b)) else: eval_fns.append(partial(_eval_hyperplane, a=A, b=b)) if ineq_exists: if C is not None: if C.dim() > 1: eval_fns.append(partial(_eval_ineq, C=C, lower=lower, upper=upper)) else: eval_fns.append(partial(_eval_halfspace, c=C, lower=lower, upper=upper)) else: eval_fns.append(partial(_eval_id_ineq, lower=lower, upper=upper)) if not eval_fns: raise ValueError( "Provided constraints define a trivial polyhedron (no constraints)." ) def _eval(x: torch.Tensor) -> torch.Tensor: """Evaluate all constraints. Args: x: Input vector. Returns: torch.Tensor: Sum of indicator functions (0 if all satisfied, inf otherwise). """ return sum(fn(x) for fn in eval_fns) return _eval def _indicator( satisfied: bool, device: torch.device, dtype: torch.dtype ) -> torch.Tensor: """Return 0 if constraint satisfied, infinity otherwise. Args: satisfied: Whether the constraint is satisfied. device: Device for result tensor. dtype: Data type for result tensor. Returns: torch.Tensor: 0.0 if satisfied, infinity otherwise. """ if satisfied: return torch.tensor(0.0, device=device, dtype=dtype) else: return torch.tensor(torch.inf, device=device, dtype=dtype) def _eval_id_ineq( x: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor ) -> torch.Tensor: """Evaluate identity inequality constraint: lower <= x <= upper.""" satisfied = torch.all((lower <= x) & (x <= upper)).item() return _indicator(satisfied, x.device, x.dtype) def _eval_ineq( x: torch.Tensor, C: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor ) -> torch.Tensor: """Evaluate matrix inequality constraint: lower <= C @ x <= upper.""" Cx = C @ x satisfied = torch.all((lower <= Cx) & (Cx <= upper)).item() return _indicator(satisfied, x.device, x.dtype) def _eval_halfspace( x: torch.Tensor, c: torch.Tensor, lower: torch.Tensor, upper: torch.Tensor ) -> torch.Tensor: """Evaluate halfspace inequality constraint: lower <= c^T x <= upper.""" cTx = torch.dot(c, x) satisfied = (lower <= cTx) and (cTx <= upper) return _indicator(satisfied.item(), x.device, x.dtype) def _eval_eq(x: torch.Tensor, A: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Evaluate matrix equality constraint: A @ x = b.""" satisfied = torch.all(A @ x == b).item() return _indicator(satisfied, x.device, x.dtype) def _eval_hyperplane(x: torch.Tensor, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Evaluate hyperplane equality constraint: a^T x = b.""" satisfied = torch.dot(a, x) == b return _indicator(satisfied.item(), x.device, x.dtype)