1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
use super::{
token_generator::{TokenGeneratorResult, TokenGeneratorTrait},
FinishReason,
};
use anyhow::Result;
use candle_examples::token_output_stream::TokenOutputStream;
mod dummy_text_generator;
/// Represents the probability associated with a piece of generated text.
pub type TextProbability = (String, f32);
/// Enumerates possible results from a text generation process.
///
/// This enum is used to encapsulate the outcomes of text generation, including
/// both the generation of a new token and the conclusion of the generation process.
#[derive(Debug, PartialEq)]
pub enum TextGeneratorResult {
/// Represents a generated piece of text along with its probability.
///
/// The `String` is the generated text, and the `f32` is the probability associated with it.
Token(TextProbability),
/// Indicates the completion of the text generation process.
///
/// This variant is used when the generation process reaches an end, either due to reaching
/// a specified limit or encountering a stopping condition.
Finish(FinishReason),
}
/// A trait defining the core functionality for text generation.
///
/// This trait encapsulates the necessary methods for initializing the generation process with a
/// prompt and then producing text iteratively.
pub trait TextGeneratorTrait {
/// Initializes the text generation process with a given prompt.
///
/// This method sets up the necessary state for text generation based on the provided prompt.
///
/// # Arguments
///
/// * `prompt` - A `String` that serves as the starting point for text generation.
///
/// # Returns
///
/// A `Result` indicating success or failure of the initialization process.
fn init(&mut self, prompt: String) -> Result<()>;
/// Generates the next piece of text in the sequence.
///
/// This method should be called iteratively to generate text progressively.
/// It provides the next piece of text based on the current state of the generator.
///
/// # Returns
///
/// A `Result` wrapping a `TextGeneratorResult`, which can be either a generated token
/// or an indication that the generation process has finished.
fn next(&mut self) -> Result<TextGeneratorResult>;
}
/// Handles the text generation process.
///
/// This struct is responsible for managing the token generation and converting tokens into text.
pub struct TextGenerator {
/// The tokenizer used to encode the prompt and decode the generated tokens.
tokenizer: TokenOutputStream,
/// The token generator that produces tokens based on the model's output.
token_generator: Box<dyn TokenGeneratorTrait>,
}
impl TextGenerator {
/// Constructs a new `TextGenerator`.
///
/// # Arguments
///
/// * `tokenizer` - Tokenizer for encoding prompts and decoding generated tokens.
/// * `token_generator` - Token generator that provides the logic for generating tokens.
pub fn new(
tokenizer: TokenOutputStream,
token_generator: Box<dyn TokenGeneratorTrait>,
) -> Self {
Self {
tokenizer,
token_generator,
}
}
}
impl TextGeneratorTrait for TextGenerator {
fn init(&mut self, prompt: String) -> Result<()> {
let prompt_tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(anyhow::Error::msg)?;
self.token_generator
.init(prompt_tokens.get_ids().to_vec())?;
Ok(())
}
fn next(&mut self) -> Result<TextGeneratorResult> {
let token = self.token_generator.next()?;
match token {
TokenGeneratorResult::Token((token, probability)) => {
let text = self.tokenizer.next_token(token)?;
match text {
Some(text) => Ok(TextGeneratorResult::Token((text, probability))),
None => Ok(TextGeneratorResult::Token(("".to_string(), 1.0))),
}
}
TokenGeneratorResult::Finish(reason) => Ok(TextGeneratorResult::Finish(reason)),
}
}
}
#[cfg(test)]
mod tests {
use crate::llm::{
generate_parameter::GenerateParameter, token_generator::dummy::DummyTokenGenerator,
};
use super::*;
#[test]
fn test_text_generator() {
let mut text_generator = TextGenerator::new(
TokenOutputStream::new(tokenizers::tokenizer::Tokenizer::new(
tokenizers::models::bpe::BPE::default(),
)),
Box::new(DummyTokenGenerator::new(GenerateParameter {
max_new_tokens: 10,
..Default::default()
})),
);
text_generator.init("Hello World".to_string()).unwrap();
for _ in 0..10 {
assert!(match text_generator.next().unwrap() {
TextGeneratorResult::Token((_, _)) => true,
_ => false,
});
}
assert_eq!(
text_generator.next().unwrap(),
TextGeneratorResult::Finish(FinishReason::Length)
);
}
}