1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
//! This module contains a factory for creating decision agents.
//!
//! The factory provides a list of options for different decision agent implementations,
//! such as TrainAgentAI and ForwardUntilTargetAgent.

use super::DecisionAgent;
use super::ForwardUntilTargetAgent;
use super::RailMovableAction;
#[cfg(feature = "ai")]
use crate::ai::{TrainAgentAI, TrainAgentState};
use crate::simulation::environment::ObservableEnvironment;
use crate::simulation::SimulationEnvironment;
use crate::types::RailwayObjectId;

/// An enumeration of available decision agent implementations.
#[derive(Default, Debug, PartialEq)]
pub enum DecisionAgentOption {
    /// TrainAgentAI implementation, requires the "ai" feature to be enabled.
    #[cfg(feature = "ai")]
    TrainAgentAI,

    /// ForwardUntilTargetAgent implementation.
    #[default]
    ForwardUntilTargetAgent,
}

/// A factory for creating decision agents.
pub struct DecisionAgentFactory;

impl DecisionAgentFactory {
    /// Create a decision agent based on the provided option.
    ///
    /// # Arguments
    ///
    /// * `option` - The selected decision agent implementation.
    /// * `id` - The RailwayObjectId for the agent.
    /// * `environment` - A reference to a SimulationEnvironment.
    ///
    /// # Returns
    ///
    /// * A boxed `DecisionAgent` trait object with the specified implementation.
    pub fn create_decision_agent(
        option: DecisionAgentOption,
        id: RailwayObjectId,
        environment: &SimulationEnvironment,
    ) -> Box<dyn DecisionAgent<A = RailMovableAction>> {
        match option {
            #[cfg(feature = "ai")]
            DecisionAgentOption::TrainAgentAI => {
                let mut train_agent_ai =
                    TrainAgentAI::new(environment.get_graph().clone(), TrainAgentState::default());
                train_agent_ai.train(10000);
                Box::new(train_agent_ai)
            }
            DecisionAgentOption::ForwardUntilTargetAgent => {
                Box::new(ForwardUntilTargetAgent::new(id))
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::railway_model::RailwayGraph;
    use std::collections::HashMap;

    #[test]
    fn test_create_decision_agent() {
        let environment = SimulationEnvironment {
            graph: RailwayGraph::default(),
            objects: HashMap::new(),
        };
        let id = 1;

        // Test creation of ForwardUntilTargetAgent
        let agent = DecisionAgentFactory::create_decision_agent(
            DecisionAgentOption::ForwardUntilTargetAgent,
            id,
            &environment,
        );
        assert!(agent.as_any().is::<ForwardUntilTargetAgent>());

        // Test creation of TrainAgentAI if the "ai" feature is enabled
        #[cfg(feature = "ai")]
        {
            let agent = DecisionAgentFactory::create_decision_agent(
                DecisionAgentOption::TrainAgentAI,
                id,
                &environment,
            );
            assert!(agent.as_any().is::<TrainAgentAI>());
        }
    }
}