1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
//! Provides a structure for processing input through multiple nodes in parallel.
//!
//! The `ParallelNode` struct represents a node that processes input through
//! multiple nodes in parallel. The output of each node is then combined using
//! a provided function to produce the final output.
//!
//! Example:
//! ```rust,no_run
//! use async_trait::async_trait;
//! use futures::{future::BoxFuture, Future};
//! use std::collections::HashMap;
//!
//! use anchor_chain::{
//!     chain::ChainBuilder,
//!     models::openai::OpenAIModel,
//!     parallel_node::{ParallelNode, to_boxed_future},
//!     nodes::prompt::Prompt,
//! };
//!
//! #[tokio::main]
//! async fn main() {
//!     let gpt3 =
//!         Box::new(OpenAIModel::new_gpt3_5_turbo("You are a helpful assistant").await);
//!     let gpt4 = Box::new(OpenAIModel::new_gpt4_turbo("You are a helpful assistant").await);
//!
//!     let concat_fn = to_boxed_future(|outputs: Vec<String>| {
//!         Ok(outputs
//!             .iter()
//!             .enumerate()
//!             .map(|(i, output)| format!("Output {}:\n```\n{}\n```\n", i + 1, output))
//!             .collect::<Vec<String>>()
//!             .concat())
//!     });
//!
//!
//!     let chain = ChainBuilder::new()
//!         .link(Prompt::new("{{ input }}"))
//!         .link(ParallelNode::new(vec![gpt3, gpt4], concat_fn))
//!         .build();
//!
//!     let output = chain
//!         .process(HashMap::from([("input", "Write a hello world program in Rust")]))
//!         .await
//!         .expect("Error processing chain");
//!     println!("{}", output);
//! }
//! ```
use anchor_chain_macros::Stateless;
use async_trait::async_trait;
use futures::future::try_join_all;
use futures::{future::BoxFuture, FutureExt};
use std::fmt;
#[cfg(feature = "tracing")]
use tracing::{instrument, Instrument};
use crate::error::AnchorChainError;
use crate::node::Node;
/// A function that combines the output of multiple nodes.
///
/// The function takes a vector of outputs from multiple nodes and returns a
/// `Result` containing the final output. The BoxFuture can be created using
/// the `to_boxed_future` helper function.
type CombinationFunction<I, O> =
    Box<dyn Fn(Vec<I>) -> BoxFuture<'static, Result<O, AnchorChainError>> + Send + Sync>;
/// A node that processes input through multiple nodes in parallel.
///
/// The `ParallelNode` struct represents a node that processes input through
/// multiple nodes in parallel. The output of each node is then combined using
/// a provided function to produce the final output.
#[derive(Stateless)]
pub struct ParallelNode<I, O, C>
where
    I: Clone + Send + Sync + fmt::Debug,
    O: Send + Sync + fmt::Debug,
    C: Send + Sync + fmt::Debug,
{
    /// The nodes that will process the input in parallel.
    pub nodes: Vec<Box<dyn Node<Input = I, Output = O> + Send + Sync>>,
    /// The function to process the output of the nodes.
    pub function: CombinationFunction<O, C>,
}
impl<I, O, C> ParallelNode<I, O, C>
where
    I: Clone + Send + Sync + fmt::Debug,
    O: Send + Sync + fmt::Debug,
    C: Send + Sync + fmt::Debug,
{
    /// Creates a new `ParallelNode` with the provided nodes and combination
    /// function.
    ///
    /// The combination function can be defined using the helper function `to_boxed_future`.
    ///
    /// # Example
    /// // Using PassThroughNode as an example node
    /// ```rust
    /// use anchor_chain::{
    ///     node::NoOpNode,
    ///     parallel_node::ParallelNode,
    ///     parallel_node::to_boxed_future
    /// };
    ///
    /// #[tokio::main]
    /// async fn main() {
    ///     let node1 = Box::new(NoOpNode::new());
    ///     let node2 = Box::new(NoOpNode::new());
    ///     let concat_fn = to_boxed_future(|outputs: Vec<String>| {
    ///         Ok(outputs
    ///            .iter()
    ///            .enumerate()
    ///            .map(|(i, output)| format!("Output {}:\n```\n{}\n```\n", i + 1, output))
    ///            .collect::<Vec<String>>()
    ///            .concat())
    ///     });
    ///     let parallel_node = ParallelNode::new(vec![node1, node2], concat_fn);
    /// }
    pub fn new(
        nodes: Vec<Box<dyn Node<Input = I, Output = O> + Send + Sync>>,
        function: CombinationFunction<O, C>,
    ) -> Self {
        ParallelNode { nodes, function }
    }
}
#[async_trait]
impl<I, O, C> Node for ParallelNode<I, O, C>
where
    I: Clone + Send + Sync + fmt::Debug,
    O: Send + Sync + fmt::Debug,
    C: Send + Sync + fmt::Debug,
{
    type Input = I;
    type Output = C;
    /// Processes the given input through nodes in parallel.
    ///
    /// The input is processed by each node in parallel, and the results are combined
    /// using the provided function to produce the final output.
    #[cfg_attr(feature = "tracing", instrument)]
    async fn process(&self, input: Self::Input) -> Result<Self::Output, AnchorChainError> {
        let futures = self.nodes.iter().map(|node| {
            let input_clone = input.clone();
            async move { node.process(input_clone).await }
        });
        let results = try_join_all(futures);
        #[cfg(feature = "tracing")]
        let results = results.instrument(tracing::info_span!("Joining parallel node futures"));
        let results = results.await?;
        let combined_results = (self.function)(results);
        #[cfg(feature = "tracing")]
        let combined_results =
            combined_results.instrument(tracing::info_span!("Combining parallel node outputs"));
        combined_results.await
    }
}
impl<I, O, C> fmt::Debug for ParallelNode<I, O, C>
where
    I: fmt::Debug + Clone + Send + Sync,
    O: fmt::Debug + Send + Sync,
    C: fmt::Debug + Send + Sync,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("ParallelNode")
            .field("nodes", &self.nodes)
            // Unable to debug print closures
            .field("function", &format_args!("<function/closure>"))
            .finish()
    }
}
/// Converts a function into a `BoxFuture` that can be used in a `ParallelNode`.
///
/// This function takes a function that processes input and returns a `Result` and
/// converts it into a boxed future.
pub fn to_boxed_future<F, I, O>(
    f: F,
) -> Box<dyn Fn(I) -> BoxFuture<'static, Result<O, AnchorChainError>> + Send + Sync>
where
    F: Fn(I) -> Result<O, AnchorChainError> + Send + Sync + Clone + 'static,
    I: Send + 'static,
{
    Box::new(move |input| {
        let f_clone = f.clone();
        async move { f_clone(input) }.boxed()
    })
}