use std::fmt;
use anchor_chain_macros::Stateless;
use async_openai::types::{
ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
ChatCompletionRequestUserMessageContent, CreateChatCompletionRequestArgs,
CreateCompletionRequestArgs, CreateEmbeddingRequestArgs, Prompt,
};
use async_trait::async_trait;
#[cfg(feature = "tracing")]
use tracing::instrument;
use crate::error::AnchorChainError;
use crate::models::embedding_model::EmbeddingModel;
use crate::node::Node;
#[derive(Debug, Stateless, Clone)]
pub enum OpenAIModel<T>
where
T: Send + Sync + fmt::Debug,
T: Into<Prompt> + Into<ChatCompletionRequestUserMessageContent>,
{
GPT3_5Turbo(OpenAIChatModel<T>),
GPT3_5TurboInstruct(OpenAIInstructModel<T>),
GPT4Turbo(OpenAIChatModel<T>),
}
impl<T> OpenAIModel<T>
where
T: Send + Sync + fmt::Debug,
T: Into<Prompt> + Into<ChatCompletionRequestUserMessageContent>,
{
pub async fn new_gpt4_turbo(system_prompt: &str) -> Self {
OpenAIModel::GPT3_5Turbo(
OpenAIChatModel::new(system_prompt.to_string(), "gpt-4-turbo-preview".to_string())
.await,
)
}
pub async fn new_gpt3_5_turbo(system_prompt: &str) -> Self {
OpenAIModel::GPT4Turbo(
OpenAIChatModel::new(system_prompt.to_string(), "gpt-3.5-turbo".to_string()).await,
)
}
pub async fn new_gpt3_5_turbo_instruct() -> Self {
OpenAIModel::GPT3_5TurboInstruct(
OpenAIInstructModel::new("gpt-3.5-turbo-instruct-0914".to_string()).await,
)
}
}
#[async_trait]
impl<T> Node for OpenAIModel<T>
where
T: Send + Sync + fmt::Debug,
T: Into<Prompt> + Into<ChatCompletionRequestUserMessageContent>,
{
type Input = T;
type Output = String;
#[cfg_attr(feature = "tracing", instrument(skip(self)))]
async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
match self {
OpenAIModel::GPT3_5Turbo(model) => model.process(input).await,
OpenAIModel::GPT4Turbo(model) => model.process(input).await,
OpenAIModel::GPT3_5TurboInstruct(model) => model.process(input).await,
}
}
}
#[derive(Clone)]
pub struct OpenAIChatModel<T> {
system_prompt: String,
model: String,
client: async_openai::Client<async_openai::config::OpenAIConfig>,
_phantom: std::marker::PhantomData<T>,
}
impl<T> OpenAIChatModel<T> {
async fn new(system_prompt: String, model: String) -> Self {
let config = async_openai::config::OpenAIConfig::new();
let client = async_openai::Client::with_config(config);
OpenAIChatModel {
system_prompt,
client,
model,
_phantom: std::marker::PhantomData,
}
}
pub async fn new_with_key(system_prompt: String, model: String, api_key: String) -> Self {
let config = async_openai::config::OpenAIConfig::new().with_api_key(api_key);
let client = async_openai::Client::with_config(config);
OpenAIChatModel {
system_prompt,
client,
model,
_phantom: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<T> Node for OpenAIChatModel<T>
where
T: Into<ChatCompletionRequestUserMessageContent> + fmt::Debug + Send + Sync,
{
type Input = T;
type Output = String;
#[cfg_attr(feature = "tracing", instrument(skip(self), fields(model = self.model.as_str(), system_prompt = self.system_prompt.as_str())))]
async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
let system_prompt = ChatCompletionRequestSystemMessageArgs::default()
.content(self.system_prompt.clone())
.build()?
.into();
let input = ChatCompletionRequestUserMessageArgs::default()
.content(input)
.build()?
.into();
let request = CreateChatCompletionRequestArgs::default()
.max_tokens(512u16)
.model(&self.model)
.messages([system_prompt, input])
.build()?;
let response = self.client.chat().create(request).await?;
if response.choices.is_empty() {
return Err(AnchorChainError::EmptyResponseError);
}
let content = response
.choices
.first()
.ok_or(AnchorChainError::EmptyResponseError)?
.message
.clone()
.content
.ok_or(AnchorChainError::EmptyResponseError)?;
Ok(content)
}
}
impl<T> fmt::Debug for OpenAIChatModel<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OpenAI")
.field("system_prompt", &self.system_prompt)
.finish()
}
}
#[derive(Clone)]
pub struct OpenAIInstructModel<T>
where
T: Into<Prompt>,
{
model: String,
client: async_openai::Client<async_openai::config::OpenAIConfig>,
_phantom: std::marker::PhantomData<T>,
}
impl<T> OpenAIInstructModel<T>
where
T: Into<Prompt>,
{
#[allow(dead_code)]
async fn new(model: String) -> Self {
let config = async_openai::config::OpenAIConfig::new();
let client = async_openai::Client::with_config(config);
OpenAIInstructModel {
client,
model,
_phantom: std::marker::PhantomData,
}
}
#[allow(dead_code)]
pub async fn new_with_key(model: String, api_key: String) -> Self {
let config = async_openai::config::OpenAIConfig::new().with_api_key(api_key);
let client = async_openai::Client::with_config(config);
OpenAIInstructModel {
client,
model,
_phantom: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<T> Node for OpenAIInstructModel<T>
where
T: Into<Prompt> + fmt::Debug + Send + Sync,
{
type Input = T;
type Output = String;
#[cfg_attr(feature = "tracing", instrument(skip(self), fields(model = self.model.as_str())))]
async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
let request = CreateCompletionRequestArgs::default()
.model(&self.model)
.prompt(input)
.temperature(0.8)
.max_tokens(512u16)
.build()?;
let response = self.client.completions().create(request).await?;
let content = response
.choices
.first()
.ok_or(AnchorChainError::EmptyResponseError)?
.text
.clone();
Ok(content)
}
}
impl<T> fmt::Debug for OpenAIInstructModel<T>
where
T: Into<Prompt>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OpenAI").finish()
}
}
#[derive(Clone)]
pub struct OpenAIEmbeddingModel {
model: String,
client: async_openai::Client<async_openai::config::OpenAIConfig>,
}
impl Default for OpenAIEmbeddingModel {
fn default() -> Self {
OpenAIEmbeddingModel {
model: "text-embedding-3-large".to_string(),
client: async_openai::Client::with_config(async_openai::config::OpenAIConfig::new()),
}
}
}
impl OpenAIEmbeddingModel {
#[allow(dead_code)]
async fn new(model: String) -> Self {
let config = async_openai::config::OpenAIConfig::new();
let client = async_openai::Client::with_config(config);
OpenAIEmbeddingModel { client, model }
}
#[allow(dead_code)]
async fn new_with_key(model: String, api_key: String) -> Self {
let config = async_openai::config::OpenAIConfig::new().with_api_key(api_key);
let client = async_openai::Client::with_config(config);
OpenAIEmbeddingModel { client, model }
}
}
#[async_trait]
impl Node for OpenAIEmbeddingModel {
type Input = Vec<String>;
type Output = Vec<Vec<f32>>;
#[cfg_attr(feature = "tracing", instrument(skip(self), fields(model = self.model.as_str())))]
async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
let request = CreateEmbeddingRequestArgs::default()
.model(&self.model)
.input(input)
.build()?;
let response = self.client.embeddings().create(request).await?;
Ok(response
.data
.iter()
.map(|data| data.embedding.clone())
.collect())
}
}
#[async_trait]
impl EmbeddingModel for OpenAIEmbeddingModel {
#[cfg_attr(feature = "tracing", instrument(skip(self), fields(model = self.model.as_str())))]
async fn embed(&self, input: String) -> Result<Vec<f32>, AnchorChainError> {
self.process(vec![input])
.await?
.first()
.ok_or(AnchorChainError::EmptyResponseError)
.cloned()
}
fn dimensions(&self) -> usize {
3072
}
}
impl fmt::Debug for OpenAIEmbeddingModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OpenAI").finish()
}
}