Source code for rlaopt.expression.tree
"""Expression tree for testing and visualization."""
from collections import Counter
from typing_extensions import Self
[docs]
class ExprTree:
"""Tree representation of expression structure.
Used for testing expression optimizations and visualizing expression structure.
Each node has a type (class name) and zero or more children.
"""
[docs]
def __init__(self, node_type: str, *children: Self, is_commutative: bool = False):
"""Initialize an expression tree node.
Args:
node_type: The expression class name.
*children: Child ExprTree nodes.
is_commutative: Whether this operation is commutative (e.g., addition,
multiplication). For commutative operations, child order doesn't
matter for equality/hashing.
"""
self.node_type = node_type
self.children = children
self.is_commutative = is_commutative
[docs]
def __eq__(self, other) -> bool:
"""Check structural equality of two expression trees.
For commutative operations, children can be in any order.
For other operations, order matters.
Args:
other: Another ExprTree to compare with.
Returns:
bool: True if trees have the same structure and node types.
"""
if not isinstance(other, ExprTree):
return False
if self.node_type != other.node_type:
return False
if self.is_commutative != other.is_commutative:
return False
# For commutative operations, ignore child order
if self.is_commutative:
return Counter(self.children) == Counter(other.children)
# For other operations, order matters
return self.children == other.children
[docs]
def __hash__(self) -> int:
"""Hash the tree for use in sets and dictionaries.
For commutative operations, hash is order-independent.
Returns:
int: Hash of the tree structure.
"""
if self.is_commutative:
# Order-independent hash for commutative operations
# Use Counter to preserve duplicates (frozenset would lose them)
return hash(
(
self.node_type,
frozenset(Counter(self.children).items()),
self.is_commutative,
)
)
return hash((self.node_type, self.children, self.is_commutative))
[docs]
def __repr__(self) -> str:
"""Return a code-like representation of the tree.
Returns:
str: String representation suitable for debugging.
"""
parts = [repr(self.node_type)]
if self.children:
parts.extend(repr(c) for c in self.children)
if self.is_commutative:
parts.append("is_commutative=True")
return f"ExprTree({', '.join(parts)})"
[docs]
def __str__(self) -> str:
"""Return a pretty-printed tree visualization.
Returns:
str: Human-readable tree with indentation and branches.
"""
return self._pretty(prefix="", is_root=True)
def _pretty(self, prefix: str, is_root: bool = False, is_last: bool = True) -> str:
"""Recursively build pretty-printed tree.
Args:
prefix: Prefix string for current line (contains branch characters).
is_root: Whether this is the root node.
is_last: Whether this is the last child of its parent.
Returns:
str: Pretty-printed subtree.
"""
# Current node
if is_root:
# Root node - no prefix or branch
result = self.node_type + "\n"
else:
# Non-root node - add branch character
branch = "└─ " if is_last else "├─ "
result = prefix + branch + self.node_type + "\n"
# Recursively add children
for i, child in enumerate(self.children):
is_last_child = i == len(self.children) - 1
if is_root:
# Children of root - start with empty prefix
child_prefix = ""
else:
# Children of non-root get extended prefix
extension = " " if is_last else "│ "
child_prefix = prefix + extension
result += child._pretty(child_prefix, is_root=False, is_last=is_last_child)
return result
[docs]
def depth(self) -> int:
"""Calculate the depth of the tree.
Returns:
int: Maximum depth from this node to any leaf.
"""
if not self.children:
return 0
return 1 + max(child.depth() for child in self.children)
[docs]
def count_nodes(self) -> int:
"""Count the total number of nodes in the tree.
Returns:
int: Total number of nodes (including this one).
"""
return 1 + sum(child.count_nodes() for child in self.children)