Functions and structs for working with openai api.

Input with stdout message
This commit is contained in:
Martin Berg Alstad 2024-07-31 19:37:46 +02:00
parent 865cc6ddb9
commit 695605977f
15 changed files with 2784 additions and 11 deletions

2
.idea/lib.iml generated
View File

@ -9,11 +9,13 @@
<sourceFolder url="file://$MODULE_DIR$/crates/read_files/src" isTestSource="false" /> <sourceFolder url="file://$MODULE_DIR$/crates/read_files/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/crates/read_files/tests" isTestSource="true" /> <sourceFolder url="file://$MODULE_DIR$/crates/read_files/tests" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/crates/read_files/tests" isTestSource="true" /> <sourceFolder url="file://$MODULE_DIR$/crates/read_files/tests" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/examples/openai-assistant/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" /> <sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
<excludeFolder url="file://$MODULE_DIR$/target" /> <excludeFolder url="file://$MODULE_DIR$/target" />
<excludeFolder url="file://$MODULE_DIR$/examples/multipart_file/target" /> <excludeFolder url="file://$MODULE_DIR$/examples/multipart_file/target" />
<excludeFolder url="file://$MODULE_DIR$/crates/into_response_derive/target" /> <excludeFolder url="file://$MODULE_DIR$/crates/into_response_derive/target" />
<excludeFolder url="file://$MODULE_DIR$/crates/read_files/target" /> <excludeFolder url="file://$MODULE_DIR$/crates/read_files/target" />
<excludeFolder url="file://$MODULE_DIR$/examples/openai-assistant/target" />
</content> </content>
<orderEntry type="inheritedJdk" /> <orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" /> <orderEntry type="sourceFolder" forTests="false" />

