mcts.CausalMCTS

class mcts.CausalMCTS(initial_state, goal_state, scenario, reward_shaper, intervention_space, max_rollout_depth, termination_threshold, exploration_constant)

Bases: object

MCTS-based planner for causal intervention planning.

This class implements a complete MCTS loop for finding sequences of causal interventions that transform an initial symbolic state into a goal state. The planner operates on symbolic state representations loaded from JSON files.

Attributes:

initial_state: Dictionary containing the initial symbolic state. goal_state: Set of symbolic relationships defining the goal. scenario: Dictionary containing scenario configuration. reward_shaper: RewardShaper instance for computing rewards. intervention_space: Dictionary mapping objects to available interventions. max_rollout_depth: Maximum depth for random rollouts. termination_threshold: Reward threshold for early termination. exploration_constant: UCT exploration parameter (default: 0.5). reward_history: List of rewards from completed rollouts.

__init__(initial_state, goal_state, scenario, reward_shaper, intervention_space, max_rollout_depth, termination_threshold, exploration_constant)

Initialize the Causal MCTS planner.

Args:

initial_state: Dictionary with initial symbolic state (from JSON). goal_state: Set of symbolic goal predicates (e.g., {‘On(2,1)’, ‘On(3,2)’}). scenario: Complete scenario configuration dictionary. reward_shaper: Instance of RewardShaper for reward computation. intervention_space: Dict mapping object names to available intervention actions. max_rollout_depth: Maximum number of interventions in a rollout. termination_threshold: Reward value that indicates success. exploration_constant: UCT exploration constant (higher = more exploration).

Methods

__init__(initial_state, goal_state, ...)

Initialize the Causal MCTS planner.

backpropagate(node, reward)

Backpropagate reward up the tree.

checkMisalignment(obj, state_manager)

Check if an object is misaligned relative to a reference.

expand(node)

Expand a node by trying an untried action.

get_legal_actions(interventions, node)

Generate list of legal actions for a given state.

rollout(node)

Perform a random rollout from the given node.

search_resolution([iterations])

Execute MCTS search for a fixed number of iterations.

select(node)

Select a node to expand using UCT policy.

backpropagate(node, reward)

Backpropagate reward up the tree.

Updates visit counts and total rewards for all nodes from the given node back to the root.

Args:

node: Node where rollout ended. reward: Reward value to propagate.

checkMisalignment(obj, state_manager)

Check if an object is misaligned relative to a reference.

Determines whether an object’s position deviates significantly from an expected alignment (useful for detecting geometric violations).

Args:

obj: Name of the object to check. state_manager: Optional StateManager to use for checking. If None, uses initial state.

Returns:

True if object is misaligned, False otherwise.

expand(node: MCTSNode)

Expand a node by trying an untried action.

Creates a new child node by applying one of the untried actions from the given node.

Args:

node: Node to expand.

Returns:

Newly created child node, or the input node if no actions available.

Generate list of legal actions for a given state.

Determines which interventions are valid to apply next, based on the current sequence of interventions and any known violations.

Args:

interventions: Current sequence of (object, action) tuples. node: Current MCTS node (used for pruning based on violations).

Returns:

List of legal (object, action) tuples that can be applied next.

rollout(node: MCTSNode) float

Perform a random rollout from the given node.

Simulates applying the interventions in the node, then continues with random actions until max depth is reached or goal is achieved. Computes cumulative reward.

Args:

node: Node from which to start the rollout.

Returns:

Total reward accumulated during the rollout.

search_resolution(iterations: int = 300)

Execute MCTS search for a fixed number of iterations.

This is the main entry point for running the MCTS algorithm. It performs the complete MCTS loop: selection, expansion, rollout, and backpropagation.

Args:

iterations: Number of MCTS iterations to perform.

Returns:
A tuple containing:
  • Best sequence of interventions found (list of (object, action) tuples)

  • Number of iterations completed before termination

select(node: MCTSNode)

Select a node to expand using UCT policy.

Traverses the tree from the given node using Upper Confidence Bound for Trees (UCT) until reaching a node that can be expanded.

Args:

node: Root node from which to start selection.

Returns:

Node selected for expansion.