Source code for rlaopt.expression.constant
"""Module for Constant class."""
import torch
from rlaopt.expression.expression import Expression
from rlaopt.expression.tree import ExprTree
[docs]
class Constant(Expression):
"""Constant value expression.
Represents a constant (non-trainable) value in an expression tree.
Constants are stored as buffers rather than parameters, so they
don't receive gradients and don't appear in parameter optimization.
Args:
value: The constant value (float, int, or torch.Tensor).
Attributes:
_value: The constant value stored as a buffer.
Examples:
>>> c = Constant(3.14)
>>> c.forward()
tensor(3.1400)
>>> c2 = Constant(torch.ones(5))
>>> c2.forward().shape
torch.Size([5])
"""
[docs]
def __init__(self, value: float | int | torch.Tensor):
"""Initialize a constant expression.
Args:
value: The constant value to store.
"""
super().__init__()
if not isinstance(value, torch.Tensor):
value = torch.tensor(value)
# register as buffer so it is visible but not a parameter
self.register_buffer("_value", value)
@property
def value(self):
"""Get the constant value.
Returns:
torch.Tensor: The constant value.
"""
return getattr(self, "_value")
[docs]
def is_smooth(self) -> bool:
"""Constants are smooth (trivially differentiable).
Returns:
bool: Always True.
"""
return True
[docs]
def forward(self) -> torch.Tensor:
"""Evaluate the constant (returns itself).
Returns:
torch.Tensor: The constant value.
"""
return self.value
[docs]
def tree(self) -> ExprTree:
"""Return tree representation for Constant (leaf node).
Returns:
ExprTree: Leaf node with type 'Constant'.
"""
return ExprTree("Constant")
[docs]
def is_affine(self) -> bool:
"""Constants are affine expressions.
Returns:
bool: Always True.
"""
return True
[docs]
def __neg__(self):
"""Negate the constant (keeps it as a constant).
Returns:
Constant: Negated constant.
"""
return Constant(-self.value)