aimx/inference/
request.rs

1//! Inference request management for AI model providers
2//!
3//! This module handles sending requests to various AI inference providers and parsing their responses.
4//! It supports multiple AI APIs including OpenAI and Ollama, providing a unified interface for all providers.
5//!
6//! # Overview
7//!
8//! The module provides:
9//! - [`InferenceResponse`] - A unified response structure wrapping provider-specific responses
10//! - [`send_request`] - Main function for sending requests to AI providers
11//! - Support for both OpenAI and Ollama APIs with automatic response parsing
12//! - Token usage tracking and timing information
13//! - Configurable timeout and connection settings
14//!
15//! For prompt construction, see the [`prompt`](crate::inference::prompt) module.
16//! For provider configuration, see the [`provider`](crate::inference::provider) module.
17//!
18//! # Examples
19//!
20//! ```rust
21//! use aimx::inference::{Provider, Api, Model, Capability, send_request};
22//! 
23//! // Create a provider configuration
24//! let provider = Provider {
25//!     api: Api::Openai,
26//!     url: "https://api.openai.com/v1".to_string(),
27//!     key: "your-api-key".to_string(),
28//!     model: Model::Standard,
29//!     capability: Capability::Standard,
30//!     fast: "gpt-3.5-turbo".to_string(),
31//!     standard: "gpt-4".to_string(),
32//!     planning: "gpt-4".to_string(),
33//!     temperature: 0.7,
34//!     max_tokens: 1000,
35//!     connection_timeout_ms: 30000,
36//!     request_timeout_ms: 120000,
37//! };
38//! 
39//! // Send a request with system and user prompts
40//! let response = send_request(&provider, "You are a helpful assistant", "Hello, how are you?");
41//! match response {
42//!     Ok(inference_response) => {
43//!         println!("Response: {}", inference_response.text());
44//!         println!("Tokens used: {}", inference_response.total_tokens());
45//!         println!("Response time: {}ms", inference_response.response_time_ms());
46//!     }
47//!     Err(e) => println!("Error: {}", e),
48//! }
49//! ```
50
51use serde::{Deserialize, Serialize};
52use reqwest::blocking::Client;
53use std::time::{Duration, Instant};
54use anyhow::{Context, Result};
55use crate::inference::{Provider, Api};
56
57/// Unified response structure for AI inference requests
58///
59/// This struct provides a consistent interface for responses from different AI providers,
60/// normalizing the various response formats into a single, easy-to-use structure.
61#[derive(Debug, Serialize, Deserialize)]
62pub struct InferenceResponse {
63    /// The generated text content from the AI model
64    text: String,
65    /// Number of tokens used in the input prompt
66    input_tokens: u32,
67    /// Number of tokens generated in the output
68    output_tokens: u32,
69    /// Total tokens used (input + output)
70    total_tokens: u32,
71    /// Response time in milliseconds
72    response_time_ms: u128,
73}
74
75impl InferenceResponse {
76    /// Get the generated text content
77    pub fn text(&self) -> &str {
78        &self.text
79    }
80    
81    /// Get the number of input tokens used
82    pub fn input_tokens(&self) -> u32 {
83        self.input_tokens
84    }
85    
86    /// Get the number of output tokens generated
87    pub fn output_tokens(&self) -> u32 {
88        self.output_tokens
89    }
90    
91    /// Get the total tokens used (input + output)
92    pub fn total_tokens(&self) -> u32 {
93        self.total_tokens
94    }
95    
96    /// Get the response time in milliseconds
97    pub fn response_time_ms(&self) -> u128 {
98        self.response_time_ms
99    }
100    
101    /// Consume the response and return the text content
102    ///
103    /// This is useful when you want to take ownership of the text content
104    /// without cloning it.
105    pub fn finish(self) -> String {
106        self.text
107    }
108}
109
110// OpenAI API structures
111/// Message structure for OpenAI API requests and responses
112#[derive(Serialize, Deserialize)]
113struct OpenAiMessage {
114    /// The role of the message sender (system, user, or assistant)
115    role: String,
116    /// The content of the message
117    content: String,
118}
119
120/// Request structure for OpenAI chat completions API
121#[derive(Serialize)]
122struct OpenAiRequest {
123    /// The model to use for the completion
124    model: String,
125    /// The list of messages to send to the model
126    messages: Vec<OpenAiMessage>,
127    /// Sampling temperature (0.0 to 1.0)
128    temperature: f64,
129    /// Maximum number of tokens to generate
130    max_tokens: u32,
131}
132
133/// Token usage information from OpenAI API responses
134#[derive(Deserialize)]
135struct OpenAiUsage {
136    /// Number of tokens in the prompt
137    prompt_tokens: u32,
138    /// Number of tokens in the completion
139    completion_tokens: u32,
140    /// Total number of tokens used
141    total_tokens: u32,
142}
143
144/// Choice structure from OpenAI API responses
145#[derive(Deserialize)]
146struct OpenAiChoice {
147    /// The message content from the model
148    message: OpenAiMessage,
149}
150
151/// Response structure from OpenAI chat completions API
152#[derive(Deserialize)]
153struct OpenAiResponse {
154    /// List of choices from the model
155    choices: Vec<OpenAiChoice>,
156    /// Token usage information
157    usage: OpenAiUsage,
158}
159
160// Ollama API structures - Chat endpoint
161/// Request structure for Ollama chat API
162#[derive(Debug, Serialize)]
163struct OllamaChatRequest {
164    /// The model to use for the completion
165    model: String,
166    /// The list of messages to send to the model
167    messages: Vec<OllamaMessage>,
168    /// Whether to stream the response (always false for this implementation)
169    stream: bool,
170    /// Model options
171    options: OllamaOptions,
172}
173
174/// Message structure for Ollama API requests and responses
175#[derive(Debug, Serialize)]
176struct OllamaMessage {
177    /// The role of the message sender (system, user, or assistant)
178    role: String,
179    /// The content of the message
180    content: String,
181}
182
183/// Model options for Ollama API requests
184#[derive(Debug, Serialize)]
185struct OllamaOptions {
186    /// Sampling temperature (0.0 to 1.0)
187    temperature: f64,
188    /// Maximum number of tokens to generate
189    num_predict: u32,
190}
191
192/// Response structure from Ollama chat API
193///
194/// This struct represents the response format from Ollama's chat API endpoint.
195/// It's made public to allow testing and direct parsing of Ollama responses.
196#[derive(Debug, Deserialize)]
197pub struct OllamaChatResponse {
198    /// The message content from the model
199    pub message: OllamaMessageResponse,
200    /// Number of tokens evaluated in the prompt (defaults to 0 if not present)
201    #[serde(default)]
202    pub prompt_eval_count: u32,
203    /// Number of tokens generated in the response (defaults to 0 if not present)
204    #[serde(default)]
205    pub eval_count: u32,
206    /// Total duration of the request in nanoseconds (defaults to 0 if not present)
207    #[serde(default)]
208    pub total_duration: u64,
209}
210
211/// Message response structure from Ollama API
212#[derive(Debug, Deserialize)]
213pub struct OllamaMessageResponse {
214    /// The content of the message
215    pub content: String,
216}
217
218/// Send a request to an AI inference provider
219///
220/// This function handles the complete process of sending a request to an AI provider,
221/// including:
222/// - Building the appropriate HTTP client with timeout configuration
223/// - Formatting the request according to the provider's API specification
224/// - Sending the request and handling HTTP errors
225/// - Parsing the response into a unified [`InferenceResponse`] structure
226/// - Tracking response time and token usage
227///
228/// # Arguments
229///
230/// * `provider` - The AI provider configuration containing API details, model settings, and timeouts
231/// * `system_prompt` - The system prompt that defines the AI's behavior or context
232/// * `user_prompt` - The user prompt containing the actual request or question
233///
234/// # Returns
235///
236/// Returns a `Result<InferenceResponse>` containing either:
237/// - `Ok(InferenceResponse)` with the generated text, token usage, and timing information
238/// - `Err(anyhow::Error)` with detailed error context if the request fails
239///
240/// # Errors
241///
242/// This function will return an error if:
243/// - The HTTP client cannot be created with the specified timeout settings
244/// - The HTTP request fails (network errors, invalid URLs, etc.)
245/// - The API returns a non-successful HTTP status code
246/// - The response cannot be parsed into the expected JSON format
247///
248/// # Supported Providers
249///
250/// The function currently supports:
251/// - **OpenAI**: Uses the `/chat/completions` endpoint with proper authentication headers
252/// - **Ollama**: Uses the `/api/chat` endpoint with local inference support
253///
254/// # Example
255///
256/// ```rust
257/// use aimx::inference::{send_request, Provider, Api, Model, Capability};
258/// 
259/// let provider = Provider {
260///     api: Api::Ollama,
261///     url: "http://localhost:11434".to_string(),
262///     key: "".to_string(),  // No key needed for local Ollama
263///     model: Model::Standard,
264///     capability: Capability::Standard,
265///     fast: "llama3.2".to_string(),
266///     standard: "llama3.2".to_string(),
267///     planning: "llama3.2".to_string(),
268///     temperature: 0.7,
269///     max_tokens: 1000,
270///     connection_timeout_ms: 30000,
271///     request_timeout_ms: 120000,
272/// };
273/// 
274/// let response = send_request(&provider, "You are a helpful assistant", "Tell me a joke");
275/// ```
276pub fn send_request(provider: &Provider, system_prompt: &str, user_prompt: &str) -> Result<InferenceResponse> {
277    // Create a client with timeout configuration
278    let client = Client::builder()
279        .connect_timeout(Duration::from_millis(provider.connection_timeout_ms))
280        .timeout(Duration::from_millis(provider.request_timeout_ms))
281        .build()
282        .with_context(|| format!("Failed to create HTTP client with timeout configuration for {} provider", provider.api))?;
283
284    let start_time = Instant::now();
285    
286    match provider.api {
287        Api::Openai => {
288            let request_body = OpenAiRequest {
289                model: provider.model().to_string(),
290                messages: vec![
291                    OpenAiMessage {
292                        role: "system".to_owned(),
293                        content: system_prompt.to_string(),
294                    },
295                    OpenAiMessage {
296                        role: "user".to_owned(),
297                        content: user_prompt.to_string(),
298                    },
299                ],
300                temperature: provider.temperature,
301                max_tokens: provider.max_tokens,
302            };
303            
304            // Send the request
305            let http_response = client
306                .post(format!("{}/chat/completions", provider.url))
307                .header("Authorization", format!("Bearer {}", provider.key))
308                .header("HTTP-Referer", "https://imogen.net")
309                .header("X-Title", "Imogen")
310                .header("Content-Type", "application/json")
311                .json(&request_body)
312                .send()
313                .with_context(|| format!("Failed to send request to {} provider (model: {}, url: {})", provider.api, provider.model(), provider.url))?;
314
315            // Check if the response is successful
316            let status = http_response.status();
317            if !status.is_success() {
318                let error_text = http_response.text().unwrap_or_else(|_| "Unknown error".to_string());
319                anyhow::bail!("HTTP error {} from {} provider (model: {}, url: {}): {}", 
320                    status, provider.api, provider.model(), provider.url, error_text);
321            }
322
323            // Parse the response
324            let response = http_response
325                .json::<OpenAiResponse>()
326                .with_context(|| format!("Failed to parse JSON response from {} provider (model: {}, url: {}). This may indicate the model doesn't support chat completions or returned an unexpected response format.", provider.api, provider.model(), provider.url))?;
327
328            let response_time = start_time.elapsed().as_millis();
329            
330            Ok(InferenceResponse {
331                text: response.choices[0].message.content.clone(),
332                input_tokens: response.usage.prompt_tokens,
333                output_tokens: response.usage.completion_tokens,
334                total_tokens: response.usage.total_tokens,
335                response_time_ms: response_time,
336            })
337        }
338        
339        Api::Ollama => {
340            let request_body = OllamaChatRequest {
341                model: provider.model().to_string(),
342                messages: vec![
343                    OllamaMessage {
344                        role: "system".to_owned(),
345                        content: system_prompt.to_string(),
346                    },
347                    OllamaMessage {
348                        role: "user".to_owned(),
349                        content: user_prompt.to_string(),
350                    },
351                ],
352                stream: false,
353                options: OllamaOptions {
354                    temperature: provider.temperature,
355                    num_predict: provider.max_tokens,
356                },
357            };
358            
359            // Send the request
360            let http_response = client
361                .post(format!("{}/api/chat", provider.url))
362                .header("Content-Type", "application/json")
363                .json(&request_body)
364                .send()
365                .with_context(|| format!("Failed to send request to {} provider (model: {}, url: {})", provider.api, provider.model(), provider.url))?;
366
367            // Check if the response is successful
368            let status = http_response.status();
369            if !status.is_success() {
370                let error_text = http_response.text().unwrap_or_else(|_| "Unknown error".to_string());
371                anyhow::bail!("HTTP error {} from {} provider (model: {}, url: {}): {}", 
372                    status, provider.api, provider.model(), provider.url, error_text);
373            }
374
375            // Parse the response
376            let response = http_response
377                .json::<OllamaChatResponse>()
378                .with_context(|| format!("Failed to parse JSON response from {} provider (model: {}, url: {}). This may indicate the model doesn't support chat completions or returned an unexpected response format.", provider.api, provider.model(), provider.url))?;
379
380            let response_time = start_time.elapsed().as_millis();
381            
382            Ok(InferenceResponse {
383                text: response.message.content,
384                input_tokens: response.prompt_eval_count,
385                output_tokens: response.eval_count,
386                total_tokens: response.prompt_eval_count + response.eval_count,
387                response_time_ms: response_time,
388            })
389        }
390    }
391}