Source code for rlaopt.atoms.linear_equality

"""Linear equality constraint atom for optimization."""

import torch

from rlaopt.atoms.polyhedron import Polyhedron
from rlaopt.expression import Variable
from rlaopt.ext_tensordict import TensorDict


[docs] class LinearEquality(Polyhedron): """Linear equality constraint atom enforcing A @ x = b. Represents a system of linear equality constraints. Unlike the general Polyhedron class, this provides an efficient closed-form proximal operator (projection onto the affine subspace) via QR factorization. Checks for feasibility at initialization, by determining the rank of A via QR factorization. If A is not full-row rank, then an error is raised. When A is full-row rank, the R factor is cached and used for efficienctly computing the projection onto the constraint set. The projection solves: argmin_z ||z - location||² subject to A @ z = b Args: x: Variable to constrain. A: Constraint matrix defining the linear system. b: Right-hand side vector of the equality constraints. Examples: >>> # Single equality constraint: x[0] + x[1] = 1 >>> x = Variable((2,), name='x') >>> A = torch.tensor([[1.0, 1.0]]) >>> b = torch.tensor([1.0]) >>> constraint = LinearEquality(x, A, b) >>> # Multiple equality constraints >>> x = Variable((5,), name='x') >>> A = torch.randn(3, 5) >>> b = torch.randn(3) >>> constraint = LinearEquality(x, A, b) >>> # Use proximal operator for projection onto affine subspace >>> unconstrained_point = torch.randn(5) >>> projected = constraint.prox(unconstrained_point, prox_scaling=1.0) >>> # Verify: A @ projected should equal b """
[docs] def __init__(self, x: Variable, A: torch.Tensor, b: torch.Tensor): """Initialize the affine equality constraint atom. Args: x: Variable to constrain. A: Constraint matrix defining the linear system. b: Right-hand side vector of the equality constraints. """ super().__init__(x, A=A, b=b, C=None, lower=None, upper=None) self._R = _factor_and_check_feasibility(self.get_buffer("A"))
[docs] def is_proxable(self) -> bool: """Check if the constraint has a computable proximal operator. Returns: bool: Always True, as affine equality constraints have a closed-form proximal operator (projection via Cholesky). """ return True
def _prox( self, relevant_variable_values: TensorDict, prox_scaling: float ) -> TensorDict: """Compute the proximal operator of the affine constraint. Projects the given location onto the affine subspace {x : A @ x = b} by solving the constrained least-squares problem. """ A = self.get_buffer("A") b = self.get_buffer("b") r = A @ relevant_variable_values.to_flat_tensor() - b temp = torch.linalg.solve_triangular( self._R, torch.linalg.solve_triangular( self._R.T, r.reshape(r.shape[0], 1), upper=False ), upper=True, ) def projection(location: torch.Tensor) -> torch.Tensor: return location - A.T @ temp.reshape( temp.shape[0], ) return relevant_variable_values.apply(projection)
# NOTE(Zach): This helper may be moved or refactored later on. def _factor_and_check_feasibility(A: torch.Tensor) -> torch.Tensor: _, R = torch.linalg.qr(A.T, mode="r") diag_R = torch.diagonal(R) m, n = R.shape # Set relative tolerance based on dtype if R.dtype not in (torch.bfloat16, torch.float32, torch.float64): raise ValueError(f"Unsupported dtype: {R.dtype}") rtol = torch.finfo(R.dtype).eps # Scale by matrix size rtol = rtol * max(m, n) max_diag = torch.max(torch.abs(diag_R)) tol = rtol * max_diag rank = torch.sum(torch.abs(diag_R) > tol).item() if rank == n: # Should equal number of columns return R else: raise ValueError( "The provided constraint matrix A is rank deficient, and so does not " "define a valid equality constraint" )