game_player/mcts.rs
1//! Monte Carlo Tree Search (MCTS) module
2//!
3//! This module implements the Monte Carlo Tree Search (MCTS) algorithm for making decisions in two-player perfect and hidden
4//! information games. It provides the core MCTS logic, including selection, expansion, simulation (rollout), and back-propagation
5//! phases. The MCTS algorithm builds a search tree incrementally and uses random simulations to evaluate the potential of
6//! different moves. The module is designed to be generic and works with any game state that implements the `State` trait, along
7//! with the `ResponseGenerator` and `Rollout` traits.
8
9use crate::state::State;
10use indextree::{Arena, NodeId};
11
12/// Default exploration constant for the UCT formula
13pub const DEFAULT_EXPLORATION_CONSTANT: f32 = std::f32::consts::SQRT_2;
14
15/// Response generator trait for MCTS search
16pub trait ResponseGenerator {
17 /// The type representing game states that this generator works with
18 type State: State;
19
20 /// Generates a list of all possible legal actions from the given state.
21 ///
22 /// # Arguments
23 /// * `state` - state to respond to
24 ///
25 /// # Returns
26 /// List of all possible legal actions for the given state.
27 ///
28 /// # Notes
29 /// - All returned actions must be legal in the provided state.
30 /// - The order of actions is not significant unless required by the implementation.
31 /// - Returning no actions indicates that the player cannot respond. It does not necessarily indicate that the game is
32 /// over or that the player has passed. If passing is allowed, then a pass must be a valid action.
33 fn generate(&self, state: &Self::State) -> Vec<<Self::State as State>::Action>;
34}
35
36/// Rollout trait for MCTS search
37///
38/// The Rollout trait defines the interface for performing rollouts in a game state. It is used by the MCTS algorithm to simulate
39/// game play and evaluate the potential outcomes of different moves. It is implemented for specific game states in the
40/// game-specific code.
41///
42/// # Type Parameters
43/// * `S` - Game state type
44///
45/// # Examples
46///
47/// ```rust
48/// # use game_player::mcts::{Rollout, ResponseGenerator};
49/// # use game_player::State;
50/// # #[derive(Debug, Clone, Default)]
51/// # struct TestGameState { value: i32 }
52/// # impl State for TestGameState {
53/// # type Action = TestAction;
54/// # fn fingerprint(&self) -> u64 { self.value as u64 }
55/// # fn whose_turn(&self) -> u8 { 0 }
56/// # fn is_terminal(&self) -> bool { false }
57/// # fn apply(&self, _action: &TestAction) -> Self { self.clone() }
58/// # }
59/// #[derive(Debug, Clone, Default)]
60/// struct TestAction;
61/// struct TestResponseGen;
62/// impl ResponseGenerator for TestResponseGen {
63/// type State = TestGameState;
64/// fn generate(&self, _state: &TestGameState) -> Vec<TestAction> { vec![TestAction] }
65/// }
66/// struct TestRollout;
67/// impl Rollout for TestRollout {
68/// type State = TestGameState;
69/// type ResponseGenerator = TestResponseGen;
70/// fn play(&self, state: &TestGameState, _rg: &TestResponseGen) -> f32 { state.value as f32 }
71/// }
72/// let state = TestGameState { value: 42 };
73/// let rollout = TestRollout;
74/// let rg = TestResponseGen;
75/// let score = rollout.play(&state, &rg);
76/// assert_eq!(score, 42.0);
77/// ```
78pub trait Rollout {
79 /// The type representing the game state
80 type State: State;
81 type ResponseGenerator: ResponseGenerator<State = Self::State>;
82
83 /// Performs a rollout (simulation) from the given state using the provided response generator, and returns the evaluated value.
84 ///
85 /// The rollout is a simulation of the game from the given state to a terminal state, following simple heuristics or random
86 /// moves. The result is a score in the range [-1.0, 1.0], where 1.0 indicates a win for the computer player, -1.0 indicates a
87 /// loss, and 0.0 indicates a draw.
88 ///
89 /// # Arguments
90 /// * `state` - The game state to perform the rollout from
91 /// * `response_generator` - The response generator for producing legal actions
92 ///
93 /// # Returns
94 /// [-1.0, 1.0] score representing the outcome of the rollout from the perspective of the computer player.
95 fn play(&self, state: &Self::State, response_generator: &Self::ResponseGenerator) -> f32;
96}
97
98/// Represents a node in the MCTS tree
99///
100/// # Type Parameters
101/// * `S` - Game state type
102///
103/// # Examples
104///
105/// ```rust
106/// # use game_player::mcts::ResponseGenerator;
107/// # use game_player::{State, StaticEvaluator};
108/// # use indextree::{Arena, NodeId};
109///
110/// # #[derive(Debug, Clone, Default)]
111/// # struct TestGameState { value: i32 }
112/// # impl State for TestGameState {
113/// # type Action = TestAction;
114/// # fn fingerprint(&self) -> u64 { self.value as u64 }
115/// # fn whose_turn(&self) -> u8 { 0 }
116/// # fn is_terminal(&self) -> bool { false }
117/// # fn apply(&self, _action: &TestAction) -> Self { self.clone() }
118/// # }
119/// # #[derive(Debug, Clone, Default)]
120/// # struct TestAction;
121/// # struct TestResponseGen;
122/// # impl ResponseGenerator for TestResponseGen {
123/// # type State = TestGameState;
124/// # fn generate(&self, _state: &TestGameState) -> Vec<TestAction> { vec![TestAction] }
125/// # }
126///
127/// let state = TestGameState { value: 42 };
128/// let response_gen = TestResponseGen;
129/// // This example shows basic usage of the test types
130/// assert_eq!(state.value, 42);
131/// ```
132/// ```
133struct Node<S>
134where
135 S: State,
136{
137 /// The game state represented by this node
138 state: S,
139 /// Action that led to this node
140 action: Option<S::Action>,
141 /// Untried actions that have not been expanded yet
142 untried_actions: Vec<S::Action>,
143 /// Number of times this node has been visited
144 visits: u32,
145 /// Sum of the values of all simulations that passed through this node
146 value_sum: f32,
147}
148
149impl<S> Node<S>
150where
151 S: State,
152{
153 // Creates a new node with the given game state and action
154 //
155 // # Arguments
156 // * `state` - The game state this node represents
157 // * `action` - The action that led to this state from the parent, None for the root
158 // * `rg` - Response generator to determine possible actions from this state
159 //
160 // # Returns
161 // A new Node instance with zero visits
162 fn new<G>(state: S, action: Option<S::Action>, rg: &G) -> Self
163 where
164 G: ResponseGenerator<State = S>,
165 {
166 let untried_actions = rg.generate(&state);
167 Self {
168 state,
169 action,
170 untried_actions,
171 visits: 0,
172 value_sum: 0.0,
173 }
174 }
175
176 // Checks if the node is fully expanded
177 //
178 // A node is fully expanded when all possible actions from this state have been
179 // tried and added as child nodes.
180 //
181 // # Returns
182 // `true` if no untried actions remain, `false` otherwise
183 fn fully_expanded(&self) -> bool {
184 self.untried_actions.is_empty()
185 }
186
187 // Calculates the UCT value for this node
188 //
189 // The UCT formula balances exploitation (average reward) with exploration (uncertainty).
190 // Higher UCT values indicate more promising nodes to explore.
191 //
192 // # Arguments
193 // * `arena` - The Arena containing all nodes
194 // * `c` - Exploration constant (typically sqrt(2) ≈ 1.414)
195 //
196 // # Panics
197 // Panics if the parent node does not exist or has zero visits.
198 //
199 // # Returns
200 // The UCT value for this node, or f32::INFINITY if unvisited
201 fn uct(&self, node_id: NodeId, arena: &Arena<Node<S>>, c: f32) -> f32 {
202 if let Some(parent_node) = arena[node_id].parent().and_then(|parent_id| arena.get(parent_id)) {
203 // If this node has never been visited, return infinity to ensure it gets visited
204 if self.visits == 0 {
205 return f32::INFINITY;
206 }
207 let parent_visits = parent_node.get().visits;
208 if parent_visits > 0 {
209 let confidence = c * ((parent_visits as f32).ln() / self.visits as f32).sqrt();
210 let mean_value = self.value_sum / self.visits as f32;
211 return mean_value + confidence;
212 }
213 }
214
215 panic!("UCT cannot be computed because the parent node does not exist or has no visits");
216 }
217}
218
219// Holds static information for the MCTS search
220struct Context<'a, G, R>
221where
222 G: ResponseGenerator,
223 R: Rollout<State = G::State>,
224{
225 /// Function to generate all possible child states
226 response_generator: &'a G,
227 /// Rollout implementation
228 rollout: &'a R,
229 /// Exploration constant for the UCT formula
230 c: f32,
231}
232
233/// Searches for the best action using the MCTS algorithm
234///
235/// Performs the four phases of MCTS (Selection, Expansion, Rollout, Back Propagation) for the given number of iterations,
236/// building up statistics in the search tree.
237///
238/// # Arguments
239/// * `s0` - Initial game state to serve as the root of the search tree
240/// * `rg` - Response generator that returns all possible actions from a state
241/// * `roll` - Rollout implementation for simulating games
242/// * `c` - Exploration constant for UCT calculation
243/// * `max_iterations` - Number of MCTS iterations to perform
244///
245/// # Returns
246/// Some(best_action) containing the action leading to the child of the root node with the most visits (most promising move),
247/// or None if the root state has no possible actions.
248///
249/// # Panics
250/// This function will panic if the UCT function ever returns NaN.
251pub fn search<S, G, R>(s0: &S, rg: &G, roll: &R, c: f32, max_iterations: u32) -> Option<S::Action>
252where
253 S: State + Clone,
254 G: ResponseGenerator<State = S>,
255 R: Rollout<State = S, ResponseGenerator = G>,
256{
257 // Create context for the search
258 let context = Context {
259 response_generator: rg,
260 rollout: roll,
261 c,
262 };
263
264 // Create the arena that will hold all nodes
265 let mut arena = Arena::new();
266
267 // Initialize root node
268 let root_node = Node::new(s0.clone(), None, rg);
269 let root_id = arena.new_node(root_node);
270 arena.get_mut(root_id).unwrap().get_mut().visits = 1; // Root node is automatically visited once
271
272 for _ in 0..max_iterations {
273 // Selection - traverse the tree to find the best leaf node to expand
274 let mut node_id = select(root_id, &arena, &context);
275
276 // Expansion - add another child to the node if it is not terminal and has untried actions
277 if let Some(child_id) = expand(node_id, &mut arena, &context) {
278 node_id = child_id;
279 }
280
281 // Rollout - evaluate the node using the Rollout implementation, and normalize the result to [0, 1]
282 let value = (rollout(node_id, &arena, &context) + 1.0) / 2.0;
283
284 // Back-propagation - update the node and its ancestors with the rollout result
285 back_propagate(node_id, &mut arena, value);
286 }
287
288 // If there are no responses to the root state then return None
289 if root_id.children(&arena).count() == 0 {
290 return None;
291 }
292
293 // Get the best child (of root) by the number of visits, or None if there are no children
294 let best_child_id = root_id.children(&arena).max_by(|&a, &b| {
295 let a_visits = arena[a].get().visits;
296 let b_visits = arena[b].get().visits;
297 a_visits.cmp(&b_visits)
298 });
299
300 // Return the action that led to the best child
301 best_child_id.and_then(|child_id| arena[child_id].get().action.clone())
302}
303
304// Helper function that determines if the node should be selected for expansion.
305// A node is selectable if it is not fully expanded, has no children, or represents a terminal game state.
306fn selectable<S>(node_id: NodeId, arena: &Arena<Node<S>>) -> bool
307where
308 S: State,
309{
310 let has_children = node_id.children(arena).count() > 0;
311 let node = arena[node_id].get();
312 !node.fully_expanded() || !has_children || node.state.is_terminal()
313}
314
315// Selects the best node for expansion using the UCT value
316//
317// This method traverses the tree from the given node downward, selecting the child with the highest UCT value at each step
318// until it reaches a node that is either not fully expanded, has no children, or represents a terminal game state. That node is
319// returned.
320//
321// # Arguments
322// * `node_id` - The starting node for selection (typically the root)
323// * `arena` - The Arena containing all nodes
324// * `context` - The search context containing parameters
325//
326// # Returns
327// The node selected for expansion or evaluation
328fn select<G, R>(node_id: NodeId, arena: &Arena<Node<G::State>>, context: &Context<'_, G, R>) -> NodeId
329where
330 G: ResponseGenerator,
331 R: Rollout<State = G::State>,
332{
333 let c = context.c;
334 let mut selected = node_id;
335
336 // Traverse the tree until a selectable node is found
337 // If a node is not fully expanded, then select it for expansion.
338 // If a node is terminal or has no children, then select it for evaluation.
339 // Otherwise, descend to the child with the highest UCT value and continue.
340 while !selectable(selected, arena) {
341 let children: Vec<NodeId> = selected.children(arena).collect();
342 let best_child = *children
343 .iter()
344 .max_by(|&a, &b| {
345 let a_uct = arena.get(*a).unwrap().get().uct(*a, arena, c);
346 let b_uct = arena.get(*b).unwrap().get().uct(*b, arena, c);
347 a_uct.total_cmp(&b_uct)
348 })
349 .unwrap(); // Safe to unwrap because not_selectable ensures there are children
350 selected = best_child;
351 }
352
353 selected
354}
355
356// Expands a node by adding a child for one of its untried actions and returns the new child node
357//
358// If the node is fully expanded (no untried actions remain) or represents a terminal game state, this function returns None.
359//
360// # Arguments
361// * `node_id` - The node to expand
362// * `arena` - The arena containing all nodes
363// * `context` - The search context containing the response generator
364//
365// # Returns
366// Some(child_node_id) if expansion was successful, None otherwise
367fn expand<G, R>(node_id: NodeId, arena: &mut Arena<Node<G::State>>, context: &Context<'_, G, R>) -> Option<NodeId>
368where
369 G: ResponseGenerator,
370 R: Rollout<State = G::State>,
371{
372 // Get the next untried action, or return None if there are no untried actions
373 let action = arena[node_id].get_mut().untried_actions.pop()?;
374
375 // Apply the action to the node's state to get a new state and create a new child node
376 let child_state = arena[node_id].get().state.apply(&action);
377
378 // Create the new child node and add it to the arena
379 let child_node = Node::new(child_state, Some(action), context.response_generator);
380 let child_id = arena.new_node(child_node);
381
382 // Add the new child to the parent node using indextree's append
383 node_id.append(child_id, arena);
384
385 // Return the new child node ID
386 Some(child_id)
387}
388
389// Performs a rollout (simulation) from the given node and returns the evaluated value
390//
391// In the current implementation, rollout simply evaluates the current node's state using the static evaluator rather than
392// performing random simulation.
393//
394// # Arguments
395// * `node_id` - The node to evaluate
396// * `arena` - The arena containing all nodes
397// * `context` - The search context containing the rollout implementation
398//
399// # Returns
400// [-1.0, 1.0] as the evaluation score for the node's game state
401fn rollout<G, R>(node_id: NodeId, arena: &Arena<Node<G::State>>, context: &Context<'_, G, R>) -> f32
402where
403 G: ResponseGenerator,
404 R: Rollout<State = G::State, ResponseGenerator = G>,
405{
406 context.rollout.play(&arena[node_id].get().state, context.response_generator)
407}
408
409// Back-propagates the value up the tree
410//
411// Updates the visit count and value sum for the given node and all of its ancestors up to the root.
412//
413// # Arguments
414// * `node_id` - The starting node for back-propagation
415// * `arena` - The arena containing all nodes
416// * `value` - The value to propagate up the tree
417fn back_propagate<S>(node_id: NodeId, arena: &mut Arena<Node<S>>, value: f32)
418where
419 S: State,
420{
421 let mut current = Some(node_id);
422 while let Some(id) = current {
423 let node = arena[id].get_mut();
424 node.visits += 1;
425 node.value_sum += value;
426 current = arena[id].parent();
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 // Test implementations for testing
435 #[derive(Debug, Clone, Default, PartialEq)]
436 struct TestGameState {
437 value: i32,
438 terminal: bool,
439 }
440
441 impl State for TestGameState {
442 type Action = TestAction;
443
444 fn fingerprint(&self) -> u64 {
445 self.value as u64
446 }
447
448 fn whose_turn(&self) -> u8 {
449 0
450 }
451
452 fn is_terminal(&self) -> bool {
453 self.terminal
454 }
455
456 fn apply(&self, action: &TestAction) -> Self {
457 Self {
458 value: self.value + action.increment,
459 terminal: self.value + action.increment > 10,
460 }
461 }
462 }
463
464 #[derive(Debug, Clone, Default)]
465 struct TestAction {
466 increment: i32,
467 }
468
469 impl TestAction {
470 fn new(increment: i32) -> Self {
471 Self { increment }
472 }
473 }
474
475 struct TestResponseGenerator;
476
477 impl ResponseGenerator for TestResponseGenerator {
478 type State = TestGameState;
479
480 fn generate(&self, state: &TestGameState) -> Vec<TestAction> {
481 if state.terminal {
482 vec![]
483 } else {
484 vec![TestAction::new(1), TestAction::new(2)]
485 }
486 }
487 }
488
489 struct EmptyResponseGenerator;
490
491 impl ResponseGenerator for EmptyResponseGenerator {
492 type State = TestGameState;
493
494 fn generate(&self, _state: &TestGameState) -> Vec<TestAction> {
495 vec![]
496 }
497 }
498
499 struct SingleResponseGenerator;
500
501 impl ResponseGenerator for SingleResponseGenerator {
502 type State = TestGameState;
503
504 fn generate(&self, state: &TestGameState) -> Vec<TestAction> {
505 if state.terminal { vec![] } else { vec![TestAction::new(1)] }
506 }
507 }
508
509 struct VariableResponseGenerator;
510
511 impl ResponseGenerator for VariableResponseGenerator {
512 type State = TestGameState;
513
514 fn generate(&self, state: &TestGameState) -> Vec<TestAction> {
515 if state.terminal {
516 vec![]
517 } else if state.value < 3 {
518 vec![TestAction::new(1), TestAction::new(2), TestAction::new(3)]
519 } else if state.value < 6 {
520 vec![TestAction::new(1), TestAction::new(2)]
521 } else {
522 vec![TestAction::new(1)]
523 }
524 }
525 }
526
527 struct TestRollout;
528
529 impl Rollout for TestRollout {
530 type State = TestGameState;
531 type ResponseGenerator = TestResponseGenerator;
532
533 fn play(&self, _state: &TestGameState, _rg: &TestResponseGenerator) -> f32 {
534 0.5 // Simple fixed rollout value for testing
535 }
536 }
537
538 // Tests for ResponseGenerator trait
539 #[test]
540 fn test_mcts_response_generator_basic() {
541 let generator = TestResponseGenerator;
542 let state = TestGameState {
543 value: 5,
544 terminal: false,
545 };
546 let terminal_state = TestGameState {
547 value: 15,
548 terminal: true,
549 };
550
551 let actions = generator.generate(&state);
552 assert_eq!(actions.len(), 2);
553 assert_eq!(actions[0].increment, 1);
554 assert_eq!(actions[1].increment, 2);
555
556 let terminal_actions = generator.generate(&terminal_state);
557 assert!(terminal_actions.is_empty());
558 }
559
560 #[test]
561 fn test_mcts_response_generator_empty() {
562 let generator = EmptyResponseGenerator;
563 let state = TestGameState {
564 value: 0,
565 terminal: false,
566 };
567
568 let actions = generator.generate(&state);
569 assert!(actions.is_empty());
570 }
571
572 #[test]
573 fn test_mcts_response_generator_single() {
574 let generator = SingleResponseGenerator;
575 let state = TestGameState {
576 value: 3,
577 terminal: false,
578 };
579
580 let actions = generator.generate(&state);
581 assert_eq!(actions.len(), 1);
582 assert_eq!(actions[0].increment, 1);
583 }
584
585 #[test]
586 fn test_mcts_response_generator_variable() {
587 let generator = VariableResponseGenerator;
588
589 // Low value state should have 3 actions
590 let low_state = TestGameState {
591 value: 1,
592 terminal: false,
593 };
594 let actions = generator.generate(&low_state);
595 assert_eq!(actions.len(), 3);
596
597 // Medium value state should have 2 actions
598 let med_state = TestGameState {
599 value: 4,
600 terminal: false,
601 };
602 let actions = generator.generate(&med_state);
603 assert_eq!(actions.len(), 2);
604
605 // High value state should have 1 action
606 let high_state = TestGameState {
607 value: 7,
608 terminal: false,
609 };
610 let actions = generator.generate(&high_state);
611 assert_eq!(actions.len(), 1);
612
613 // Terminal state should have no actions
614 let terminal_state = TestGameState {
615 value: 15,
616 terminal: true,
617 };
618 let actions = generator.generate(&terminal_state);
619 assert!(actions.is_empty());
620 }
621
622 #[test]
623 fn test_mcts_search_basic() {
624 let state = TestGameState {
625 value: 0,
626 terminal: false,
627 };
628 let generator = TestResponseGenerator;
629 let rollout = TestRollout;
630
631 // Test with minimal iterations
632 let result = search(&state, &generator, &rollout, 1.0, 1);
633 // Since we have actions available, should return Some action
634 assert!(result.is_some());
635 }
636
637 #[test]
638 fn test_mcts_search_terminal_state() {
639 let state = TestGameState {
640 value: 0,
641 terminal: true,
642 };
643 let generator = TestResponseGenerator;
644 let rollout = TestRollout;
645
646 // Terminal state should return None (no actions)
647 let result = search(&state, &generator, &rollout, 1.0, 10);
648 assert!(result.is_none());
649 }
650
651 #[test]
652 fn test_mcts_search_zero_iterations() {
653 let state = TestGameState {
654 value: 0,
655 terminal: false,
656 };
657 let generator = TestResponseGenerator;
658 let rollout = TestRollout;
659
660 // Zero iterations should still work
661 let result = search(&state, &generator, &rollout, 1.0, 0);
662 // Might return None or Some depending on implementation
663 // Just verify it doesn't crash
664 let _ = result;
665 }
666
667 #[test]
668 fn test_mcts_search_different_c_values() {
669 let state = TestGameState {
670 value: 0,
671 terminal: false,
672 };
673 let generator = TestResponseGenerator;
674 let rollout = TestRollout;
675
676 // Test with different exploration constants
677 let result1 = search(&state, &generator, &rollout, 0.1, 5);
678 let result2 = search(&state, &generator, &rollout, 2.0, 5);
679
680 // Both should work (might return different results)
681 assert!(result1.is_some() || state.terminal);
682 assert!(result2.is_some() || state.terminal);
683 }
684
685 #[test]
686 fn test_mcts_search_multiple_iterations() {
687 let state = TestGameState {
688 value: 0,
689 terminal: false,
690 };
691 let generator = TestResponseGenerator;
692 let rollout = TestRollout;
693
694 // Test with multiple iterations
695 let result = search(&state, &generator, &rollout, 1.4, 50);
696
697 // Should return an action if state is not terminal
698 if !state.terminal {
699 assert!(result.is_some());
700 }
701 }
702
703 #[test]
704 fn test_mcts_search_consistency() {
705 let state = TestGameState {
706 value: 5,
707 terminal: false,
708 };
709 let generator = TestResponseGenerator;
710 let rollout = TestRollout;
711
712 // Multiple searches on same state should work
713 let result1 = search(&state, &generator, &rollout, 1.0, 10);
714 let result2 = search(&state, &generator, &rollout, 1.0, 10);
715
716 // Both should return results (might be different due to randomness)
717 assert!(result1.is_some());
718 assert!(result2.is_some());
719 }
720}