1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
use axum::{
    extract::State,
    http::StatusCode,
    response::{IntoResponse, Response},
    Json,
};

use crate::{
    api::model::{CompatGenerateRequest, ErrorResponse, GenerateRequest},
    server::AppState,
};

use super::{generate_stream::generate_stream_handler, generate_text_handler};

/// Handler for generating text tokens.
///
/// This endpoint accepts a `CompatGenerateRequest` and returns a stream of generated text
/// or a single text response based on the `stream` field in the request. If `stream` is true,
/// it returns a stream of `StreamResponse`. If `stream` is false, it returns `GenerateResponse`.
///
/// # Arguments
/// * `config` - State containing the application configuration.
/// * `payload` - JSON payload containing the input text and optional parameters.
///
/// # Responses
/// * `200 OK` - Successful generation of text.
/// * `501 Not Implemented` - Returned if streaming is not implemented.
#[utoipa::path(
    post,
    tag = "Text Generation Inference",
    path = "/",
    request_body = CompatGenerateRequest,
    responses(
        (status = 200, description = "Generated Text",
         content(
             ("application/json" = GenerateResponse),
             ("text/event-stream" = StreamResponse),
         )
        ),
        (status = 424, description = "Generation Error", body = ErrorResponse,
         example = json!({"error": "Request failed during generation"})),
        (status = 429, description = "Model is overloaded", body = ErrorResponse,
         example = json!({"error": "Model is overloaded"})),
        (status = 422, description = "Input validation error", body = ErrorResponse,
         example = json!({"error": "Input validation error"})),
        (status = 500, description = "Incomplete generation", body = ErrorResponse,
         example = json!({"error": "Incomplete generation"})),
    )
)]
pub async fn generate_handler(
    app_state: State<AppState>,
    Json(payload): Json<CompatGenerateRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
    if payload.stream {
        Ok(generate_stream_handler(
            app_state,
            Json(GenerateRequest {
                inputs: payload.inputs,
                parameters: payload.parameters,
            }),
        )
        .await
        .into_response())
    } else {
        Ok(generate_text_handler(
            app_state,
            Json(GenerateRequest {
                inputs: payload.inputs,
                parameters: payload.parameters,
            }),
        )
        .await
        .into_response())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::config::Config;
    use axum::{
        body::Body,
        http::{Request, StatusCode},
        routing::post,
        Router,
    };
    use serde_json::json;
    use tower::ServiceExt; // for `oneshot` method

    /// Test the generate_handler function for streaming enabled.
    #[ignore = "Will download model from HuggingFace"]
    #[tokio::test]
    async fn test_generate_handler_stream_enabled() {
        let state = AppState {
            config: Config::default(),
            text_generation: None,
        };
        let app = Router::new()
            .route("/", post(generate_handler))
            .with_state(state);

        let response = app
            .oneshot(
                Request::builder()
                    .method("POST")
                    .uri("/")
                    .header("content-type", "application/json")
                    .body(Body::from(
                        json!({
                            "inputs": "Hello, world!",
                            "stream": true
                        })
                        .to_string(),
                    ))
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::OK);
    }

    /// Test the generate_handler function for streaming disabled.
    #[tokio::test]
    #[ignore = "Will download model from HuggingFace"]
    async fn test_generate_handler_stream_disabled() {
        let state = AppState {
            config: Config::default(),
            text_generation: None,
        };
        let app = Router::new()
            .route("/", post(generate_handler))
            .with_state(state);

        let response = app
            .oneshot(
                Request::builder()
                    .method("POST")
                    .uri("/")
                    .header("content-type", "application/json")
                    .body(Body::from(
                        json!({
                            "inputs": "Hello, world!",
                            "stream": false
                        })
                        .to_string(),
                    ))
                    .unwrap(),
            )
            .await
            .unwrap();

        assert_eq!(response.status(), StatusCode::OK);
    }
}