"""Preconditioned Conjugate Gradient solver implementation."""
import time
from dataclasses import dataclass, replace
from typing import Callable
import torch
from pydantic import Field
from rlaopt.linalg import (
IdentityConfig,
LinSys,
Preconditioner,
PreconditionerConfig,
get_preconditioner,
)
from rlaopt.solvers.configs_base import SolverConfig, StoppingCriteria
from rlaopt.solvers.solver_base import (
ConvergenceStatus,
LinSysResult,
LinSysSolver,
SolverState,
)
[docs]
class PCGConfig(SolverConfig):
"""Configuration for the Preconditioned Conjugate Gradient solver."""
preconditioner_config: PreconditionerConfig = IdentityConfig()
[docs]
class PCGStoppingCriteria(StoppingCriteria):
"""Stopping criteria specific to the PCG solver.
Attributes:
max_iters: Maximum number of iterations.
tol: Relative tolerance for convergence.
"""
tol: float = Field(default=1e-6, gt=0)
@dataclass(frozen=True)
class PCGState(SolverState):
"""State container for the PCG solver.
Attributes:
iter_: Current iteration count.
r: Residual vector (B - AW).
z: Preconditioned residual (P_inv @ r).
p: Search direction.
rz: Inner product r^T @ z (for non-converged components).
res_norm: Current residual norm per component.
mask: Boolean mask indicating which components have not yet converged.
If there is no tolerance set, all components are considered active.
tol: Optional tolerance for convergence checking. When None, all components
are considered active (mask is not updated based on convergence).
"""
r: torch.Tensor
z: torch.Tensor
p: torch.Tensor
rz: torch.Tensor
res_norm: torch.Tensor
mask: torch.Tensor
tol: float | None = None
[docs]
@dataclass(frozen=True)
class PCGResult(LinSysResult):
"""Result container for the PCG solver.
Attributes:
solution: Converged solution parameters.
convergence_status: Status indicating how the solver terminated.
num_iters: Number of iterations performed.
solver_time: Time taken by the solver in seconds.
residual_norm: Final residual norm per component.
"""
residual_norm: torch.Tensor
class _PCG:
"""Internal PCG implementation that operates on a given preconditioner.
This is the core algorithm implementation. It doesn't inherit from
LinSysSolver since it's an internal building block with a different
interface. Use the public PCG class for the standard API.
"""
def __init__(
self,
lin_sys: LinSys,
preconditioner: Preconditioner,
detach: bool = True,
):
"""Initialize PCG solver with a preconditioner.
Args:
lin_sys: The linear system to solve.
preconditioner: Preconditioner instance to use.
detach: Whether to detach params and state from computation graph
between iterations.
"""
self._init_state = _build_init_state(lin_sys, preconditioner)
self._step = _build_step(lin_sys, preconditioner, detach)
self._solve = lambda tol, max_iters: _build_solve(
lin_sys, self._init_state, self._step, tol, max_iters
)
def init_state(self, params: torch.Tensor) -> PCGState:
"""Initialize the solver state.
Args:
params: Initial parameters (solution estimate).
Returns:
Initial solver state.
"""
return self._init_state(params)
def step(
self, params: torch.Tensor, state: PCGState
) -> tuple[torch.Tensor, PCGState]:
"""Perform a single PCG iteration step.
Args:
params: Current parameters (solution estimate).
state: Current solver state.
Returns:
Tuple of updated parameters and solver state.
"""
return self._step(params, state)
def solve(
self,
params: torch.Tensor | None = None,
stopping_criteria: PCGStoppingCriteria = PCGStoppingCriteria(),
) -> PCGResult:
"""Solve the linear system AW = B using PCG.
Args:
params: Initial parameters (solution estimate). If None, defaults to
parameters in linear system.
stopping_criteria: Criteria to determine when to stop the solver.
Defaults to PCGStoppingCriteria().
Returns:
PCGResult containing the solution and convergence information.
"""
solve_fn = self._solve(
tol=stopping_criteria.tol, max_iters=stopping_criteria.max_iters
)
return solve_fn(params)
[docs]
class PCG(LinSysSolver):
"""Block Preconditioned Conjugate Gradient solver for linear systems.
Solves linear systems of the form:
AW = B
where A is a symmetric positive-definite matrix.
The PCG method uses a preconditioner to improve convergence. The algorithm
iteratively refines the solution by moving along conjugate search directions
that are scaled by the preconditioner.
"""
[docs]
def __init__(self, lin_sys: LinSys, config: PCGConfig, detach: bool = True):
"""Initialize the PCG solver.
Args:
lin_sys: The linear system to solve.
config: Configuration for the solver.
detach: Whether to detach params and state from computation graph
between iterations. Set to False only if you need to differentiate
through the entire solver. Default: True.
"""
# Call parent constructor first
super().__init__(lin_sys, config, detach)
# Create preconditioner from config
preconditioner = get_preconditioner(
config.preconditioner_config,
lin_sys.A,
lin_sys.dtype,
)
# Delegate to internal implementation (composition)
self._impl = _PCG(lin_sys, preconditioner, detach)
[docs]
def init_state(self, params: torch.Tensor) -> PCGState:
"""Initialize the solver state.
Args:
params: Initial parameters (solution estimate).
Returns:
Initial solver state.
"""
return self._impl.init_state(params)
[docs]
def step(
self, params: torch.Tensor, state: PCGState
) -> tuple[torch.Tensor, PCGState]:
"""Perform a single PCG iteration step.
Args:
params: Current parameters (solution estimate).
state: Current solver state.
Returns:
Tuple of updated parameters and solver state.
"""
return self._impl.step(params, state)
[docs]
def solve(
self,
params: torch.Tensor | None = None,
stopping_criteria: PCGStoppingCriteria = PCGStoppingCriteria(),
) -> PCGResult:
"""Solve the linear system AW = B using PCG.
Args:
params: Initial parameters (solution estimate). If None, defaults to
parameters in linear system.
stopping_criteria: Criteria to determine when to stop the solver.
Defaults to PCGStoppingCriteria().
Returns:
PCGResult containing the solution and convergence information.
"""
return self._impl.solve(params, stopping_criteria)
def _compute_convergence_mask(
res_norm: torch.Tensor, lin_sys: LinSys, tol: float
) -> torch.Tensor:
epsilon = tol * lin_sys.rhs_norm
mask = res_norm > epsilon
return mask
def _build_init_state(
lin_sys: LinSys, P: Preconditioner
) -> Callable[[torch.Tensor], PCGState]:
def init_state(params: torch.Tensor) -> PCGState:
"""Initialize the PCG solver state.
Args:
params: Initial parameters (solution estimate).
Returns:
Initial solver state with all components marked as active (mask=True).
"""
# Compute initial residual: r = B - A @ params
r = lin_sys.compute_residual(params)
# Apply preconditioner
z = P.inv @ r
# Initialize search direction
p = z.clone()
# Compute initial residual norm per component
res_norm = torch.linalg.norm(r, dim=0, ord=2)
# Initialize mask to all True (all components are active)
# When tolerance is set via solve(), the mask will be updated in step()
mask = torch.ones(r.shape[1], dtype=torch.bool, device=r.device)
# Compute r^T @ z as a matrix (rz[i,j] corresponds to components i and j)
rz = r.T @ z
return PCGState(iter_=0, r=r, z=z, p=p, rz=rz, res_norm=res_norm, mask=mask)
return init_state
def _build_step(
lin_sys: LinSys,
P: Preconditioner,
detach: bool,
) -> Callable[[torch.Tensor, PCGState], tuple[torch.Tensor, PCGState]]:
def step(
params: torch.Tensor,
state: PCGState,
) -> tuple[torch.Tensor, PCGState]:
"""Perform a single PCG iteration step.
Args:
params: Current parameters (solution estimate).
state: Current solver state.
Returns:
Tuple of updated parameters and solver state.
"""
# Get current mask
mask = state.mask
# If all components have converged, return unchanged state
if not mask.any():
return params, state
# Apply mask to work only with non-converged components
p_masked = state.p[:, mask]
rz_masked = state.rz[mask][:, mask]
# Compute A @ p only for non-converged directions
Ap_masked = lin_sys(p_masked)
# Compute alpha for active components
alpha_masked = torch.linalg.solve(p_masked.T @ Ap_masked, rz_masked)
# Only update the active parts of the solution
params_new = params.clone()
params_new[:, mask] += p_masked @ alpha_masked
# Update residual for active components
r_new = state.r.clone()
r_new[:, mask] -= Ap_masked @ alpha_masked
# Apply preconditioner to new residual for active components
z_new_masked = P.inv @ r_new[:, mask]
# Update z with new values for active components
z_new = state.z.clone()
z_new[:, mask] = z_new_masked
# Compute new rz for active components
rz_new_masked = r_new[:, mask].T @ z_new_masked
# Compute beta for active components
beta_masked = torch.linalg.solve(rz_masked, rz_new_masked)
# Update search direction for active components
p_new = state.p.clone()
p_new[:, mask] = z_new_masked + p_masked @ beta_masked
# Update rz matrix
rz_new = torch.zeros_like(state.rz)
rz_new[torch.outer(mask, mask)] = rz_new_masked.flatten()
# Compute new residual norm
res_norm_new = torch.linalg.norm(r_new, dim=0, ord=2)
# Update mask based on convergence if tolerance is set
if state.tol is not None:
mask_new = _compute_convergence_mask(res_norm_new, lin_sys, state.tol)
else:
# If no tolerance is set, keep all components active
mask_new = mask
# Detach if requested to avoid expanding autodiff tree
if detach:
params_new = params_new.detach()
r_new = r_new.detach()
z_new = z_new.detach()
p_new = p_new.detach()
rz_new = rz_new.detach()
res_norm_new = res_norm_new.detach()
new_state = PCGState(
iter_=state.iter_ + 1,
r=r_new,
z=z_new,
p=p_new,
rz=rz_new,
res_norm=res_norm_new,
mask=mask_new,
tol=state.tol,
)
return params_new, new_state
return step
def _build_solve(
lin_sys: LinSys,
init_state_fn: Callable[[torch.Tensor], PCGState],
step_fn: Callable[[torch.Tensor, PCGState], tuple[torch.Tensor, PCGState]],
tol: float,
max_iters: int,
) -> Callable[[torch.Tensor | None], PCGResult]:
"""Build the solve function with stopping criteria."""
def solve(params: torch.Tensor | None = None) -> PCGResult:
"""Solve the linear system AW = B.
Args:
params: Initial parameters. If None, defaults to zeros.
Returns:
PCGResult containing the solution and convergence information.
"""
ts = time.time()
if params is None:
params = lin_sys.w.clone()
# Initialize state without tolerance
state = init_state_fn(params)
# Set tolerance in state for convergence checking
state = replace(state, tol=tol)
# Iterate until convergence or max iterations
while state.mask.any() and state.iter_ < max_iters:
params, state = step_fn(params, state)
# Determine convergence status
if not state.mask.any():
convergence_status = ConvergenceStatus.CONVERGED
else:
convergence_status = ConvergenceStatus.NOT_CONVERGED
total_time = time.time() - ts
return PCGResult(
solution=params,
convergence_status=convergence_status,
num_iters=state.iter_,
solver_time=total_time,
residual_norm=state.res_norm,
)
return solve