1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
use crate::api::model::GenerateRequest;
use crate::llm::generate_parameter::GenerateParameter;
use crate::llm::text_generation::create_text_generation;
use crate::server::AppState;
use axum::{
    extract::State,
    response::{sse::Event, IntoResponse, Sse},
    Json,
};
use futures::stream::StreamExt;
use log::debug;
use std::vec;

/// Asynchronous handler for generating text through a streaming API.
///
/// This function handles POST requests to the `/generate_stream` endpoint. It takes a JSON payload
/// representing a `GenerateRequest` and uses the configuration and parameters specified to
/// generate text using a streaming approach. The response is a stream of Server-Sent Events (SSE),
/// allowing clients to receive generated text in real-time as it is produced.
///
/// # Parameters
/// - `config`: Application state holding the global configuration.
/// - `Json(payload)`: JSON payload containing the input text and generation parameters.
///
/// # Responses
/// - `200 OK`: Stream of generated text as `StreamResponse` events.
/// - Error responses: Descriptive error messages if any issues occur.
///
/// # Usage
/// This endpoint is suitable for scenarios where real-time text generation is required,
/// such as interactive chatbots or live content creation tools.
#[utoipa::path(
    post,
    path = "/generate_stream",
    request_body = GenerateRequest,
    responses(
        (status = 200, description = "Generated Text", body = StreamResponse),
    ),
    tag = "Text Generation Inference"
)]
pub async fn generate_stream_handler(
    app_state: State<AppState>,
    Json(payload): Json<GenerateRequest>,
) -> impl IntoResponse {
    debug!("Received request: {:?}", payload);
    let temperature = match &payload.parameters {
        Some(parameters) => parameters.temperature,
        None => None,
    };
    let top_p: Option<f64> = match &payload.parameters {
        Some(parameters) => parameters.top_p,
        None => None,
    };
    let repeat_penalty: f32 = match &payload.parameters {
        Some(parameters) => parameters.repetition_penalty.unwrap_or(1.1),
        None => 1.1,
    };
    let repeat_last_n = match &payload.parameters {
        Some(parameters) => parameters.top_n_tokens.unwrap_or(64) as usize,
        None => 64,
    };
    let sample_len = match &payload.parameters {
        Some(parameters) => parameters.max_new_tokens.unwrap_or(50) as usize,
        None => 50,
    };

    let stop_tokens = match &payload.parameters {
        Some(parameters) => parameters.stop.clone(),
        None => vec!["<|endoftext|>".to_string(), "</s>".to_string()],
    };

    let config = app_state.config.clone();

    let mut generator = match &app_state.text_generation {
        Some(text_generation) => text_generation.clone(),
        None => create_text_generation(config.model, &config.cache_dir).unwrap(),
    };

    let parameter = GenerateParameter {
        temperature: temperature.unwrap_or_default(),
        top_p: top_p.unwrap_or_default(),
        max_new_tokens: sample_len,
        seed: 42,
        repeat_penalty,
        repeat_last_n,
    };

    let stream = generator.run_stream(&payload.inputs, parameter, Some(stop_tokens));

    let event_stream = stream.map(|response| -> Result<Event, std::convert::Infallible> {
        let data = serde_json::to_string(&response)
            .unwrap_or_else(|_| "Error serializing response".to_string());
        Ok(Event::default().data(data))
    });
    Sse::new(event_stream)
}