1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
//! This module contains a factory for creating decision agents.
//!
//! The factory provides a list of options for different decision agent implementations,
//! such as TrainAgentAI and ForwardUntilTargetAgent.
use super::DecisionAgent;
use super::ForwardUntilTargetAgent;
use super::RailMovableAction;
#[cfg(feature = "ai")]
use crate::ai::{TrainAgentAI, TrainAgentState};
use crate::simulation::environment::ObservableEnvironment;
use crate::simulation::SimulationEnvironment;
use crate::types::RailwayObjectId;
/// An enumeration of available decision agent implementations.
#[derive(Default, Debug, PartialEq)]
pub enum DecisionAgentOption {
/// TrainAgentAI implementation, requires the "ai" feature to be enabled.
#[cfg(feature = "ai")]
TrainAgentAI,
/// ForwardUntilTargetAgent implementation.
#[default]
ForwardUntilTargetAgent,
}
/// A factory for creating decision agents.
pub struct DecisionAgentFactory;
impl DecisionAgentFactory {
/// Create a decision agent based on the provided option.
///
/// # Arguments
///
/// * `option` - The selected decision agent implementation.
/// * `id` - The RailwayObjectId for the agent.
/// * `environment` - A reference to a SimulationEnvironment.
///
/// # Returns
///
/// * A boxed `DecisionAgent` trait object with the specified implementation.
pub fn create_decision_agent(
option: DecisionAgentOption,
id: RailwayObjectId,
environment: &SimulationEnvironment,
) -> Box<dyn DecisionAgent<A = RailMovableAction>> {
match option {
#[cfg(feature = "ai")]
DecisionAgentOption::TrainAgentAI => {
let mut train_agent_ai =
TrainAgentAI::new(environment.get_graph().clone(), TrainAgentState::default());
train_agent_ai.train(10000);
Box::new(train_agent_ai)
}
DecisionAgentOption::ForwardUntilTargetAgent => {
Box::new(ForwardUntilTargetAgent::new(id))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::railway_model::RailwayGraph;
use std::collections::HashMap;
#[test]
fn test_create_decision_agent() {
let environment = SimulationEnvironment {
graph: RailwayGraph::default(),
objects: HashMap::new(),
};
let id = 1;
// Test creation of ForwardUntilTargetAgent
let agent = DecisionAgentFactory::create_decision_agent(
DecisionAgentOption::ForwardUntilTargetAgent,
id,
&environment,
);
assert!(agent.as_any().is::<ForwardUntilTargetAgent>());
// Test creation of TrainAgentAI if the "ai" feature is enabled
#[cfg(feature = "ai")]
{
let agent = DecisionAgentFactory::create_decision_agent(
DecisionAgentOption::TrainAgentAI,
id,
&environment,
);
assert!(agent.as_any().is::<TrainAgentAI>());
}
}
}