1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
use crate::ai::TrainAgentState;
use crate::simulation::agents::RailMovableAction;
use rurel::mdp::{Agent, State};

/// Reinforcement Learning Agent for controlling a train in the simulation.
#[derive(Default, Clone, Debug)]
pub struct TrainAgentRL {
    /// The current state of the train agent.
    pub state: TrainAgentState,
    /// The maximum speed the train agent can reach in millimeters per second (mm/s).
    pub max_speed_mm_s: i32,
}

impl TrainAgentRL {
    const TIME_DELTA_MS: u32 = 1000;
}

impl Agent<TrainAgentState> for TrainAgentRL {
    fn current_state(&self) -> &TrainAgentState {
        &self.state
    }

    fn take_action(&mut self, action: &RailMovableAction) {
        match action {
            RailMovableAction::Stop => {
                self.state.current_speed_mm_s = 0;
            }
            RailMovableAction::AccelerateForward { acceleration } => {
                self.state.current_speed_mm_s += acceleration * Self::TIME_DELTA_MS as i32 / 1000;
                self.state.delta_distance_mm =
                    self.state.current_speed_mm_s * Self::TIME_DELTA_MS as i32 / 1000;
            }
            RailMovableAction::AccelerateBackward { acceleration } => {
                self.state.current_speed_mm_s -= acceleration * Self::TIME_DELTA_MS as i32 / 1000;
                self.state.delta_distance_mm =
                    self.state.current_speed_mm_s * Self::TIME_DELTA_MS as i32 / 1000;
            }
        }
        self.state.max_speed_percentage =
            (((self.state.current_speed_mm_s as f64 / self.max_speed_mm_s as f64) * 100.0) as i32)
                .abs();
    }

    fn pick_random_action(&mut self) -> <TrainAgentState as State>::A {
        let action = self.current_state().random_action();
        self.take_action(&action);
        action
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_take_action() {
        let mut agent = TrainAgentRL {
            state: TrainAgentState {
                delta_distance_mm: 1000,
                current_speed_mm_s: 0,
                max_speed_percentage: 0,
            },
            max_speed_mm_s: (1000.0 * 160.0 / 3.6) as i32,
        };

        // Test AccelerateForward action
        agent.take_action(&RailMovableAction::AccelerateForward { acceleration: 1000 });
        assert_eq!(agent.state.current_speed_mm_s, 1000);
        assert_eq!(agent.state.delta_distance_mm, 1000);

        // Test AccelerateForward action
        agent.take_action(&RailMovableAction::AccelerateForward { acceleration: 500 });
        assert_eq!(agent.state.current_speed_mm_s, 1500);
        assert_eq!(agent.state.delta_distance_mm, 1500);

        // Test AccelerateBackward action
        agent.take_action(&RailMovableAction::AccelerateBackward { acceleration: 500 });
        assert_eq!(agent.state.current_speed_mm_s, 1000);
        assert_eq!(agent.state.delta_distance_mm, 1000);

        // Test Stop action
        agent.take_action(&RailMovableAction::Stop);
        assert_eq!(agent.state.current_speed_mm_s, 0);
    }
}