"""Alternating direction method of multipliers (ADMM) implementation.
Our ADMM implementation is based on the description in
"GeNIOS: an (almost) second-order operator-splitting solver for
large-scale convex optimization" by Diamandis et al., 2023.
This implements an inexact ADMM solver that can handle large-scale
problems by solving the ADMM linear system approximately using
preconditioned conjugate gradient (PCG).
"""
import math
import time
from dataclasses import dataclass
from typing import Callable
from warnings import warn
import torch
from pydantic import Field
from rlaopt.expression import Expression
from rlaopt.ext_tensordict import TensorDict
from rlaopt.linalg import (
LinSys,
NystromConfig,
Preconditioner,
PreconditionerConfig,
get_preconditioner,
)
from rlaopt.solvers.configs_base import SolverConfig, StoppingCriteria
from rlaopt.solvers.pcg import _PCG, PCGStoppingCriteria
from rlaopt.solvers.solver_base import (
ConvergenceStatus,
OptimResult,
OptimSolver,
SolverState,
)
from rlaopt.splitting import ADMMSplit
PCG_TOL_EPS = 1e-3
[docs]
class ADMMConfig(SolverConfig):
"""Configuration for the ADMM solver."""
rho: float = Field(
1.0,
description="Augmented Lagrangian penalty at initialization.",
gt=0.0,
)
rho_update_factor: float = Field(
2.0,
description="Factor to update rho in primal-dual balancing.",
gt=1.0,
)
rho_update_threshold: float = Field(
10.0,
description="Threshold for updating rho in primal-dual balancing.",
gt=0.0,
)
rho_update_freq: int = Field(
25,
description="Frequency (in iterations) for updating rho.",
ge=1,
)
alpha: float = Field(
1.6,
description="Over-relaxation parameter.",
gt=0.0,
lt=2.0,
)
sigma: float = Field(
1e-6,
description="Regularization parameter for the inexact ADMM linear system.",
ge=0.0,
)
gamma: float = Field(
1.2,
description="Exponent for the linear system solve tolerance.",
gt=1.0,
)
preconditioner_config: PreconditionerConfig = Field(
NystromConfig(rank_init=50, base_damping=0.0),
description="Configuration for the linear system preconditioner.",
)
preconditioner_update_freq: int = Field(
20,
description="Frequency (in iterations) for updating the preconditioner.",
ge=1,
)
[docs]
class ADMMStoppingCriteria(StoppingCriteria):
"""Stopping criteria for the ADMM solver."""
eps_abs: float = Field(
1e-4, description="Absolute tolerance for primal and dual residuals.", gt=0.0
)
eps_rel: float = Field(
1e-4, description="Relative tolerance for primal and dual residuals.", gt=0.0
)
@dataclass(frozen=True)
class ADMMState(SolverState):
"""State container for the ADMM solver.
Attributes:
iter_: Current iteration count.
aux_variables: Auxiliary variables (z) in ADMM.
dual_variables: Dual variables (u) in ADMM.
primal_residual_norm: Norm of the primal residual.
dual_residual_norm: Norm of the dual residual.
rho: Current augmented Lagrangian penalty.
"""
aux_variables: TensorDict
dual_variables: TensorDict
primal_residual_norm: float
dual_residual_norm: float
rho: float
_preconditioner: Preconditioner | None = None
[docs]
@dataclass(frozen=True)
class ADMMResult(OptimResult):
"""Result container for the ADMM solver.
Attributes:
variable_values: Optimized variable values.
convergence_status: Status indicating how the solver terminated.
num_iters: Number of iterations performed.
solver_time: Time taken by the solver in seconds.
primal_residual_norm: Norm of the primal residual.
dual_residual_norm: Norm of the dual residual.
"""
primal_residual_norm: float
dual_residual_norm: float
[docs]
class ADMM(OptimSolver):
"""Alternating Direction Method of Multipliers (ADMM) solver.
Solves problems of the form:
minimize f(x) + sum_i g_i(A_i x - b_i)
where f is smooth (differentiable) and each g_i is proxable.
"""
[docs]
def __init__(self, obj: Expression, config: ADMMConfig, detach: bool = True):
"""Initialize the ADMM solver.
Args:
obj: The optimization problem to solve.
config: Configuration for the ADMM solver.
detach: Whether to detach params and state from computation graph
between iterations. Set to False only if you need to differentiate
through the entire solver. Default: True.
"""
if not isinstance(obj, Expression):
raise ValueError("ADMM solver requires an Expression objective.")
if not isinstance(config, ADMMConfig):
raise ValueError("ADMM solver requires an ADMMConfig configuration.")
super().__init__(obj, config, detach)
op_split = ADMMSplit(obj)
self._init_state = _build_init_state(op_split, config.rho)
self._step = _build_step(
op_split,
config.rho_update_factor,
config.rho_update_threshold,
config.rho_update_freq,
config.alpha,
config.sigma,
config.gamma,
config.preconditioner_config,
config.preconditioner_update_freq,
detach,
)
self._solve = lambda eps_abs, eps_rel, max_iters: _build_solve(
op_split, self._init_state, self._step, eps_abs, eps_rel, max_iters
)
[docs]
def init_state(self, variable_values: TensorDict) -> ADMMState:
"""Initialize the solver state.
Args:
variable_values: Initial variable values.
Returns:
Initial solver state.
"""
return self._init_state(variable_values)
[docs]
def step(
self, variables_values: TensorDict, state: ADMMState
) -> tuple[TensorDict, ADMMState]:
"""Perform a single ADMM optimization step.
Args:
variables_values: Current variable values.
state: Current ADMM solver state.
Returns:
Tuple of updated variable values and solver state.
"""
return self._step(variables_values, state)
[docs]
def solve(
self,
variable_values: TensorDict | None = None,
stopping_criteria: ADMMStoppingCriteria = ADMMStoppingCriteria(),
) -> ADMMResult:
"""Solve the optimization problem using ADMM.
Args:
variable_values: Initial variable values.
stopping_criteria: Stopping criteria for the solver.
Returns:
ADMMResult: Result of the optimization containing optimized variable values
among other metrics.
"""
solve_fn = self._solve(
eps_abs=stopping_criteria.eps_abs,
eps_rel=stopping_criteria.eps_rel,
max_iters=stopping_criteria.max_iters,
)
return solve_fn(variable_values)
def _build_init_state(
op_split: ADMMSplit, rho: float
) -> Callable[[TensorDict], ADMMState]:
"""Build function to initialize ADMM solver state.
Args:
op_split: ADMMSplit instance for the optimization problem.
rho: Initial augmented Lagrangian penalty.
Returns:
Function that initializes the ADMM solver state.
"""
def init_state(variable_values: TensorDict) -> ADMMState:
aux_variables = op_split.r_variable_values
dual_variables = op_split.init_dual_variables()
primal_residual_norm, dual_residual_norm = (
_compute_primal_and_dual_residual_norms(
op_split,
variable_values,
aux_variables,
dual_variables,
rho,
)
)
return ADMMState(
iter_=0,
aux_variables=aux_variables,
dual_variables=dual_variables,
primal_residual_norm=primal_residual_norm,
dual_residual_norm=dual_residual_norm,
rho=rho,
_preconditioner=None,
)
return init_state
def _build_step(
op_split: ADMMSplit,
rho_update_factor: float,
rho_update_threshold: float,
rho_update_freq: int,
alpha: float,
sigma: float,
gamma: float,
preconditioner_config: PreconditionerConfig,
preconditioner_update_freq: int,
detach: bool,
) -> Callable[[TensorDict, ADMMState], tuple[TensorDict, ADMMState]]:
"""Build function to perform a single ADMM optimization step."""
def step(
variable_values: TensorDict, state: ADMMState
) -> tuple[TensorDict, ADMMState]:
"""Perform a single ADMM optimization step."""
# Unpack state
iter_ = state.iter_
aux_variables = state.aux_variables
dual_variables = state.dual_variables
aux_variables_flat = aux_variables.to_flat_tensor()
dual_variables_flat = dual_variables.to_flat_tensor()
primal_residual_norm = state.primal_residual_norm
dual_residual_norm = state.dual_residual_norm
rho = state.rho
preconditioner = state._preconditioner
# Solve x-subproblem
new_variable_values, preconditioner = _solve_x_subproblem(
op_split,
variable_values,
aux_variables_flat,
dual_variables_flat,
rho,
sigma,
gamma,
iter_,
primal_residual_norm,
dual_residual_norm,
preconditioner,
preconditioner_config,
preconditioner_update_freq,
detach,
)
# Apply over-relaxation
new_variable_values_flat = new_variable_values.to_flat_tensor()
variable_values_overrelaxed_flat = _compute_overrelaxation(
op_split, alpha, new_variable_values_flat, aux_variables_flat
)
# Update auxiliary variables
new_aux_variables = _update_aux_variables(
op_split,
variable_values_overrelaxed_flat,
dual_variables_flat,
aux_variables,
rho,
)
# Update dual variables
new_dual_variables = _update_dual_variables(
op_split,
dual_variables,
variable_values_overrelaxed_flat,
new_aux_variables,
)
# Recompute residual norms
new_primal_residual_norm, new_dual_residual_norm = (
_compute_primal_and_dual_residual_norms(
op_split,
new_variable_values,
new_aux_variables,
new_dual_variables,
rho,
)
)
# Update rho using primal-dual balancing
rho, new_dual_variables = _update_rho(
rho,
new_dual_variables,
new_primal_residual_norm,
new_dual_residual_norm,
iter_,
rho_update_factor,
rho_update_threshold,
rho_update_freq,
)
# Detach if requested to avoid expanding autodiff tree
if detach:
new_variable_values = new_variable_values.detach()
new_aux_variables = new_aux_variables.detach()
new_dual_variables = new_dual_variables.detach()
# Return updated state
new_state = ADMMState(
iter_=iter_ + 1,
aux_variables=new_aux_variables,
dual_variables=new_dual_variables,
primal_residual_norm=new_primal_residual_norm,
dual_residual_norm=new_dual_residual_norm,
rho=rho,
_preconditioner=preconditioner,
)
return new_variable_values, new_state
return step
def _build_solve(
op_split: ADMMSplit,
init_state_fn: Callable[[TensorDict], ADMMState],
step_fn: Callable[[TensorDict, ADMMState], tuple[TensorDict, ADMMState]],
eps_abs: float,
eps_rel: float,
max_iters: int,
) -> Callable[[TensorDict | None], ADMMResult]:
"""Build the solve function with ADMM stopping criteria.
Implements Boyd et al. stopping criteria:
- Primal feasibility: ||r_primal|| <= eps_primal
- Dual feasibility: ||r_dual|| <= eps_dual
where:
- eps_primal = sqrt(m) * eps_abs + eps_rel * max(||Ax||, ||z||, ||b||)
- eps_dual = sqrt(n) * eps_abs + eps_rel * ||rho * A^T * u||
"""
def solve(var_vals: TensorDict | None = None) -> ADMMResult:
"""Solve the optimization problem using ADMM."""
ts = time.time()
if var_vals is None:
var_vals = op_split.f_and_affine_variable_values
state = init_state_fn(var_vals)
m = op_split.A.shape[0] # Constraint dimension
n = op_split.A.shape[1] # Variable dimension
# Main iteration loop
while state.iter_ < max_iters:
var_vals, state = step_fn(var_vals, state)
# Compute stopping criteria thresholds (Boyd et al., Section 3.3.1)
# Primal threshold
Ax_norm = torch.linalg.norm(op_split.A @ var_vals.to_flat_tensor()).item()
z_norm = state.aux_variables.flat_norm().item()
b_norm = torch.linalg.norm(op_split.b).item()
eps_primal = (m**0.5) * eps_abs + eps_rel * max(Ax_norm, z_norm, b_norm)
# Dual threshold
ATu_norm = torch.linalg.norm(
state.rho * op_split.A_T @ state.dual_variables.to_flat_tensor()
).item()
eps_dual = (n**0.5) * eps_abs + eps_rel * ATu_norm
# Check convergence
if (
state.primal_residual_norm <= eps_primal
and state.dual_residual_norm <= eps_dual
):
convergence_status = ConvergenceStatus.CONVERGED
break
else:
# Max iterations reached without convergence
convergence_status = ConvergenceStatus.NOT_CONVERGED
total_time = time.time() - ts
return ADMMResult(
variable_values=var_vals,
convergence_status=convergence_status,
num_iters=state.iter_,
solver_time=total_time,
primal_residual_norm=state.primal_residual_norm,
dual_residual_norm=state.dual_residual_norm,
)
return solve
def _solve_x_subproblem(
op_split: ADMMSplit,
variable_values: TensorDict,
aux_variables_flat: torch.Tensor,
dual_variables_flat: torch.Tensor,
rho: float,
sigma: float,
gamma: float,
iter_: int,
primal_residual_norm: float,
dual_residual_norm: float,
preconditioner: Preconditioner | None,
preconditioner_config: PreconditionerConfig,
preconditioner_update_freq: int,
detach: bool,
) -> tuple[TensorDict, Preconditioner]:
"""Solve the x-subproblem approximately using PCG.
Returns:
Tuple of updated variable values and preconditioner.
"""
variables_values_flat = variable_values.to_flat_tensor()
rhs = -op_split.grad_f(variable_values).to_flat_tensor()
rhs += op_split.hvp_f(variable_values, variables_values_flat)
rhs += sigma * variables_values_flat
rhs += rho * op_split.A_T @ (aux_variables_flat - dual_variables_flat + op_split.b)
lin_sys = LinSys(
op_split.hvp_f_ATA_linop(variable_values, rho),
B=rhs,
reg=sigma,
w=variables_values_flat,
)
# Compute preconditioner if needed
if iter_ % preconditioner_update_freq == 0 or preconditioner is None:
preconditioner = get_preconditioner(
preconditioner_config,
lin_sys.A,
lin_sys.dtype,
)
# Solve the linear system using PCG
pcg_solver = _PCG(lin_sys, preconditioner=preconditioner, detach=detach)
rel_tol = min((primal_residual_norm * dual_residual_norm) ** 0.5, 1.0)
if math.isclose(rel_tol, 0.0):
rel_tol = PCG_TOL_EPS
rel_tol /= (iter_ + 1) ** gamma
stopping_criteria = PCGStoppingCriteria(
max_iters=lin_sys.A.shape[0],
tol=rel_tol,
)
pcg_solve_result = pcg_solver.solve(stopping_criteria=stopping_criteria)
if pcg_solve_result.convergence_status != ConvergenceStatus.CONVERGED:
warn(
f"PCG for ADMM did not converge in iteration {iter_}. "
f"Status: {pcg_solve_result.convergence_status}."
"Consider changing the preconditioner."
)
# Need to squeeze last dimension since PCG returns 2D tensors
new_variable_values_flat = pcg_solve_result.solution.squeeze(-1)
new_variable_values = variable_values.from_flat_tensor(new_variable_values_flat)
return new_variable_values, preconditioner
def _compute_overrelaxation(
op_split: ADMMSplit,
alpha: float,
new_variable_values_flat: torch.Tensor,
aux_variables_flat: torch.Tensor,
) -> torch.Tensor:
"""Apply over-relaxation to the variable values.
Returns:
Over-relaxed variable values.
"""
return alpha * op_split.A @ new_variable_values_flat + (1 - alpha) * (
aux_variables_flat + op_split.b
)
def _update_aux_variables(
op_split: ADMMSplit,
variable_values_overrelaxed_flat: torch.Tensor,
dual_variables_flat: torch.Tensor,
aux_variables: TensorDict,
rho: float,
) -> TensorDict:
"""Update auxiliary variables using proximal operators.
Returns:
Updated auxiliary variables.
"""
aux_variables_intermediate_flat = (
variable_values_overrelaxed_flat + dual_variables_flat - op_split.b
)
aux_variables_intermediate = aux_variables.from_flat_tensor(
aux_variables_intermediate_flat
)
return op_split.prox(aux_variables_intermediate, 1.0 / rho)
def _update_dual_variables(
op_split: ADMMSplit,
dual_variables: TensorDict,
variable_values_overrelaxed_flat: torch.Tensor,
new_aux_variables: TensorDict,
) -> TensorDict:
"""Update dual variables.
Returns:
Updated dual variables.
"""
dual_variables_flat = dual_variables.to_flat_tensor()
new_dual_variables_flat = (
dual_variables_flat
+ variable_values_overrelaxed_flat
- new_aux_variables.to_flat_tensor()
- op_split.b
)
return dual_variables.from_flat_tensor(new_dual_variables_flat)
def _update_rho(
rho: float,
dual_variables: TensorDict,
primal_residual_norm: float,
dual_residual_norm: float,
iter_: int,
rho_update_factor: float,
rho_update_threshold: float,
rho_update_freq: int,
) -> tuple[float, TensorDict]:
"""Update penalty parameter rho using primal-dual balancing.
Returns:
Tuple of updated rho and dual variables.
"""
if (iter_ + 1) % rho_update_freq == 0:
if primal_residual_norm > rho_update_threshold * dual_residual_norm:
rho *= rho_update_factor
dual_variables = dual_variables / rho_update_factor
elif dual_residual_norm > rho_update_threshold * primal_residual_norm:
rho /= rho_update_factor
dual_variables = dual_variables * rho_update_factor
return rho, dual_variables
def _compute_primal_and_dual_residual_norms(
op_split: ADMMSplit,
variable_values: TensorDict,
aux_variables: TensorDict,
dual_variables: TensorDict,
rho: float,
) -> tuple[float, float]:
"""Compute the primal and dual residual norms for ADMM."""
primal_residual = (
op_split.A @ variable_values.to_flat_tensor()
- aux_variables.to_flat_tensor()
- op_split.b
)
dual_residual = (
op_split.grad_f(variable_values).to_flat_tensor()
+ rho * op_split.A_T @ dual_variables.to_flat_tensor()
)
primal_residual_norm = torch.linalg.norm(primal_residual).item()
dual_residual_norm = torch.linalg.norm(dual_residual).item()
return primal_residual_norm, dual_residual_norm