1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
//! The `ai` module provides a reinforcement learning agent that controls a train in a railway network simulation.
//!
//! The module defines the `TrainAgentState` struct, which represents the state of a train agent in the simulation,
//! and the `TrainAgentAction` enum, which represents the possible actions a train agent can take in the simulation.
//!
//! The main component of this module is the `TrainAgentRL` struct, which implements the `rurel::mdp::Agent` trait
//! for the train agent. This struct handles the reinforcement learning process, updating the train agent's state
//! based on the actions taken and the rewards received.
//!
use std::any::Any;
use std::fmt;
use std::sync::{Arc, RwLock};
use crate::prelude::RailwayGraph;
use crate::simulation::agents::{DecisionAgent, RailMovableAction};
use crate::simulation::environment::ObservableEnvironment;
use crate::simulation::SimulationEnvironment;
use crate::types::{NodeId, RailwayObjectId};
use rurel::strategy::explore::RandomExploration;
use rurel::strategy::learn::QLearning;
use rurel::strategy::terminate::FixedIterations;
use rurel::AgentTrainer;
use uom::si::velocity::millimeter_per_second;
mod train_agent_state;
pub use train_agent_state::TrainAgentState;
mod train_agent_rl;
pub use train_agent_rl::TrainAgentRL;
/// A reinforcement learning agent that controls a train in the simulation.
#[derive(Default, Clone)]
pub struct TrainAgentAI {
/// The id of the railway object
pub id: RailwayObjectId,
/// The railway graph representing the train network.
pub railway_graph: Option<RailwayGraph>,
/// The current node
pub current_node: Option<NodeId>,
/// The target node
pub target_node: Option<NodeId>,
/// The reinforcement learning agent responsible for controlling the train.
pub agent_rl: TrainAgentRL,
/// The trainer responsible for training the reinforcement learning agent.
pub trainer: Arc<RwLock<AgentTrainer<TrainAgentState>>>,
}
impl fmt::Debug for TrainAgentAI {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TrainAgentAI")
.field("railway_graph", &self.railway_graph)
.field("agent_rl", &self.agent_rl)
// We use `Arc::as_ptr()` to display the pointer value of the `AgentTrainer` to avoid
// issues with its non-Debug fields.
.field("trainer", &format_args!("{:p}", Arc::as_ptr(&self.trainer)))
.finish()
}
}
impl TrainAgentAI {
/// Creates a new `TrainAgentAI` with the given railway graph and initial state.
///
/// # Arguments
///
/// * `railway_graph` - The railway graph representing the train network.
/// * `initial_state` - The initial state of the train agent in the simulation.
///
/// # Returns
///
/// A new `TrainAgentAI` instance.
pub fn new(railway_graph: RailwayGraph, initial_state: TrainAgentState) -> Self {
let agent_rl = TrainAgentRL {
state: initial_state,
max_speed_mm_s: ((160.0 / 3.6) as i32) * 1000,
};
let trainer = Arc::new(RwLock::new(AgentTrainer::new()));
Self {
id: 0,
railway_graph: Some(railway_graph),
current_node: None,
target_node: None,
agent_rl,
trainer,
}
}
/// Trains the reinforcement learning agent for the specified number of iterations.
///
/// # Arguments
///
/// * `iterations` - The number of iterations to train the agent.
pub fn train(&mut self, iterations: usize) {
println!("Starting training for {} iterations...", iterations);
let mut agent = self.agent_rl.clone();
let mut trainer = self.trainer.write().unwrap();
trainer.train(
&mut agent,
&QLearning::new(0.2, 0.01, 20.),
&mut FixedIterations::new(iterations as u32),
&RandomExploration::new(),
);
}
/// Returns the best action for the given state according to the trained reinforcement learning agent.
///
/// # Arguments
///
/// * `state` - The current state of the train agent in the simulation.
///
/// # Returns
///
/// The best action for the given state or `None` if no action can be selected.
pub fn best_action(&self, state: &TrainAgentState) -> Option<RailMovableAction> {
Some(
self.trainer
.read()
.unwrap()
.best_action(state)
.unwrap_or_default(),
)
}
/// Updates the current edge, target node (optionally), and calculates the new state by
/// updating its distance using the shortest path distance while keeping the current speed constant.
///
/// # Arguments
///
/// * `current_edge` - The current edge on which the train agent is.
/// * `target_node` - The optional target node the train agent is heading to.
/// * `time_delta_ms` - The time delta in milliseconds since the last update.
pub fn observe(
&mut self,
current_node: NodeId,
target_node: Option<NodeId>,
speed_mm_s: Option<i32>,
delta_distance_mm: Option<i32>,
) {
self.current_node = Some(current_node);
self.target_node = target_node;
let mut agent_state = self.agent_rl.state.clone();
if let Some(speed) = speed_mm_s {
agent_state.current_speed_mm_s = speed;
agent_state.max_speed_percentage = 100 * speed / self.agent_rl.max_speed_mm_s;
}
if let Some(delta_distance_mm) = delta_distance_mm {
agent_state.delta_distance_mm = delta_distance_mm;
}
self.agent_rl.state = agent_state;
}
}
impl DecisionAgent for TrainAgentAI {
type A = RailMovableAction;
fn next_action(&self, _delta_time: Option<std::time::Duration>) -> Self::A {
self.best_action(&self.agent_rl.state).unwrap_or_default()
}
fn observe(&mut self, environment: &SimulationEnvironment) {
if let Some(object) = environment.get_objects().iter().find(|o| o.id() == self.id) {
self.current_node = object.position();
self.target_node = object.next_target();
let mut agent_state = self.agent_rl.state.clone();
let speed = object.speed();
agent_state.current_speed_mm_s = speed.get::<millimeter_per_second>() as i32;
agent_state.max_speed_percentage = (100.0 * speed.get::<millimeter_per_second>()
/ self.agent_rl.max_speed_mm_s as f64)
as i32;
self.agent_rl.state = agent_state;
}
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::test_graph_vilbel;
#[test]
fn test_train_agent_ai() {
let graph = test_graph_vilbel();
let mut train_agent_ai = TrainAgentAI::new(graph, Default::default());
let source_node = 662529467;
let target_node = 662529466;
train_agent_ai.observe(source_node, Some(target_node), Some(1000), Some(1000));
train_agent_ai.train(10000);
let state = TrainAgentState {
delta_distance_mm: 1000,
current_speed_mm_s: 1000,
max_speed_percentage: 20,
};
let action = train_agent_ai.best_action(&state);
assert_ne!(action, None);
}
}