use crate::ai::TrainAgentState;
use crate::simulation::agents::RailMovableAction;
use rurel::mdp::{Agent, State};
#[derive(Default, Clone, Debug)]
pub struct TrainAgentRL {
pub state: TrainAgentState,
pub max_speed_mm_s: i32,
}
impl TrainAgentRL {
const TIME_DELTA_MS: u32 = 1000;
}
impl Agent<TrainAgentState> for TrainAgentRL {
fn current_state(&self) -> &TrainAgentState {
&self.state
}
fn take_action(&mut self, action: &RailMovableAction) {
match action {
RailMovableAction::Stop => {
self.state.current_speed_mm_s = 0;
}
RailMovableAction::AccelerateForward { acceleration } => {
self.state.current_speed_mm_s += acceleration * Self::TIME_DELTA_MS as i32 / 1000;
self.state.delta_distance_mm =
self.state.current_speed_mm_s * Self::TIME_DELTA_MS as i32 / 1000;
}
RailMovableAction::AccelerateBackward { acceleration } => {
self.state.current_speed_mm_s -= acceleration * Self::TIME_DELTA_MS as i32 / 1000;
self.state.delta_distance_mm =
self.state.current_speed_mm_s * Self::TIME_DELTA_MS as i32 / 1000;
}
}
self.state.max_speed_percentage =
(((self.state.current_speed_mm_s as f64 / self.max_speed_mm_s as f64) * 100.0) as i32)
.abs();
}
fn pick_random_action(&mut self) -> <TrainAgentState as State>::A {
let action = self.current_state().random_action();
self.take_action(&action);
action
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_take_action() {
let mut agent = TrainAgentRL {
state: TrainAgentState {
delta_distance_mm: 1000,
current_speed_mm_s: 0,
max_speed_percentage: 0,
},
max_speed_mm_s: (1000.0 * 160.0 / 3.6) as i32,
};
agent.take_action(&RailMovableAction::AccelerateForward { acceleration: 1000 });
assert_eq!(agent.state.current_speed_mm_s, 1000);
assert_eq!(agent.state.delta_distance_mm, 1000);
agent.take_action(&RailMovableAction::AccelerateForward { acceleration: 500 });
assert_eq!(agent.state.current_speed_mm_s, 1500);
assert_eq!(agent.state.delta_distance_mm, 1500);
agent.take_action(&RailMovableAction::AccelerateBackward { acceleration: 500 });
assert_eq!(agent.state.current_speed_mm_s, 1000);
assert_eq!(agent.state.delta_distance_mm, 1000);
agent.take_action(&RailMovableAction::Stop);
assert_eq!(agent.state.current_speed_mm_s, 0);
}
}