mcts.CausalMCTS
- class mcts.CausalMCTS(initial_state, goal_state, scenario, reward_shaper, intervention_space, max_rollout_depth, termination_threshold, exploration_constant)
Bases:
objectMCTS-based planner for causal intervention planning.
This class implements a complete MCTS loop for finding sequences of causal interventions that transform an initial symbolic state into a goal state. The planner operates on symbolic state representations loaded from JSON files.
- Attributes:
initial_state: Dictionary containing the initial symbolic state. goal_state: Set of symbolic relationships defining the goal. scenario: Dictionary containing scenario configuration. reward_shaper: RewardShaper instance for computing rewards. intervention_space: Dictionary mapping objects to available interventions. max_rollout_depth: Maximum depth for random rollouts. termination_threshold: Reward threshold for early termination. exploration_constant: UCT exploration parameter (default: 0.5). reward_history: List of rewards from completed rollouts.
- __init__(initial_state, goal_state, scenario, reward_shaper, intervention_space, max_rollout_depth, termination_threshold, exploration_constant)
Initialize the Causal MCTS planner.
- Args:
initial_state: Dictionary with initial symbolic state (from JSON). goal_state: Set of symbolic goal predicates (e.g., {‘On(2,1)’, ‘On(3,2)’}). scenario: Complete scenario configuration dictionary. reward_shaper: Instance of RewardShaper for reward computation. intervention_space: Dict mapping object names to available intervention actions. max_rollout_depth: Maximum number of interventions in a rollout. termination_threshold: Reward value that indicates success. exploration_constant: UCT exploration constant (higher = more exploration).
Methods
__init__(initial_state, goal_state, ...)Initialize the Causal MCTS planner.
backpropagate(node, reward)Backpropagate reward up the tree.
checkMisalignment(obj, state_manager)Check if an object is misaligned relative to a reference.
expand(node)Expand a node by trying an untried action.
get_legal_actions(interventions, node)Generate list of legal actions for a given state.
rollout(node)Perform a random rollout from the given node.
search_resolution([iterations])Execute MCTS search for a fixed number of iterations.
select(node)Select a node to expand using UCT policy.
- backpropagate(node, reward)
Backpropagate reward up the tree.
Updates visit counts and total rewards for all nodes from the given node back to the root.
- Args:
node: Node where rollout ended. reward: Reward value to propagate.
- checkMisalignment(obj, state_manager)
Check if an object is misaligned relative to a reference.
Determines whether an object’s position deviates significantly from an expected alignment (useful for detecting geometric violations).
- Args:
obj: Name of the object to check. state_manager: Optional StateManager to use for checking. If None, uses initial state.
- Returns:
True if object is misaligned, False otherwise.
- expand(node: MCTSNode)
Expand a node by trying an untried action.
Creates a new child node by applying one of the untried actions from the given node.
- Args:
node: Node to expand.
- Returns:
Newly created child node, or the input node if no actions available.
- get_legal_actions(interventions, node)
Generate list of legal actions for a given state.
Determines which interventions are valid to apply next, based on the current sequence of interventions and any known violations.
- Args:
interventions: Current sequence of (object, action) tuples. node: Current MCTS node (used for pruning based on violations).
- Returns:
List of legal (object, action) tuples that can be applied next.
- rollout(node: MCTSNode) float
Perform a random rollout from the given node.
Simulates applying the interventions in the node, then continues with random actions until max depth is reached or goal is achieved. Computes cumulative reward.
- Args:
node: Node from which to start the rollout.
- Returns:
Total reward accumulated during the rollout.
- search_resolution(iterations: int = 300)
Execute MCTS search for a fixed number of iterations.
This is the main entry point for running the MCTS algorithm. It performs the complete MCTS loop: selection, expansion, rollout, and backpropagation.
- Args:
iterations: Number of MCTS iterations to perform.
- Returns:
- A tuple containing:
Best sequence of interventions found (list of (object, action) tuples)
Number of iterations completed before termination