838
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,10 @@
[workspace] [workspace]
members = ["crates/*"] members = ["crates/*"]
exclude = ["examples"]
[workspace.package] [workspace.package]
edition = "2021" edition = "2021"
rust-version = "1.79.0" rust-version = "1.80.0"
authors = ["Martin Berg Alstad"] authors = ["Martin Berg Alstad"]
homepage = "emberal.github.io" homepage = "emberal.github.io"
@ -26,8 +27,12 @@ tower-http = { version = "0.5", optional = true, features = ["trace", "cors", "n
# Async # Async
tokio = { version = "1.38", optional = true, features = ["fs"] } tokio = { version = "1.38", optional = true, features = ["fs"] }
tokio-util = { version = "0.7", optional = true, features = ["io"] } tokio-util = { version = "0.7", optional = true, features = ["io"] }
async-stream = { version = "0.3", optional = true }
futures = { version = "0.3", optional = true }
# Error handling # Error handling
thiserror = { version = "1.0", optional = true } thiserror = { version = "1.0", optional = true }
# LLM
async-openai = { version = "0.23", optional = true }
# Logging # Logging
tracing = { version = "0.1", optional = true } tracing = { version = "0.1", optional = true }
tracing-subscriber = { version = "0.3", optional = true } tracing-subscriber = { version = "0.3", optional = true }
@ -38,6 +43,8 @@ into-response-derive = { path = "crates/into_response_derive", optional = true }
read-files = { path = "crates/read_files", optional = true } read-files = { path = "crates/read_files", optional = true }
# Serialization / Deserialization # Serialization / Deserialization
serde = { version = "1.0", optional = true, features = ["derive"] } serde = { version = "1.0", optional = true, features = ["derive"] }
# Utils
cfg-if = "1.0.0"
[workspace.dependencies] [workspace.dependencies]
syn = "2.0" syn = "2.0"
@ -51,3 +58,4 @@ nom = ["dep:nom"]
serde = ["dep:serde"] serde = ["dep:serde"]
derive = ["dep:into-response-derive", "axum", "serde"] derive = ["dep:into-response-derive", "axum", "serde"]
read-files = ["dep:read-files"] read-files = ["dep:read-files"]
openai = ["dep:async-openai", "dep:async-stream", "dep:futures"]

View File

@ -286,9 +286,10 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]] [[package]]
name = "lib" name = "lib"
version = "1.3.5" version = "1.4.1-hotfix-hotfix-2"
dependencies = [ dependencies = [
"axum", "axum",
"cfg-if",
"thiserror", "thiserror",
"tokio", "tokio",
"tower", "tower",

1554
examples/openai-assistant/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,10 @@
[package]
name = "openai-assistant"
version = "0.1.0"
edition = "2021"
[dependencies]
lib = { path = "../..", features = ["openai", "io"] }
tokio = { version = "1.38.0", features = ["rt-multi-thread"] }
futures = "0.3.0"
async-openai = "0.23.0"

View File

@ -0,0 +1,32 @@
use futures::StreamExt;
use lib::{
openai::{assistants::Assistant, streams::TokenStream},
prompt_read_line,
};
/// Expects the OPENAI_API_KEY environment variable to be set
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let assistant = Assistant::new("gpt-4o-mini", "Be a helpful assistant").await?;
let thread = assistant.create_thread().await?;
while let Some(input) = get_user_input() {
let mut stream: TokenStream = thread.run_stream(&input).await?.into();
while let Some(result) = stream.next().await {
if let Ok(text) = result {
print!("{}", text);
}
}
println!();
}
assistant.delete().await?;
Ok(())
}
fn get_user_input() -> Option<String> {
prompt_read_line!("> ")
.ok()
.take_if(|input| !input.is_empty())
}

31
src/io/console.rs Normal file
View File

@ -0,0 +1,31 @@
#[macro_export]
macro_rules! _read_line {
() => {
match std::io::Write::flush(&mut std::io::stdout()) {
Ok(_) => {
let mut input = String::new();
match std::io::Stdin::read_line(&mut std::io::stdin(), &mut input) {
Ok(_) => Ok::<String, std::io::Error>(input),
Err(error) => Err(error),
}
}
Err(error) => Err(error),
}
};
}
#[macro_export]
macro_rules! prompt_read_line {
($($expr:expr),*) => {{
print!($($expr),*);
$crate::_read_line!()
}};
}
#[macro_export]
macro_rules! promptln_read_line {
($($expr:expr),*) => {{
println!($($expr),*);
$crate::_read_line!()
}};
}

View File

@ -1 +1,2 @@
pub mod console;
pub mod file; pub mod file;

View File

@ -11,6 +11,8 @@ pub mod axum;
pub mod io; pub mod io;
#[cfg(feature = "nom")] #[cfg(feature = "nom")]
pub mod nom; pub mod nom;
#[cfg(feature = "openai")]
pub mod openai;
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
pub mod serde; pub mod serde;
pub mod traits; pub mod traits;

125
src/openai/assistants.rs Normal file
View File

@ -0,0 +1,125 @@
use async_openai::{
types::{
AssistantEventStream, AssistantObject, CreateAssistantRequest, CreateMessageRequest,
CreateRunRequest, CreateThreadRequest, DeleteAssistantResponse, DeleteThreadResponse,
MessageObject, MessageRole, ThreadObject,
},
Client,
};
use crate::openai::types::{OpenAIClient, OpenAIResult};
#[derive(Clone, Debug)]
pub struct Assistant {
client: OpenAIClient,
assistant_object: AssistantObject,
}
#[derive(Clone, Debug)]
pub struct Thread<'client> {
client: &'client OpenAIClient,
assistant_id: String,
thread_object: ThreadObject,
}
impl Assistant {
pub async fn new(
client: &OpenAIClient,
model: impl Into<String>,
instructions: impl Into<String>,
) -> OpenAIResult<Self> {
let assistant_object = client
.assistants()
.create(CreateAssistantRequest {
model: model.into(),
instructions: Some(instructions.into()),
..Default::default()
})
.await?;
Ok(Self {
client: client.clone(),
assistant_object,
})
}
pub async fn from_id(id: impl AsRef<str>) -> OpenAIResult<Self> {
let client = Client::new();
let assistant_object = client.assistants().retrieve(id.as_ref()).await?;
Ok(Self {
client,
assistant_object,
})
}
pub async fn create_thread(&self) -> OpenAIResult<Thread> {
Thread::new(&self.client, self.id()).await
}
pub async fn delete(self) -> OpenAIResult<DeleteAssistantResponse> {
self.client.assistants().delete(self.id()).await
}
pub fn id(&self) -> &str {
&self.assistant_object.id
}
}
impl<'client> Thread<'client> {
pub async fn new(
client: &'client OpenAIClient,
assistant_id: impl Into<String>,
) -> OpenAIResult<Self> {
Ok(Self {
client,
assistant_id: assistant_id.into(),
thread_object: client
.threads()
.create(CreateThreadRequest::default())
.await?,
})
}
pub async fn from_id(
client: &'client OpenAIClient,
assistant_id: impl Into<String>,
thread_id: impl AsRef<str>,
) -> OpenAIResult<Self> {
Ok(Self {
client,
assistant_id: assistant_id.into(),
thread_object: client.threads().retrieve(thread_id.as_ref()).await?,
})
}
pub async fn run_stream(&self, prompt: impl AsRef<str>) -> OpenAIResult<AssistantEventStream> {
self.create_message(prompt.as_ref()).await?;
self.client
.threads()
.runs(self.id())
.create_stream(CreateRunRequest {
assistant_id: self.assistant_id.clone(),
..Default::default()
})
.await
}
pub fn id(&self) -> &str {
&self.thread_object.id
}
async fn create_message(&self, prompt: &str) -> OpenAIResult<MessageObject> {
self.client
.threads()
.messages(&self.thread_object.id)
.create(CreateMessageRequest {
role: MessageRole::User,
content: prompt.into(),
..Default::default()
})
.await
}
async fn delete(&self) -> OpenAIResult<DeleteThreadResponse> {
self.client.threads().delete(self.id()).await
}
}

