1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
//! Module for handling dynamic prompts in processing chains.
//!
//! This module provides the `Prompt` struct, a processor for handling
//! and displaying text prompts. The `Prompt` struct uses Tera templating
//! to allow for dynamic input substitution in the prompt text. Tera is a
//! template engine that allows for dynamic templating using variables with
//! a similar syntax to Jinja2. For more information on Tera, see the
//! [Tera documentation](https://keats.github.io/tera/docs/#templates).

use std::collections::HashMap;

use anchor_chain_macros::Stateless;
use async_trait::async_trait;
use tera::{Context, Tera};
#[cfg(feature = "tracing")]
use tracing::instrument;

use crate::error::AnchorChainError;
use crate::node::Node;

/// A processor for handling text prompts within a processing chain.
///
/// The `Prompt` struct is a processor for handling text prompts within a
/// processing chain using Tera templating.
#[derive(Debug, Stateless)]
pub struct Prompt<'a> {
    /// The Tera template used to process the prompt text.
    tera: Tera,
    _marker: std::marker::PhantomData<&'a ()>,
}

impl<'a> Prompt<'a> {
    /// Creates a new `Prompt` processor with the specified template.
    ///
    /// Templates need to be specified using the Tera syntax which is based on
    /// Jinja2. For more information on Tera, see the
    /// [Tera Templates documentation](https://keats.github.io/tera/docs/#templates).
    ///
    /// # Examples
    /// ```rust
    /// use anchor_chain::nodes::prompt::Prompt;
    ///
    /// let prompt = Prompt::new("Create a {{ language }} program that prints 'Hello, World!'");
    /// ```
    pub fn new(template: &str) -> Self {
        let mut tera = Tera::default();
        tera.add_raw_template("prompt", template)
            .expect("Error creating template");
        Prompt {
            tera,
            _marker: std::marker::PhantomData,
        }
    }
}

/// Implements the `Node` trait for the `Prompt` struct.
#[async_trait]
impl<'a> Node for Prompt<'a> {
    /// Input HashMap that will be converted to the tera::Context.
    type Input = HashMap<&'a str, &'a str>;
    /// Output string from the rendered template.
    type Output = String;

    /// Processes the input HashMap and returns the rendered template.
    #[cfg_attr(feature = "tracing", instrument)]
    async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
        let context = Context::from_serialize(input)?;
        Ok(self.tera.render("prompt", &context)?.to_string())
    }
}