1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
//! Model Processor module for text generation.
//!
//! This module contains the `ModelProcessor` trait and its implementations
//! which are used for processing input tensors and generating output tensors
//! representing logits from a language model.

use super::Model;
use candle_core::{Result, Tensor};

/// A trait for processing model inputs and generating outputs.
///
/// This trait defines a method for processing input tensors through a model
/// and generating output tensors.
pub trait ModelProcessor {
    /// Processes an input tensor and generates an output tensor.
    ///
    /// # Arguments
    ///
    /// * `x` - A reference to the input tensor.
    /// * `index_pos` - The position index for processing.
    ///
    /// # Returns
    ///
    /// Returns a `Result` containing the output tensor.
    fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor>;
}

impl ModelProcessor for Model {
    fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
        match self {
            Model::Llama(model) => model.forward(x, index_pos),
            Model::MixFormer(model) => model.forward(x),
        }
    }
}

/// A dummy implementation of `ModelProcessor` for testing purposes.
///
/// This processor simulates model outputs by returning incrementing tensors.
pub struct DummyModelProcessor {
    index: usize,
}

impl DummyModelProcessor {
    /// Creates a new `DummyModelProcessor`.
    pub fn new() -> Self {
        Self { index: 0 }
    }
}

/// Provides a default instance of `DummyModelProcessor`.
impl Default for DummyModelProcessor {
    fn default() -> Self {
        Self::new()
    }
}

/// Implementation of `ModelProcessor` for `DummyModelProcessor`.
impl ModelProcessor for DummyModelProcessor {
    fn forward(&mut self, x: &Tensor, _index_pos: usize) -> Result<Tensor> {
        self.index += 1;
        let y = Tensor::new(&[self.index as f32 - 1.0], x.device())?;
        Ok(y)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use candle_core::Device;

    /// Tests the `DummyModelProcessor` to ensure it returns incrementing tensors.
    #[test]
    fn test_dummy_model_processor() {
        let mut model_processor = DummyModelProcessor::new();
        let x = Tensor::new(&[0.0], &Device::Cpu).unwrap();

        for index in 0..10 {
            let y = model_processor
                .forward(&x, index)
                .unwrap()
                .to_vec1::<f32>()
                .unwrap();
            assert_eq!(y, vec![index as f32]);
        }
    }
}