1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
use super::{
    token_generator::{TokenGeneratorResult, TokenGeneratorTrait},
    FinishReason,
};
use anyhow::Result;
use candle_examples::token_output_stream::TokenOutputStream;
mod dummy_text_generator;

/// Represents the probability associated with a piece of generated text.
pub type TextProbability = (String, f32);

/// Enumerates possible results from a text generation process.
///
/// This enum is used to encapsulate the outcomes of text generation, including
/// both the generation of a new token and the conclusion of the generation process.
#[derive(Debug, PartialEq)]
pub enum TextGeneratorResult {
    /// Represents a generated piece of text along with its probability.
    ///
    /// The `String` is the generated text, and the `f32` is the probability associated with it.
    Token(TextProbability),

    /// Indicates the completion of the text generation process.
    ///
    /// This variant is used when the generation process reaches an end, either due to reaching
    /// a specified limit or encountering a stopping condition.
    Finish(FinishReason),
}

/// A trait defining the core functionality for text generation.
///
/// This trait encapsulates the necessary methods for initializing the generation process with a
/// prompt and then producing text iteratively.
pub trait TextGeneratorTrait {
    /// Initializes the text generation process with a given prompt.
    ///
    /// This method sets up the necessary state for text generation based on the provided prompt.
    ///
    /// # Arguments
    ///
    /// * `prompt` - A `String` that serves as the starting point for text generation.
    ///
    /// # Returns
    ///
    /// A `Result` indicating success or failure of the initialization process.
    fn init(&mut self, prompt: String) -> Result<()>;

    /// Generates the next piece of text in the sequence.
    ///
    /// This method should be called iteratively to generate text progressively.
    /// It provides the next piece of text based on the current state of the generator.
    ///
    /// # Returns
    ///
    /// A `Result` wrapping a `TextGeneratorResult`, which can be either a generated token
    /// or an indication that the generation process has finished.
    fn next(&mut self) -> Result<TextGeneratorResult>;
}

/// Handles the text generation process.
///
/// This struct is responsible for managing the token generation and converting tokens into text.
pub struct TextGenerator {
    /// The tokenizer used to encode the prompt and decode the generated tokens.
    tokenizer: TokenOutputStream,

    /// The token generator that produces tokens based on the model's output.
    token_generator: Box<dyn TokenGeneratorTrait>,
}

impl TextGenerator {
    /// Constructs a new `TextGenerator`.
    ///
    /// # Arguments
    ///
    /// * `tokenizer` - Tokenizer for encoding prompts and decoding generated tokens.
    /// * `token_generator` - Token generator that provides the logic for generating tokens.
    pub fn new(
        tokenizer: TokenOutputStream,
        token_generator: Box<dyn TokenGeneratorTrait>,
    ) -> Self {
        Self {
            tokenizer,
            token_generator,
        }
    }
}

impl TextGeneratorTrait for TextGenerator {
    fn init(&mut self, prompt: String) -> Result<()> {
        let prompt_tokens = self
            .tokenizer
            .tokenizer()
            .encode(prompt, true)
            .map_err(anyhow::Error::msg)?;
        self.token_generator
            .init(prompt_tokens.get_ids().to_vec())?;
        Ok(())
    }

    fn next(&mut self) -> Result<TextGeneratorResult> {
        let token = self.token_generator.next()?;
        match token {
            TokenGeneratorResult::Token((token, probability)) => {
                let text = self.tokenizer.next_token(token)?;
                match text {
                    Some(text) => Ok(TextGeneratorResult::Token((text, probability))),
                    None => Ok(TextGeneratorResult::Token(("".to_string(), 1.0))),
                }
            }
            TokenGeneratorResult::Finish(reason) => Ok(TextGeneratorResult::Finish(reason)),
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::llm::{
        generate_parameter::GenerateParameter, token_generator::dummy::DummyTokenGenerator,
    };

    use super::*;

    #[test]
    fn test_text_generator() {
        let mut text_generator = TextGenerator::new(
            TokenOutputStream::new(tokenizers::tokenizer::Tokenizer::new(
                tokenizers::models::bpe::BPE::default(),
            )),
            Box::new(DummyTokenGenerator::new(GenerateParameter {
                max_new_tokens: 10,
                ..Default::default()
            })),
        );
        text_generator.init("Hello World".to_string()).unwrap();
        for _ in 0..10 {
            assert!(match text_generator.next().unwrap() {
                TextGeneratorResult::Token((_, _)) => true,
                _ => false,
            });
        }
        assert_eq!(
            text_generator.next().unwrap(),
            TextGeneratorResult::Finish(FinishReason::Length)
        );
    }
}