Source code for rlaopt.atoms.sum_squares
"""Implementation of the sum squared atom."""
import torch
from rlaopt.atoms.atom import Atom
from rlaopt.expression import Expression, Variable
from rlaopt.ext_tensordict import TensorDict
[docs]
class SumSquares(Atom):
"""Sum of squared elements atom."""
[docs]
def __init__(self, x: Expression):
"""Initializes the sum squared atom.
Args:
x: Expression to apply the sum of squares to.
"""
super().__init__(exprs={"x": x}, buffers={})
[docs]
def is_smooth(self) -> bool:
"""Returns True depending on the smoothness of the expression."""
return self.get_input("x").is_smooth()
[docs]
def forward(self) -> torch.Tensor:
"""Forward pass to compute the sum of squares."""
value = self.get_input("x").forward()
return torch.sum(value**2)
[docs]
def is_proxable(self) -> bool:
"""Returns True if the input is a Variable."""
input_ = self.get_input("x")
if isinstance(input_, Variable):
return True
return False
def _prox(
self, relevant_variable_values: TensorDict, prox_scaling: float
) -> TensorDict:
"""Proximal operator for the sum of squares."""
return relevant_variable_values.apply(lambda x: 1 / (1 + 2 * prox_scaling) * x)