Source code for rlaopt.expression.tree

"""Expression tree for testing and visualization."""

from collections import Counter

from typing_extensions import Self


[docs] class ExprTree: """Tree representation of expression structure. Used for testing expression optimizations and visualizing expression structure. Each node has a type (class name) and zero or more children. """
[docs] def __init__(self, node_type: str, *children: Self, is_commutative: bool = False): """Initialize an expression tree node. Args: node_type: The expression class name. *children: Child ExprTree nodes. is_commutative: Whether this operation is commutative (e.g., addition, multiplication). For commutative operations, child order doesn't matter for equality/hashing. """ self.node_type = node_type self.children = children self.is_commutative = is_commutative
[docs] def __eq__(self, other) -> bool: """Check structural equality of two expression trees. For commutative operations, children can be in any order. For other operations, order matters. Args: other: Another ExprTree to compare with. Returns: bool: True if trees have the same structure and node types. """ if not isinstance(other, ExprTree): return False if self.node_type != other.node_type: return False if self.is_commutative != other.is_commutative: return False # For commutative operations, ignore child order if self.is_commutative: return Counter(self.children) == Counter(other.children) # For other operations, order matters return self.children == other.children
[docs] def __hash__(self) -> int: """Hash the tree for use in sets and dictionaries. For commutative operations, hash is order-independent. Returns: int: Hash of the tree structure. """ if self.is_commutative: # Order-independent hash for commutative operations # Use Counter to preserve duplicates (frozenset would lose them) return hash( ( self.node_type, frozenset(Counter(self.children).items()), self.is_commutative, ) ) return hash((self.node_type, self.children, self.is_commutative))
[docs] def __repr__(self) -> str: """Return a code-like representation of the tree. Returns: str: String representation suitable for debugging. """ parts = [repr(self.node_type)] if self.children: parts.extend(repr(c) for c in self.children) if self.is_commutative: parts.append("is_commutative=True") return f"ExprTree({', '.join(parts)})"
[docs] def __str__(self) -> str: """Return a pretty-printed tree visualization. Returns: str: Human-readable tree with indentation and branches. """ return self._pretty(prefix="", is_root=True)
def _pretty(self, prefix: str, is_root: bool = False, is_last: bool = True) -> str: """Recursively build pretty-printed tree. Args: prefix: Prefix string for current line (contains branch characters). is_root: Whether this is the root node. is_last: Whether this is the last child of its parent. Returns: str: Pretty-printed subtree. """ # Current node if is_root: # Root node - no prefix or branch result = self.node_type + "\n" else: # Non-root node - add branch character branch = "└─ " if is_last else "├─ " result = prefix + branch + self.node_type + "\n" # Recursively add children for i, child in enumerate(self.children): is_last_child = i == len(self.children) - 1 if is_root: # Children of root - start with empty prefix child_prefix = "" else: # Children of non-root get extended prefix extension = " " if is_last else "│ " child_prefix = prefix + extension result += child._pretty(child_prefix, is_root=False, is_last=is_last_child) return result
[docs] def depth(self) -> int: """Calculate the depth of the tree. Returns: int: Maximum depth from this node to any leaf. """ if not self.children: return 0 return 1 + max(child.depth() for child in self.children)
[docs] def count_nodes(self) -> int: """Count the total number of nodes in the tree. Returns: int: Total number of nodes (including this one). """ return 1 + sum(child.count_nodes() for child in self.children)