use crate::simulation::agents::RailMovableAction;
use rurel::mdp::State;
#[derive(PartialEq, Eq, Hash, Clone, Debug, Default)]
pub struct TrainAgentState {
    pub delta_distance_mm: i32,
    pub current_speed_mm_s: i32,
    pub max_speed_percentage: i32,
}
impl TrainAgentState {
    const MAX_ACCELERATION: i32 = 1000; const ACCELERATION_STEP: i32 = 20;
    fn speed_reward(&self) -> f64 {
        (self.max_speed_percentage as f64 / 100.0).powi(2)
    }
    fn distance_reward(&self) -> f64 {
        self.delta_distance_mm as f64
    }
}
impl State for TrainAgentState {
    type A = RailMovableAction;
    fn reward(&self) -> f64 {
        20.0 * self.speed_reward() + self.distance_reward()
    }
    fn actions(&self) -> Vec<Self::A> {
        let mut actions = vec![Self::A::Stop];
        for acceleration in 1..=(Self::MAX_ACCELERATION / Self::ACCELERATION_STEP) {
            actions.push(Self::A::AccelerateForward {
                acceleration: acceleration * Self::ACCELERATION_STEP,
            });
            actions.push(Self::A::AccelerateBackward {
                acceleration: acceleration * Self::ACCELERATION_STEP,
            });
        }
        actions
    }
    fn random_action(&self) -> Self::A {
        let actions = self.actions();
        let a_t = rand::random::<usize>() % actions.len();
        actions[a_t].clone()
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn test_train_agent_state() {
        let state = TrainAgentState {
            delta_distance_mm: 1000,
            current_speed_mm_s: 0,
            max_speed_percentage: 0,
        };
        assert_eq!(state.delta_distance_mm, 1000);
        assert_eq!(state.current_speed_mm_s, 0);
        assert_eq!(state.max_speed_percentage, 0);
    }
}