game_player/
mcts.rs

1//! Monte Carlo Tree Search (MCTS) module
2//!
3//! This module implements the Monte Carlo Tree Search (MCTS) algorithm for making decisions in two-player perfect and hidden
4//! information games. It provides the core MCTS logic, including selection, expansion, simulation (rollout), and back-propagation
5//! phases. The MCTS algorithm builds a search tree incrementally and uses random simulations to evaluate the potential of
6//! different moves. The module is designed to be generic and works with any game state that implements the `State` trait, along
7//! with the `ResponseGenerator` and `Rollout` traits.
8
9use crate::state::State;
10use indextree::{Arena, NodeId};
11
12/// Default exploration constant for the UCT formula
13pub const DEFAULT_EXPLORATION_CONSTANT: f32 = std::f32::consts::SQRT_2;
14
15/// Response generator trait for MCTS search
16pub trait ResponseGenerator {
17    /// The type representing game states that this generator works with
18    type State: State;
19
20    /// Generates a list of all possible legal actions from the given state.
21    ///
22    /// # Arguments
23    /// * `state` - state to respond to
24    ///
25    /// # Returns
26    /// List of all possible legal actions for the given state.
27    ///
28    /// # Notes
29    /// - All returned actions must be legal in the provided state.
30    /// - The order of actions is not significant unless required by the implementation.
31    /// - Returning no actions indicates that the player cannot respond. It does not necessarily indicate that the game is
32    ///   over or that the player has passed. If passing is allowed, then a pass must be a valid action.
33    fn generate(&self, state: &Self::State) -> Vec<<Self::State as State>::Action>;
34}
35
36/// Rollout trait for MCTS search
37///
38/// The Rollout trait defines the interface for performing rollouts in a game state. It is used by the MCTS algorithm to simulate
39/// game play and evaluate the potential outcomes of different moves. It is implemented for specific game states in the
40/// game-specific code.
41///
42/// # Type Parameters
43/// * `S` - Game state type
44///
45/// # Examples
46///
47/// ```rust
48/// # use game_player::mcts::{Rollout, ResponseGenerator};
49/// # use game_player::State;
50/// # #[derive(Debug, Clone, Default)]
51/// # struct TestGameState { value: i32 }
52/// # impl State for TestGameState {
53/// #     type Action = TestAction;
54/// #     fn fingerprint(&self) -> u64 { self.value as u64 }
55/// #     fn whose_turn(&self) -> u8 { 0 }
56/// #     fn is_terminal(&self) -> bool { false }
57/// #     fn apply(&self, _action: &TestAction) -> Self { self.clone() }
58/// # }
59/// #[derive(Debug, Clone, Default)]
60/// struct TestAction;
61/// struct TestResponseGen;
62/// impl ResponseGenerator for TestResponseGen {
63///     type State = TestGameState;
64///     fn generate(&self, _state: &TestGameState) -> Vec<TestAction> { vec![TestAction] }
65/// }
66/// struct TestRollout;
67/// impl Rollout for TestRollout {
68///     type State = TestGameState;
69///     type ResponseGenerator = TestResponseGen;
70///     fn play(&self, state: &TestGameState, _rg: &TestResponseGen) -> f32 { state.value as f32 }
71/// }
72/// let state = TestGameState { value: 42 };
73/// let rollout = TestRollout;
74/// let rg = TestResponseGen;
75/// let score = rollout.play(&state, &rg);
76/// assert_eq!(score, 42.0);
77/// ```
78pub trait Rollout {
79    /// The type representing the game state
80    type State: State;
81    type ResponseGenerator: ResponseGenerator<State = Self::State>;
82
83    /// Performs a rollout (simulation) from the given state using the provided response generator, and returns the evaluated value.
84    ///
85    /// The rollout is a simulation of the game from the given state to a terminal state, following simple heuristics or random
86    /// moves. The result is a score in the range [-1.0, 1.0], where 1.0 indicates a win for the computer player, -1.0 indicates a
87    /// loss, and 0.0 indicates a draw.
88    ///
89    /// # Arguments
90    /// * `state` - The game state to perform the rollout from
91    /// * `response_generator` - The response generator for producing legal actions
92    ///
93    /// # Returns
94    /// [-1.0, 1.0] score representing the outcome of the rollout from the perspective of the computer player.
95    fn play(&self, state: &Self::State, response_generator: &Self::ResponseGenerator) -> f32;
96}
97
98/// Represents a node in the MCTS tree
99///
100/// # Type Parameters
101/// * `S` - Game state type
102///
103/// # Examples
104///
105/// ```rust
106/// # use game_player::mcts::ResponseGenerator;
107/// # use game_player::{State, StaticEvaluator};
108/// # use indextree::{Arena, NodeId};
109///
110/// # #[derive(Debug, Clone, Default)]
111/// # struct TestGameState { value: i32 }
112/// # impl State for TestGameState {
113/// #     type Action = TestAction;
114/// #     fn fingerprint(&self) -> u64 { self.value as u64 }
115/// #     fn whose_turn(&self) -> u8 { 0 }
116/// #     fn is_terminal(&self) -> bool { false }
117/// #     fn apply(&self, _action: &TestAction) -> Self { self.clone() }
118/// # }
119/// # #[derive(Debug, Clone, Default)]
120/// # struct TestAction;
121/// # struct TestResponseGen;
122/// # impl ResponseGenerator for TestResponseGen {
123/// #     type State = TestGameState;
124/// #     fn generate(&self, _state: &TestGameState) -> Vec<TestAction> { vec![TestAction] }
125/// # }
126///
127/// let state = TestGameState { value: 42 };
128/// let response_gen = TestResponseGen;
129/// // This example shows basic usage of the test types
130/// assert_eq!(state.value, 42);
131/// ```
132/// ```
133struct Node<S>
134where
135    S: State,
136{
137    /// The game state represented by this node
138    state: S,
139    /// Action that led to this node
140    action: Option<S::Action>,
141    /// Untried actions that have not been expanded yet
142    untried_actions: Vec<S::Action>,
143    /// Number of times this node has been visited
144    visits: u32,
145    /// Sum of the values of all simulations that passed through this node
146    value_sum: f32,
147}
148
149impl<S> Node<S>
150where
151    S: State,
152{
153    // Creates a new node with the given game state and action
154    //
155    // # Arguments
156    // * `state` - The game state this node represents
157    // * `action` - The action that led to this state from the parent, None for the root
158    // * `rg` - Response generator to determine possible actions from this state
159    //
160    // # Returns
161    // A new Node instance with zero visits
162    fn new<G>(state: S, action: Option<S::Action>, rg: &G) -> Self
163    where
164        G: ResponseGenerator<State = S>,
165    {
166        let untried_actions = rg.generate(&state);
167        Self {
168            state,
169            action,
170            untried_actions,
171            visits: 0,
172            value_sum: 0.0,
173        }
174    }
175
176    // Checks if the node is fully expanded
177    //
178    // A node is fully expanded when all possible actions from this state have been
179    // tried and added as child nodes.
180    //
181    // # Returns
182    // `true` if no untried actions remain, `false` otherwise
183    fn fully_expanded(&self) -> bool {
184        self.untried_actions.is_empty()
185    }
186
187    // Calculates the UCT value for this node
188    //
189    // The UCT formula balances exploitation (average reward) with exploration (uncertainty).
190    // Higher UCT values indicate more promising nodes to explore.
191    //
192    // # Arguments
193    // * `arena` - The Arena containing all nodes
194    // * `c` - Exploration constant (typically sqrt(2) ≈ 1.414)
195    //
196    // # Panics
197    // Panics if the parent node does not exist or has zero visits.
198    //
199    // # Returns
200    // The UCT value for this node, or f32::INFINITY if unvisited
201    fn uct(&self, node_id: NodeId, arena: &Arena<Node<S>>, c: f32) -> f32 {
202        if let Some(parent_node) = arena[node_id].parent().and_then(|parent_id| arena.get(parent_id)) {
203            // If this node has never been visited, return infinity to ensure it gets visited
204            if self.visits == 0 {
205                return f32::INFINITY;
206            }
207            let parent_visits = parent_node.get().visits;
208            if parent_visits > 0 {
209                let confidence = c * ((parent_visits as f32).ln() / self.visits as f32).sqrt();
210                let mean_value = self.value_sum / self.visits as f32;
211                return mean_value + confidence;
212            }
213        }
214
215        panic!("UCT cannot be computed because the parent node does not exist or has no visits");
216    }
217}
218
219// Holds static information for the MCTS search
220struct Context<'a, G, R>
221where
222    G: ResponseGenerator,
223    R: Rollout<State = G::State>,
224{
225    /// Function to generate all possible child states
226    response_generator: &'a G,
227    /// Rollout implementation
228    rollout: &'a R,
229    /// Exploration constant for the UCT formula
230    c: f32,
231}
232
233/// Searches for the best action using the MCTS algorithm
234///
235/// Performs the four phases of MCTS (Selection, Expansion, Rollout, Back Propagation) for the given number of iterations,
236/// building up statistics in the search tree.
237///
238/// # Arguments
239/// * `s0` - Initial game state to serve as the root of the search tree
240/// * `rg` - Response generator that returns all possible actions from a state
241/// * `roll` - Rollout implementation for simulating games
242/// * `c` - Exploration constant for UCT calculation
243/// * `max_iterations` - Number of MCTS iterations to perform
244///
245/// # Returns
246/// Some(best_action) containing the action leading to the child of the root node with the most visits (most promising move),
247/// or None if the root state has no possible actions.
248///
249/// # Panics
250/// This function will panic if the UCT function ever returns NaN.
251pub fn search<S, G, R>(s0: &S, rg: &G, roll: &R, c: f32, max_iterations: u32) -> Option<S::Action>
252where
253    S: State + Clone,
254    G: ResponseGenerator<State = S>,
255    R: Rollout<State = S, ResponseGenerator = G>,
256{
257    // Create context for the search
258    let context = Context {
259        response_generator: rg,
260        rollout: roll,
261        c,
262    };
263
264    // Create the arena that will hold all nodes
265    let mut arena = Arena::new();
266
267    // Initialize root node
268    let root_node = Node::new(s0.clone(), None, rg);
269    let root_id = arena.new_node(root_node);
270    arena.get_mut(root_id).unwrap().get_mut().visits = 1; // Root node is automatically visited once
271
272    for _ in 0..max_iterations {
273        // Selection - traverse the tree to find the best leaf node to expand
274        let mut node_id = select(root_id, &arena, &context);
275
276        // Expansion - add another child to the node if it is not terminal and has untried actions
277        if let Some(child_id) = expand(node_id, &mut arena, &context) {
278            node_id = child_id;
279        }
280
281        // Rollout - evaluate the node using the Rollout implementation, and normalize the result to [0, 1]
282        let value = (rollout(node_id, &arena, &context) + 1.0) / 2.0;
283
284        // Back-propagation - update the node and its ancestors with the rollout result
285        back_propagate(node_id, &mut arena, value);
286    }
287
288    // If there are no responses to the root state then return None
289    if root_id.children(&arena).count() == 0 {
290        return None;
291    }
292
293    // Get the best child (of root) by the number of visits, or None if there are no children
294    let best_child_id = root_id.children(&arena).max_by(|&a, &b| {
295        let a_visits = arena[a].get().visits;
296        let b_visits = arena[b].get().visits;
297        a_visits.cmp(&b_visits)
298    });
299
300    // Return the action that led to the best child
301    best_child_id.and_then(|child_id| arena[child_id].get().action.clone())
302}
303
304// Helper function that determines if the node should be selected for expansion.
305// A node is selectable if it is not fully expanded, has no children, or represents a terminal game state.
306fn selectable<S>(node_id: NodeId, arena: &Arena<Node<S>>) -> bool
307where
308    S: State,
309{
310    let has_children = node_id.children(arena).count() > 0;
311    let node = arena[node_id].get();
312    !node.fully_expanded() || !has_children || node.state.is_terminal()
313}
314
315// Selects the best node for expansion using the UCT value
316//
317// This method traverses the tree from the given node downward, selecting the child with the highest UCT value at each step
318// until it reaches a node that is either not fully expanded, has no children, or represents a terminal game state. That node is
319// returned.
320//
321// # Arguments
322// * `node_id` - The starting node for selection (typically the root)
323// * `arena` - The Arena containing all nodes
324// * `context` - The search context containing parameters
325//
326// # Returns
327// The node selected for expansion or evaluation
328fn select<G, R>(node_id: NodeId, arena: &Arena<Node<G::State>>, context: &Context<'_, G, R>) -> NodeId
329where
330    G: ResponseGenerator,
331    R: Rollout<State = G::State>,
332{
333    let c = context.c;
334    let mut selected = node_id;
335
336    // Traverse the tree until a selectable node is found
337    // If a node is not fully expanded, then select it for expansion.
338    // If a node is terminal or has no children, then select it for evaluation.
339    // Otherwise, descend to the child with the highest UCT value and continue.
340    while !selectable(selected, arena) {
341        let children: Vec<NodeId> = selected.children(arena).collect();
342        let best_child = *children
343            .iter()
344            .max_by(|&a, &b| {
345                let a_uct = arena.get(*a).unwrap().get().uct(*a, arena, c);
346                let b_uct = arena.get(*b).unwrap().get().uct(*b, arena, c);
347                a_uct.total_cmp(&b_uct)
348            })
349            .unwrap(); // Safe to unwrap because not_selectable ensures there are children
350        selected = best_child;
351    }
352
353    selected
354}
355
356// Expands a node by adding a child for one of its untried actions and returns the new child node
357//
358// If the node is fully expanded (no untried actions remain) or represents a terminal game state, this function returns None.
359//
360// # Arguments
361// * `node_id` - The node to expand
362// * `arena` - The arena containing all nodes
363// * `context` - The search context containing the response generator
364//
365// # Returns
366// Some(child_node_id) if expansion was successful, None otherwise
367fn expand<G, R>(node_id: NodeId, arena: &mut Arena<Node<G::State>>, context: &Context<'_, G, R>) -> Option<NodeId>
368where
369    G: ResponseGenerator,
370    R: Rollout<State = G::State>,
371{
372    // Get the next untried action, or return None if there are no untried actions
373    let action = arena[node_id].get_mut().untried_actions.pop()?;
374
375    // Apply the action to the node's state to get a new state and create a new child node
376    let child_state = arena[node_id].get().state.apply(&action);
377
378    // Create the new child node and add it to the arena
379    let child_node = Node::new(child_state, Some(action), context.response_generator);
380    let child_id = arena.new_node(child_node);
381
382    // Add the new child to the parent node using indextree's append
383    node_id.append(child_id, arena);
384
385    // Return the new child node ID
386    Some(child_id)
387}
388
389// Performs a rollout (simulation) from the given node and returns the evaluated value
390//
391// In the current implementation, rollout simply evaluates the current node's state using the static evaluator rather than
392// performing random simulation.
393//
394// # Arguments
395// * `node_id` - The node to evaluate
396// * `arena` - The arena containing all nodes
397// * `context` - The search context containing the rollout implementation
398//
399// # Returns
400// [-1.0, 1.0] as the evaluation score for the node's game state
401fn rollout<G, R>(node_id: NodeId, arena: &Arena<Node<G::State>>, context: &Context<'_, G, R>) -> f32
402where
403    G: ResponseGenerator,
404    R: Rollout<State = G::State, ResponseGenerator = G>,
405{
406    context.rollout.play(&arena[node_id].get().state, context.response_generator)
407}
408
409// Back-propagates the value up the tree
410//
411// Updates the visit count and value sum for the given node and all of its ancestors up to the root.
412//
413// # Arguments
414// * `node_id` - The starting node for back-propagation
415// * `arena` - The arena containing all nodes
416// * `value` - The value to propagate up the tree
417fn back_propagate<S>(node_id: NodeId, arena: &mut Arena<Node<S>>, value: f32)
418where
419    S: State,
420{
421    let mut current = Some(node_id);
422    while let Some(id) = current {
423        let node = arena[id].get_mut();
424        node.visits += 1;
425        node.value_sum += value;
426        current = arena[id].parent();
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    // Test implementations for testing
435    #[derive(Debug, Clone, Default, PartialEq)]
436    struct TestGameState {
437        value: i32,
438        terminal: bool,
439    }
440
441    impl State for TestGameState {
442        type Action = TestAction;
443
444        fn fingerprint(&self) -> u64 {
445            self.value as u64
446        }
447
448        fn whose_turn(&self) -> u8 {
449            0
450        }
451
452        fn is_terminal(&self) -> bool {
453            self.terminal
454        }
455
456        fn apply(&self, action: &TestAction) -> Self {
457            Self {
458                value: self.value + action.increment,
459                terminal: self.value + action.increment > 10,
460            }
461        }
462    }
463
464    #[derive(Debug, Clone, Default)]
465    struct TestAction {
466        increment: i32,
467    }
468
469    impl TestAction {
470        fn new(increment: i32) -> Self {
471            Self { increment }
472        }
473    }
474
475    struct TestResponseGenerator;
476
477    impl ResponseGenerator for TestResponseGenerator {
478        type State = TestGameState;
479
480        fn generate(&self, state: &TestGameState) -> Vec<TestAction> {
481            if state.terminal {
482                vec![]
483            } else {
484                vec![TestAction::new(1), TestAction::new(2)]
485            }
486        }
487    }
488
489    struct EmptyResponseGenerator;
490
491    impl ResponseGenerator for EmptyResponseGenerator {
492        type State = TestGameState;
493
494        fn generate(&self, _state: &TestGameState) -> Vec<TestAction> {
495            vec![]
496        }
497    }
498
499    struct SingleResponseGenerator;
500
501    impl ResponseGenerator for SingleResponseGenerator {
502        type State = TestGameState;
503
504        fn generate(&self, state: &TestGameState) -> Vec<TestAction> {
505            if state.terminal { vec![] } else { vec![TestAction::new(1)] }
506        }
507    }
508
509    struct VariableResponseGenerator;
510
511    impl ResponseGenerator for VariableResponseGenerator {
512        type State = TestGameState;
513
514        fn generate(&self, state: &TestGameState) -> Vec<TestAction> {
515            if state.terminal {
516                vec![]
517            } else if state.value < 3 {
518                vec![TestAction::new(1), TestAction::new(2), TestAction::new(3)]
519            } else if state.value < 6 {
520                vec![TestAction::new(1), TestAction::new(2)]
521            } else {
522                vec![TestAction::new(1)]
523            }
524        }
525    }
526
527    struct TestRollout;
528
529    impl Rollout for TestRollout {
530        type State = TestGameState;
531        type ResponseGenerator = TestResponseGenerator;
532
533        fn play(&self, _state: &TestGameState, _rg: &TestResponseGenerator) -> f32 {
534            0.5 // Simple fixed rollout value for testing
535        }
536    }
537
538    // Tests for ResponseGenerator trait
539    #[test]
540    fn test_mcts_response_generator_basic() {
541        let generator = TestResponseGenerator;
542        let state = TestGameState {
543            value: 5,
544            terminal: false,
545        };
546        let terminal_state = TestGameState {
547            value: 15,
548            terminal: true,
549        };
550
551        let actions = generator.generate(&state);
552        assert_eq!(actions.len(), 2);
553        assert_eq!(actions[0].increment, 1);
554        assert_eq!(actions[1].increment, 2);
555
556        let terminal_actions = generator.generate(&terminal_state);
557        assert!(terminal_actions.is_empty());
558    }
559
560    #[test]
561    fn test_mcts_response_generator_empty() {
562        let generator = EmptyResponseGenerator;
563        let state = TestGameState {
564            value: 0,
565            terminal: false,
566        };
567
568        let actions = generator.generate(&state);
569        assert!(actions.is_empty());
570    }
571
572    #[test]
573    fn test_mcts_response_generator_single() {
574        let generator = SingleResponseGenerator;
575        let state = TestGameState {
576            value: 3,
577            terminal: false,
578        };
579
580        let actions = generator.generate(&state);
581        assert_eq!(actions.len(), 1);
582        assert_eq!(actions[0].increment, 1);
583    }
584
585    #[test]
586    fn test_mcts_response_generator_variable() {
587        let generator = VariableResponseGenerator;
588
589        // Low value state should have 3 actions
590        let low_state = TestGameState {
591            value: 1,
592            terminal: false,
593        };
594        let actions = generator.generate(&low_state);
595        assert_eq!(actions.len(), 3);
596
597        // Medium value state should have 2 actions
598        let med_state = TestGameState {
599            value: 4,
600            terminal: false,
601        };
602        let actions = generator.generate(&med_state);
603        assert_eq!(actions.len(), 2);
604
605        // High value state should have 1 action
606        let high_state = TestGameState {
607            value: 7,
608            terminal: false,
609        };
610        let actions = generator.generate(&high_state);
611        assert_eq!(actions.len(), 1);
612
613        // Terminal state should have no actions
614        let terminal_state = TestGameState {
615            value: 15,
616            terminal: true,
617        };
618        let actions = generator.generate(&terminal_state);
619        assert!(actions.is_empty());
620    }
621
622    #[test]
623    fn test_mcts_search_basic() {
624        let state = TestGameState {
625            value: 0,
626            terminal: false,
627        };
628        let generator = TestResponseGenerator;
629        let rollout = TestRollout;
630
631        // Test with minimal iterations
632        let result = search(&state, &generator, &rollout, 1.0, 1);
633        // Since we have actions available, should return Some action
634        assert!(result.is_some());
635    }
636
637    #[test]
638    fn test_mcts_search_terminal_state() {
639        let state = TestGameState {
640            value: 0,
641            terminal: true,
642        };
643        let generator = TestResponseGenerator;
644        let rollout = TestRollout;
645
646        // Terminal state should return None (no actions)
647        let result = search(&state, &generator, &rollout, 1.0, 10);
648        assert!(result.is_none());
649    }
650
651    #[test]
652    fn test_mcts_search_zero_iterations() {
653        let state = TestGameState {
654            value: 0,
655            terminal: false,
656        };
657        let generator = TestResponseGenerator;
658        let rollout = TestRollout;
659
660        // Zero iterations should still work
661        let result = search(&state, &generator, &rollout, 1.0, 0);
662        // Might return None or Some depending on implementation
663        // Just verify it doesn't crash
664        let _ = result;
665    }
666
667    #[test]
668    fn test_mcts_search_different_c_values() {
669        let state = TestGameState {
670            value: 0,
671            terminal: false,
672        };
673        let generator = TestResponseGenerator;
674        let rollout = TestRollout;
675
676        // Test with different exploration constants
677        let result1 = search(&state, &generator, &rollout, 0.1, 5);
678        let result2 = search(&state, &generator, &rollout, 2.0, 5);
679
680        // Both should work (might return different results)
681        assert!(result1.is_some() || state.terminal);
682        assert!(result2.is_some() || state.terminal);
683    }
684
685    #[test]
686    fn test_mcts_search_multiple_iterations() {
687        let state = TestGameState {
688            value: 0,
689            terminal: false,
690        };
691        let generator = TestResponseGenerator;
692        let rollout = TestRollout;
693
694        // Test with multiple iterations
695        let result = search(&state, &generator, &rollout, 1.4, 50);
696
697        // Should return an action if state is not terminal
698        if !state.terminal {
699            assert!(result.is_some());
700        }
701    }
702
703    #[test]
704    fn test_mcts_search_consistency() {
705        let state = TestGameState {
706            value: 5,
707            terminal: false,
708        };
709        let generator = TestResponseGenerator;
710        let rollout = TestRollout;
711
712        // Multiple searches on same state should work
713        let result1 = search(&state, &generator, &rollout, 1.0, 10);
714        let result2 = search(&state, &generator, &rollout, 1.0, 10);
715
716        // Both should return results (might be different due to randomness)
717        assert!(result1.is_some());
718        assert!(result2.is_some());
719    }
720}