use std::fmt;
use anchor_chain_macros::Stateless;
use async_openai::types::{
    ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
    ChatCompletionRequestUserMessageContent, CreateChatCompletionRequestArgs,
    CreateCompletionRequestArgs, CreateEmbeddingRequestArgs, Prompt,
};
use async_trait::async_trait;
#[cfg(feature = "tracing")]
use tracing::instrument;
use crate::error::AnchorChainError;
use crate::models::embedding_model::EmbeddingModel;
use crate::node::Node;
#[derive(Debug, Stateless, Clone)]
pub enum OpenAIModel<T>
where
    T: Send + Sync + fmt::Debug,
    T: Into<Prompt> + Into<ChatCompletionRequestUserMessageContent>,
{
    GPT3_5Turbo(OpenAIChatModel<T>),
    GPT3_5TurboInstruct(OpenAIInstructModel<T>),
    GPT4Turbo(OpenAIChatModel<T>),
}
impl<T> OpenAIModel<T>
where
    T: Send + Sync + fmt::Debug,
    T: Into<Prompt> + Into<ChatCompletionRequestUserMessageContent>,
{
    pub async fn new_gpt4_turbo(system_prompt: &str) -> Self {
        OpenAIModel::GPT3_5Turbo(
            OpenAIChatModel::new(system_prompt.to_string(), "gpt-4-turbo-preview".to_string())
                .await,
        )
    }
    pub async fn new_gpt3_5_turbo(system_prompt: &str) -> Self {
        OpenAIModel::GPT4Turbo(
            OpenAIChatModel::new(system_prompt.to_string(), "gpt-3.5-turbo".to_string()).await,
        )
    }
    pub async fn new_gpt3_5_turbo_instruct() -> Self {
        OpenAIModel::GPT3_5TurboInstruct(
            OpenAIInstructModel::new("gpt-3.5-turbo-instruct-0914".to_string()).await,
        )
    }
}
#[async_trait]
impl<T> Node for OpenAIModel<T>
where
    T: Send + Sync + fmt::Debug,
    T: Into<Prompt> + Into<ChatCompletionRequestUserMessageContent>,
{
    type Input = T;
    type Output = String;
    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
    async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
        match self {
            OpenAIModel::GPT3_5Turbo(model) => model.process(input).await,
            OpenAIModel::GPT4Turbo(model) => model.process(input).await,
            OpenAIModel::GPT3_5TurboInstruct(model) => model.process(input).await,
        }
    }
}
#[derive(Clone)]
pub struct OpenAIChatModel<T> {
    system_prompt: String,
    model: String,
    client: async_openai::Client<async_openai::config::OpenAIConfig>,
    _phantom: std::marker::PhantomData<T>,
}
impl<T> OpenAIChatModel<T> {
    async fn new(system_prompt: String, model: String) -> Self {
        let config = async_openai::config::OpenAIConfig::new();
        let client = async_openai::Client::with_config(config);
        OpenAIChatModel {
            system_prompt,
            client,
            model,
            _phantom: std::marker::PhantomData,
        }
    }
    pub async fn new_with_key(system_prompt: String, model: String, api_key: String) -> Self {
        let config = async_openai::config::OpenAIConfig::new().with_api_key(api_key);
        let client = async_openai::Client::with_config(config);
        OpenAIChatModel {
            system_prompt,
            client,
            model,
            _phantom: std::marker::PhantomData,
        }
    }
}
#[async_trait]
impl<T> Node for OpenAIChatModel<T>
where
    T: Into<ChatCompletionRequestUserMessageContent> + fmt::Debug + Send + Sync,
{
    type Input = T;
    type Output = String;
    #[cfg_attr(feature = "tracing", instrument(skip(self), fields(model = self.model.as_str(), system_prompt = self.system_prompt.as_str())))]
    async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
        let system_prompt = ChatCompletionRequestSystemMessageArgs::default()
            .content(self.system_prompt.clone())
            .build()?
            .into();
        let input = ChatCompletionRequestUserMessageArgs::default()
            .content(input)
            .build()?
            .into();
        let request = CreateChatCompletionRequestArgs::default()
            .max_tokens(512u16)
            .model(&self.model)
            .messages([system_prompt, input])
            .build()?;
        let response = self.client.chat().create(request).await?;
        if response.choices.is_empty() {
            return Err(AnchorChainError::EmptyResponseError);
        }
        let content = response
            .choices
            .first()
            .ok_or(AnchorChainError::EmptyResponseError)?
            .message
            .clone()
            .content
            .ok_or(AnchorChainError::EmptyResponseError)?;
        Ok(content)
    }
}
impl<T> fmt::Debug for OpenAIChatModel<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("OpenAI")
            .field("system_prompt", &self.system_prompt)
            .finish()
    }
}
#[derive(Clone)]
pub struct OpenAIInstructModel<T>
where
    T: Into<Prompt>,
{
    model: String,
    client: async_openai::Client<async_openai::config::OpenAIConfig>,
    _phantom: std::marker::PhantomData<T>,
}
impl<T> OpenAIInstructModel<T>
where
    T: Into<Prompt>,
{
    #[allow(dead_code)]
    async fn new(model: String) -> Self {
        let config = async_openai::config::OpenAIConfig::new();
        let client = async_openai::Client::with_config(config);
        OpenAIInstructModel {
            client,
            model,
            _phantom: std::marker::PhantomData,
        }
    }
    #[allow(dead_code)]
    pub async fn new_with_key(model: String, api_key: String) -> Self {
        let config = async_openai::config::OpenAIConfig::new().with_api_key(api_key);
        let client = async_openai::Client::with_config(config);
        OpenAIInstructModel {
            client,
            model,
            _phantom: std::marker::PhantomData,
        }
    }
}
#[async_trait]
impl<T> Node for OpenAIInstructModel<T>
where
    T: Into<Prompt> + fmt::Debug + Send + Sync,
{
    type Input = T;
    type Output = String;
    #[cfg_attr(feature = "tracing", instrument(skip(self), fields(model = self.model.as_str())))]
    async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
        let request = CreateCompletionRequestArgs::default()
            .model(&self.model)
            .prompt(input)
            .temperature(0.8)
            .max_tokens(512u16)
            .build()?;
        let response = self.client.completions().create(request).await?;
        let content = response
            .choices
            .first()
            .ok_or(AnchorChainError::EmptyResponseError)?
            .text
            .clone();
        Ok(content)
    }
}
impl<T> fmt::Debug for OpenAIInstructModel<T>
where
    T: Into<Prompt>,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("OpenAI").finish()
    }
}
#[derive(Clone)]
pub struct OpenAIEmbeddingModel {
    model: String,
    client: async_openai::Client<async_openai::config::OpenAIConfig>,
}
impl Default for OpenAIEmbeddingModel {
    fn default() -> Self {
        OpenAIEmbeddingModel {
            model: "text-embedding-3-large".to_string(),
            client: async_openai::Client::with_config(async_openai::config::OpenAIConfig::new()),
        }
    }
}
impl OpenAIEmbeddingModel {
    #[allow(dead_code)]
    async fn new(model: String) -> Self {
        let config = async_openai::config::OpenAIConfig::new();
        let client = async_openai::Client::with_config(config);
        OpenAIEmbeddingModel { client, model }
    }
    #[allow(dead_code)]
    async fn new_with_key(model: String, api_key: String) -> Self {
        let config = async_openai::config::OpenAIConfig::new().with_api_key(api_key);
        let client = async_openai::Client::with_config(config);
        OpenAIEmbeddingModel { client, model }
    }
}
#[async_trait]
impl Node for OpenAIEmbeddingModel {
    type Input = Vec<String>;
    type Output = Vec<Vec<f32>>;
    #[cfg_attr(feature = "tracing", instrument(skip(self), fields(model = self.model.as_str())))]
    async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
        let request = CreateEmbeddingRequestArgs::default()
            .model(&self.model)
            .input(input)
            .build()?;
        let response = self.client.embeddings().create(request).await?;
        Ok(response
            .data
            .iter()
            .map(|data| data.embedding.clone())
            .collect())
    }
}
#[async_trait]
impl EmbeddingModel for OpenAIEmbeddingModel {
    #[cfg_attr(feature = "tracing", instrument(skip(self), fields(model = self.model.as_str())))]
    async fn embed(&self, input: String) -> Result<Vec<f32>, AnchorChainError> {
        self.process(vec![input])
            .await?
            .first()
            .ok_or(AnchorChainError::EmptyResponseError)
            .cloned()
    }
    fn dimensions(&self) -> usize {
        3072
    }
}
impl fmt::Debug for OpenAIEmbeddingModel {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("OpenAI").finish()
    }
}