"""Base class for optimization atoms."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from typing_extensions import Self
from rlaopt.expression import Expression, ExprTree, Variable
from rlaopt.ext_tensordict import TensorDict
@dataclass(kw_only=True, frozen=True)
class AtomDecomposition:
"""Result of decomposing an atom r(Ax - b) into r(z) and Ax - b.
Attributes:
atom: The decomposed atom r(z) with a new variable z.
affine_expr: The affine expression Ax - b that was replaced by z.
"""
atom: "Atom"
affine_expr: Expression
[docs]
class Atom(Expression, ABC):
"""Abstract base class for optimization atoms.
An atom represents a mathematical function that can be used in optimization
problems. Atoms have various properties (smooth, proxable) and
can be composed to form more complex objective functions.
Atoms extend Expression with:
- Input registration system for Variables and Expressions
- Buffer management for constants and hyperparameters
Subclasses must implement:
- is_smooth() - whether the function is differentiable everywhere
- is_proxable() - whether the proximal operator is computable
- forward() - evaluation of the atom
- _prox() - prox operator of the atom
"""
[docs]
def __init__(
self,
exprs: dict[str, Expression],
buffers: dict[str, torch.Tensor | float | None],
variable_names: list[str] | None = None,
):
"""Initialize the atom.
Subclasses should call this constructor to ensure proper initialization.
"""
super().__init__()
# Automatically register all input expressions
variable_names = variable_names or []
variable_only = {var_name: True for var_name in variable_names}
for name, expr in exprs.items():
self._register_input(name, expr, variable_only.get(name, False))
# Automatically register all buffers
for name, buffer in buffers.items():
self._register_atom_buffer(name, buffer)
[docs]
@abstractmethod
def is_proxable(self) -> bool:
"""Check if the expression has a computable proximal operator.
The proximal operator is used in proximal gradient methods and ADMM
for non-smooth optimization. An expression is proxable if its proximal
operator can be computed efficiently in closed form.
Returns:
bool: True if the expression is proxable, False otherwise.
"""
pass
def prox(self, variable_values: TensorDict, prox_scaling: float) -> TensorDict:
"""Proximal operator corresponding to the atom.
Args:
variable_values: TensorDict containing the location at which to evaluate
the proximal operator.
prox_scaling: Scaling parameter for the proximal operator.
Returns:
TensorDict: Result of applying the proximal operator at the given
variable_values. Note that the returned TensorDict may contain only a
subset of the variables in variable_values, depending on which variables
are relevant to the atom.
Raises:
NotImplementedError: If the atom is not proxable.
"""
if not self.is_proxable():
raise NotImplementedError(
f"Proximal operator for {self.__class__.__name__} with given inputs "
"is not implemented."
)
relevant_vars = self.select_relevant_variables(variable_values)
prox_result = self._prox(relevant_vars, prox_scaling)
return prox_result
@abstractmethod
def _prox(
self, relevant_variable_values: TensorDict, prox_scaling: float
) -> TensorDict:
"""Proximal operator implementation for the atom."""
pass
def get_input(self, name: str) -> Expression:
"""Retrieve a registered input expression by name.
Args:
name: Name of the input expression to retrieve.
Returns:
Expression: The registered input expression.
Raises:
KeyError: If no input with the given name exists.
"""
if not hasattr(self, name):
raise KeyError(f"No input expression named '{name}' found.")
return getattr(self, name)
def decompose(self) -> list[AtomDecomposition] | None:
"""Decompose the atom r(Ax - b) into r(z) and Ax - b.
This method is useful for splitting methods like ADMM, where we want to
introduce an auxiliary variable z and constraint z = Ax - b, then optimize
over r(z) separately.
The method checks if the atom's input is an affine expression (not just a
Variable). If so, it creates a new variable z with the same shape as the
input, constructs a new atom r(z), and returns both the new atom and the
original affine expression.
The decomposition is returned as a list of AtomDecomposition instances,
each containing the new atom and the corresponding affine expression.
Returns:
list[AtomDecomposition] | None: Decomposition containing the new atom r(z)
and affine expression if decomposable, None otherwise.
"""
return None
def _register_atom_buffer(self, name: str, buffer):
"""Register a buffer (non-trainable constant) with the atom.
Buffers store constants, hyperparameters, or fixed data that should be
tracked by the module but not optimized. Unlike parameters, buffers do
not receive gradients.
Args:
name: Name for the buffer.
buffer: Value to register (float, Parameter, or Tensor).
"""
if isinstance(buffer, float):
self.register_buffer(name, torch.tensor(float(buffer)))
elif isinstance(buffer, torch.Tensor):
self.register_buffer(name, buffer)
elif buffer is None:
self.register_buffer(name, None)
else:
raise TypeError(
f"Expected float, Tensor, or None, but got {type(buffer).__name__}"
)
def _register_input(self, name: str, x: Expression, variable_only: bool):
"""Register an input (Expression) with the atom."""
if variable_only and not isinstance(x, Variable):
raise TypeError(f"Expected Variable, but got {type(x).__name__} instead.")
if not isinstance(x, Expression):
raise TypeError(f"Expected Expression, but got {type(x).__name__}")
self.add_module(name, x)
def _scale(self, scaling: float) -> Self:
"""Scale the atom by a scalar constant.
This method should be overridden by subclasses that support scalar
multiplication. The default implementation returns NotImplemented.
Args:
scaling: Scalar value to multiply the atom by.
Returns:
Self: A new atom scaled by the given value, or NotImplemented if
scalar multiplication is not supported.
"""
return NotImplemented
[docs]
def tree(self) -> ExprTree:
"""Return tree representation for Atom.
If the atom has input expressions, includes them in the tree.
Otherwise, returns just the atom class name as a leaf node.
Returns:
ExprTree: Tree with atom class name and optional input child.
"""
input_expr_trees = [
expr.tree()
for _, expr in self.named_children()
if isinstance(expr, Expression)
]
if len(input_expr_trees) == 0:
return ExprTree(self.__class__.__name__)
return ExprTree(self.__class__.__name__, *input_expr_trees)