aimx/inference/request.rs
1//! Inference request management for AI model providers
2//!
3//! This module handles sending requests to various AI inference providers and parsing their responses.
4//! It supports multiple AI APIs including OpenAI and Ollama, providing a unified interface for all providers.
5//!
6//! # Overview
7//!
8//! The module provides:
9//! - [`InferenceResponse`] - A unified response structure wrapping provider-specific responses
10//! - [`send_request`] - Main function for sending requests to AI providers
11//! - Support for both OpenAI and Ollama APIs with automatic response parsing
12//! - Token usage tracking and timing information
13//! - Configurable timeout and connection settings
14//!
15//! For prompt construction, see the [`prompt`](crate::inference::prompt) module.
16//! For provider configuration, see the [`provider`](crate::inference::provider) module.
17//!
18//! # Examples
19//!
20//! ```rust
21//! use aimx::inference::{Provider, Api, Model, Capability, send_request};
22//!
23//! // Create a provider configuration
24//! let provider = Provider {
25//! api: Api::Openai,
26//! url: "https://api.openai.com/v1".to_string(),
27//! key: "your-api-key".to_string(),
28//! model: Model::Standard,
29//! capability: Capability::Standard,
30//! fast: "gpt-3.5-turbo".to_string(),
31//! standard: "gpt-4".to_string(),
32//! planning: "gpt-4".to_string(),
33//! temperature: 0.7,
34//! max_tokens: 1000,
35//! connection_timeout_ms: 30000,
36//! request_timeout_ms: 120000,
37//! };
38//!
39//! // Send a request with system and user prompts
40//! let response = send_request(&provider, "You are a helpful assistant", "Hello, how are you?");
41//! match response {
42//! Ok(inference_response) => {
43//! println!("Response: {}", inference_response.text());
44//! println!("Tokens used: {}", inference_response.total_tokens());
45//! println!("Response time: {}ms", inference_response.response_time_ms());
46//! }
47//! Err(e) => println!("Error: {}", e),
48//! }
49//! ```
50
51use serde::{Deserialize, Serialize};
52use reqwest::blocking::Client;
53use std::time::{Duration, Instant};
54use anyhow::{Context, Result};
55use crate::inference::{Provider, Api};
56
57/// Unified response structure for AI inference requests
58///
59/// This struct provides a consistent interface for responses from different AI providers,
60/// normalizing the various response formats into a single, easy-to-use structure.
61#[derive(Debug, Serialize, Deserialize)]
62pub struct InferenceResponse {
63 /// The generated text content from the AI model
64 text: String,
65 /// Number of tokens used in the input prompt
66 input_tokens: u32,
67 /// Number of tokens generated in the output
68 output_tokens: u32,
69 /// Total tokens used (input + output)
70 total_tokens: u32,
71 /// Response time in milliseconds
72 response_time_ms: u128,
73}
74
75impl InferenceResponse {
76 /// Get the generated text content
77 pub fn text(&self) -> &str {
78 &self.text
79 }
80
81 /// Get the number of input tokens used
82 pub fn input_tokens(&self) -> u32 {
83 self.input_tokens
84 }
85
86 /// Get the number of output tokens generated
87 pub fn output_tokens(&self) -> u32 {
88 self.output_tokens
89 }
90
91 /// Get the total tokens used (input + output)
92 pub fn total_tokens(&self) -> u32 {
93 self.total_tokens
94 }
95
96 /// Get the response time in milliseconds
97 pub fn response_time_ms(&self) -> u128 {
98 self.response_time_ms
99 }
100
101 /// Consume the response and return the text content
102 ///
103 /// This is useful when you want to take ownership of the text content
104 /// without cloning it.
105 pub fn finish(self) -> String {
106 self.text
107 }
108}
109
110// OpenAI API structures
111/// Message structure for OpenAI API requests and responses
112#[derive(Serialize, Deserialize)]
113struct OpenAiMessage {
114 /// The role of the message sender (system, user, or assistant)
115 role: String,
116 /// The content of the message
117 content: String,
118}
119
120/// Request structure for OpenAI chat completions API
121#[derive(Serialize)]
122struct OpenAiRequest {
123 /// The model to use for the completion
124 model: String,
125 /// The list of messages to send to the model
126 messages: Vec<OpenAiMessage>,
127 /// Sampling temperature (0.0 to 1.0)
128 temperature: f64,
129 /// Maximum number of tokens to generate
130 max_tokens: u32,
131}
132
133/// Token usage information from OpenAI API responses
134#[derive(Deserialize)]
135struct OpenAiUsage {
136 /// Number of tokens in the prompt
137 prompt_tokens: u32,
138 /// Number of tokens in the completion
139 completion_tokens: u32,
140 /// Total number of tokens used
141 total_tokens: u32,
142}
143
144/// Choice structure from OpenAI API responses
145#[derive(Deserialize)]
146struct OpenAiChoice {
147 /// The message content from the model
148 message: OpenAiMessage,
149}
150
151/// Response structure from OpenAI chat completions API
152#[derive(Deserialize)]
153struct OpenAiResponse {
154 /// List of choices from the model
155 choices: Vec<OpenAiChoice>,
156 /// Token usage information
157 usage: OpenAiUsage,
158}
159
160// Ollama API structures - Chat endpoint
161/// Request structure for Ollama chat API
162#[derive(Debug, Serialize)]
163struct OllamaChatRequest {
164 /// The model to use for the completion
165 model: String,
166 /// The list of messages to send to the model
167 messages: Vec<OllamaMessage>,
168 /// Whether to stream the response (always false for this implementation)
169 stream: bool,
170 /// Model options
171 options: OllamaOptions,
172}
173
174/// Message structure for Ollama API requests and responses
175#[derive(Debug, Serialize)]
176struct OllamaMessage {
177 /// The role of the message sender (system, user, or assistant)
178 role: String,
179 /// The content of the message
180 content: String,
181}
182
183/// Model options for Ollama API requests
184#[derive(Debug, Serialize)]
185struct OllamaOptions {
186 /// Sampling temperature (0.0 to 1.0)
187 temperature: f64,
188 /// Maximum number of tokens to generate
189 num_predict: u32,
190}
191
192/// Response structure from Ollama chat API
193///
194/// This struct represents the response format from Ollama's chat API endpoint.
195/// It's made public to allow testing and direct parsing of Ollama responses.
196#[derive(Debug, Deserialize)]
197pub struct OllamaChatResponse {
198 /// The message content from the model
199 pub message: OllamaMessageResponse,
200 /// Number of tokens evaluated in the prompt (defaults to 0 if not present)
201 #[serde(default)]
202 pub prompt_eval_count: u32,
203 /// Number of tokens generated in the response (defaults to 0 if not present)
204 #[serde(default)]
205 pub eval_count: u32,
206 /// Total duration of the request in nanoseconds (defaults to 0 if not present)
207 #[serde(default)]
208 pub total_duration: u64,
209}
210
211/// Message response structure from Ollama API
212#[derive(Debug, Deserialize)]
213pub struct OllamaMessageResponse {
214 /// The content of the message
215 pub content: String,
216}
217
218/// Send a request to an AI inference provider
219///
220/// This function handles the complete process of sending a request to an AI provider,
221/// including:
222/// - Building the appropriate HTTP client with timeout configuration
223/// - Formatting the request according to the provider's API specification
224/// - Sending the request and handling HTTP errors
225/// - Parsing the response into a unified [`InferenceResponse`] structure
226/// - Tracking response time and token usage
227///
228/// # Arguments
229///
230/// * `provider` - The AI provider configuration containing API details, model settings, and timeouts
231/// * `system_prompt` - The system prompt that defines the AI's behavior or context
232/// * `user_prompt` - The user prompt containing the actual request or question
233///
234/// # Returns
235///
236/// Returns a `Result<InferenceResponse>` containing either:
237/// - `Ok(InferenceResponse)` with the generated text, token usage, and timing information
238/// - `Err(anyhow::Error)` with detailed error context if the request fails
239///
240/// # Errors
241///
242/// This function will return an error if:
243/// - The HTTP client cannot be created with the specified timeout settings
244/// - The HTTP request fails (network errors, invalid URLs, etc.)
245/// - The API returns a non-successful HTTP status code
246/// - The response cannot be parsed into the expected JSON format
247///
248/// # Supported Providers
249///
250/// The function currently supports:
251/// - **OpenAI**: Uses the `/chat/completions` endpoint with proper authentication headers
252/// - **Ollama**: Uses the `/api/chat` endpoint with local inference support
253///
254/// # Example
255///
256/// ```rust
257/// use aimx::inference::{send_request, Provider, Api, Model, Capability};
258///
259/// let provider = Provider {
260/// api: Api::Ollama,
261/// url: "http://localhost:11434".to_string(),
262/// key: "".to_string(), // No key needed for local Ollama
263/// model: Model::Standard,
264/// capability: Capability::Standard,
265/// fast: "llama3.2".to_string(),
266/// standard: "llama3.2".to_string(),
267/// planning: "llama3.2".to_string(),
268/// temperature: 0.7,
269/// max_tokens: 1000,
270/// connection_timeout_ms: 30000,
271/// request_timeout_ms: 120000,
272/// };
273///
274/// let response = send_request(&provider, "You are a helpful assistant", "Tell me a joke");
275/// ```
276pub fn send_request(provider: &Provider, system_prompt: &str, user_prompt: &str) -> Result<InferenceResponse> {
277 // Create a client with timeout configuration
278 let client = Client::builder()
279 .connect_timeout(Duration::from_millis(provider.connection_timeout_ms))
280 .timeout(Duration::from_millis(provider.request_timeout_ms))
281 .build()
282 .with_context(|| format!("Failed to create HTTP client with timeout configuration for {} provider", provider.api))?;
283
284 let start_time = Instant::now();
285
286 match provider.api {
287 Api::Openai => {
288 let request_body = OpenAiRequest {
289 model: provider.model().to_string(),
290 messages: vec![
291 OpenAiMessage {
292 role: "system".to_owned(),
293 content: system_prompt.to_string(),
294 },
295 OpenAiMessage {
296 role: "user".to_owned(),
297 content: user_prompt.to_string(),
298 },
299 ],
300 temperature: provider.temperature,
301 max_tokens: provider.max_tokens,
302 };
303
304 // Send the request
305 let http_response = client
306 .post(format!("{}/chat/completions", provider.url))
307 .header("Authorization", format!("Bearer {}", provider.key))
308 .header("HTTP-Referer", "https://imogen.net")
309 .header("X-Title", "Imogen")
310 .header("Content-Type", "application/json")
311 .json(&request_body)
312 .send()
313 .with_context(|| format!("Failed to send request to {} provider (model: {}, url: {})", provider.api, provider.model(), provider.url))?;
314
315 // Check if the response is successful
316 let status = http_response.status();
317 if !status.is_success() {
318 let error_text = http_response.text().unwrap_or_else(|_| "Unknown error".to_string());
319 anyhow::bail!("HTTP error {} from {} provider (model: {}, url: {}): {}",
320 status, provider.api, provider.model(), provider.url, error_text);
321 }
322
323 // Parse the response
324 let response = http_response
325 .json::<OpenAiResponse>()
326 .with_context(|| format!("Failed to parse JSON response from {} provider (model: {}, url: {}). This may indicate the model doesn't support chat completions or returned an unexpected response format.", provider.api, provider.model(), provider.url))?;
327
328 let response_time = start_time.elapsed().as_millis();
329
330 Ok(InferenceResponse {
331 text: response.choices[0].message.content.clone(),
332 input_tokens: response.usage.prompt_tokens,
333 output_tokens: response.usage.completion_tokens,
334 total_tokens: response.usage.total_tokens,
335 response_time_ms: response_time,
336 })
337 }
338
339 Api::Ollama => {
340 let request_body = OllamaChatRequest {
341 model: provider.model().to_string(),
342 messages: vec![
343 OllamaMessage {
344 role: "system".to_owned(),
345 content: system_prompt.to_string(),
346 },
347 OllamaMessage {
348 role: "user".to_owned(),
349 content: user_prompt.to_string(),
350 },
351 ],
352 stream: false,
353 options: OllamaOptions {
354 temperature: provider.temperature,
355 num_predict: provider.max_tokens,
356 },
357 };
358
359 // Send the request
360 let http_response = client
361 .post(format!("{}/api/chat", provider.url))
362 .header("Content-Type", "application/json")
363 .json(&request_body)
364 .send()
365 .with_context(|| format!("Failed to send request to {} provider (model: {}, url: {})", provider.api, provider.model(), provider.url))?;
366
367 // Check if the response is successful
368 let status = http_response.status();
369 if !status.is_success() {
370 let error_text = http_response.text().unwrap_or_else(|_| "Unknown error".to_string());
371 anyhow::bail!("HTTP error {} from {} provider (model: {}, url: {}): {}",
372 status, provider.api, provider.model(), provider.url, error_text);
373 }
374
375 // Parse the response
376 let response = http_response
377 .json::<OllamaChatResponse>()
378 .with_context(|| format!("Failed to parse JSON response from {} provider (model: {}, url: {}). This may indicate the model doesn't support chat completions or returned an unexpected response format.", provider.api, provider.model(), provider.url))?;
379
380 let response_time = start_time.elapsed().as_millis();
381
382 Ok(InferenceResponse {
383 text: response.message.content,
384 input_tokens: response.prompt_eval_count,
385 output_tokens: response.eval_count,
386 total_tokens: response.prompt_eval_count + response.eval_count,
387 response_time_ms: response_time,
388 })
389 }
390 }
391}