50
src/openai/chat.rs Normal file
View File

@ -0,0 +1,50 @@
use async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessage,
ChatCompletionRequestUserMessageContent, ChatCompletionResponseStream,
CreateChatCompletionRequest,
};
use crate::openai::types::{OpenAIClient, OpenAIResult};
pub async fn chat(
client: &OpenAIClient,
model: impl Into<String>,
prompt: impl Into<String>,
) -> OpenAIResult<String> {
Ok(client
.chat()
.create(CreateChatCompletionRequest {
model: model.into(),
messages: vec![create_user_message(prompt)],
..Default::default()
})
.await?
.choices[0]
.message
.content
.clone()
.unwrap_or_default())
}
pub async fn chat_stream(
client: &OpenAIClient,
model: impl Into<String>,
prompt: impl Into<String>,
) -> OpenAIResult<ChatCompletionResponseStream> {
client
.chat()
.create_stream(CreateChatCompletionRequest {
model: model.into(),
stream: Some(true),
messages: vec![create_user_message(prompt)],
..Default::default()
})
.await
}
fn create_user_message(prompt: impl Into<String>) -> ChatCompletionRequestMessage {
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::from(prompt.into()),
name: None,
})
}

4
src/openai/mod.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod assistants;
pub mod chat;
pub mod streams;
pub mod types;

129
src/openai/streams.rs Normal file
View File

@ -0,0 +1,129 @@
use std::{
pin::Pin,
task::{Context, Poll},
};
use async_openai::types::ChatCompletionResponseStream;
use async_openai::{
error::OpenAIError,
types::{AssistantEventStream, AssistantStreamEvent, MessageDeltaContent, MessageDeltaObject},
};
use async_stream::try_stream;
use futures::{Stream, StreamExt};
use crate::openai::types::OpenAIResult;
pub struct TokenStream(Pin<Box<dyn Stream<Item = OpenAIResult<String>> + Send + 'static>>);
impl TokenStream {
pub fn new(stream: impl Stream<Item = OpenAIResult<String>> + Send + 'static) -> Self {
Self(Box::pin(stream))
}
}
impl Stream for TokenStream {
type Item = OpenAIResult<String>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.0.as_mut().poll_next(cx)
}
}
impl From<AssistantEventStream> for TokenStream {
fn from(mut value: AssistantEventStream) -> Self {
Self::new(try_stream! {
while let Some(event) = value.next().await {
if let Ok(AssistantStreamEvent::ThreadMessageDelta(message)) = event {
if let Ok(text) = get_message(message) {
yield text;
};
}
}
})
}
}
impl From<ChatCompletionResponseStream> for TokenStream {
fn from(mut value: ChatCompletionResponseStream) -> Self {
Self::new(try_stream! {
while let Some(event) = value.next().await {
if let Ok(event) = event {
if let Some(text) = event.choices[0].delta.content.clone() {
yield text;
};
}
}
})
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "axum")] {
use axum::response::sse::Event;
pub struct EventStream<E>(Pin<Box<dyn Stream<Item = Result<Event, E>> + Send + 'static>>);
impl<E> EventStream<E> {
pub fn new(stream: impl Stream<Item = Result<Event, E>> + Send + 'static) -> Self {
Self(Box::pin(stream))
}
}
impl<E> Stream for EventStream<E> {
type Item = Result<Event, E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.0.as_mut().poll_next(cx)
}
}
impl<E> From<AssistantEventStream> for EventStream<E>
where
E: Send + 'static,
{
fn from(mut value: AssistantEventStream) -> Self {
Self::new(try_stream! {
while let Some(event) = value.next().await {
if let Ok(AssistantStreamEvent::ThreadMessageDelta(message)) = event {
if let Ok(text) = get_message(message) {
yield Event::default().data(text);
};
}
}
})
}
}
impl<E> From<ChatCompletionResponseStream> for EventStream<E>
where
E: Send + 'static,
{
fn from(mut value: ChatCompletionResponseStream) -> Self {
Self::new(try_stream! {
while let Some(event) = value.next().await {
if let Ok(event) = event {
if let Some(text) = event.choices[0].delta.content.clone() {
yield Event::default().data(text);
};
}
}
})
}
}
}
}
fn get_message(message: MessageDeltaObject) -> OpenAIResult<String> {
let content = message
.delta
.content
.and_then(|content| content.first().cloned())
.ok_or(OpenAIError::StreamError("Expected content".into()))?;
if let MessageDeltaContent::Text(content) = content {
content
.text
.and_then(|text| text.value)
.ok_or(OpenAIError::StreamError("Expected text message".into()))
} else {
Err(OpenAIError::StreamError("Expected text message".into()))
}
}

4
src/openai/types.rs Normal file
View File

@ -0,0 +1,4 @@
use async_openai::{config::OpenAIConfig, error::OpenAIError, Client};
pub type OpenAIClient = Client<OpenAIConfig>;
pub type OpenAIResult<T> = Result<T, OpenAIError>;