1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
use crate::simulation::agents::RailMovableAction;
use rurel::mdp::State;

/// Represents the state of a train agent in the simulation.
#[derive(PartialEq, Eq, Hash, Clone, Debug, Default)]
pub struct TrainAgentState {
    /// The remaining distance in millimeters the train agent needs to travel.
    pub delta_distance_mm: i32,
    /// The current speed of the train agent in millimeters per second (mm/s).
    pub current_speed_mm_s: i32,
    /// The maximum speed percentage the train agent can reach (e.g., 100 for 100% of the maximum speed).
    pub max_speed_percentage: i32,
}

impl TrainAgentState {
    const MAX_ACCELERATION: i32 = 1000; // 1000 mm/s², approximately 1 m/s²
    const ACCELERATION_STEP: i32 = 20;

    fn speed_reward(&self) -> f64 {
        (self.max_speed_percentage as f64 / 100.0).powi(2)
    }

    fn distance_reward(&self) -> f64 {
        self.delta_distance_mm as f64
    }
}

impl State for TrainAgentState {
    type A = RailMovableAction;

    fn reward(&self) -> f64 {
        20.0 * self.speed_reward() + self.distance_reward()
    }

    fn actions(&self) -> Vec<Self::A> {
        let mut actions = vec![Self::A::Stop];
        for acceleration in 1..=(Self::MAX_ACCELERATION / Self::ACCELERATION_STEP) {
            actions.push(Self::A::AccelerateForward {
                acceleration: acceleration * Self::ACCELERATION_STEP,
            });
            actions.push(Self::A::AccelerateBackward {
                acceleration: acceleration * Self::ACCELERATION_STEP,
            });
        }
        actions
    }

    fn random_action(&self) -> Self::A {
        let actions = self.actions();
        let a_t = rand::random::<usize>() % actions.len();
        actions[a_t].clone()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_train_agent_state() {
        let state = TrainAgentState {
            delta_distance_mm: 1000,
            current_speed_mm_s: 0,
            max_speed_percentage: 0,
        };

        assert_eq!(state.delta_distance_mm, 1000);
        assert_eq!(state.current_speed_mm_s, 0);
        assert_eq!(state.max_speed_percentage, 0);
    }
}