Skip to content

dspy.MultiChainComparison

Constructor

The constructor initializes the MultiChainComparison class and sets up its attributes. It inherits from the Predict class and adds specific functionality for multiple chain comparisons.

The class incorporates multiple student attempt reasonings and concludes with the selected best reasoning path out of the available attempts.

from .predict import Predict
from ..primitives.program import Module

import dsp

class MultiChainComparison(Module):
    def __init__(self, signature, M=3, temperature=0.7, **config):
        super().__init__()

        self.M = M
        signature = Predict(signature).signature
        *keys, last_key = signature.kwargs.keys()

        extended_kwargs = {key: signature.kwargs[key] for key in keys}

        for idx in range(M):
            candidate_type = dsp.Type(prefix=f"Student Attempt #{idx+1}:", desc="${reasoning attempt}")
            extended_kwargs.update({f'reasoning_attempt_{idx+1}': candidate_type})

        rationale_type = dsp.Type(prefix="Accurate Reasoning: Thank you everyone. Let's now holistically", desc="${corrected reasoning}")
        extended_kwargs.update({'rationale': rationale_type, last_key: signature.kwargs[last_key]})

        signature = dsp.Template(signature.instructions, **extended_kwargs)
        self.predict = Predict(signature, temperature=temperature, **config)
        self.last_key = last_key

Parameters: - signature (Any): Signature of predictive model. - M (int, optional): Number of student reasoning attempts. Defaults to 3. - temperature (float, optional): Temperature parameter for prediction. Defaults to 0.7. - **config (dict): Additional configuration parameters for model.

Method

forward(self, completions, **kwargs)

This method aggregates all the student reasoning attempts and calls the predict method with extended signatures to get the best reasoning.

Parameters: - completions: List of completion objects which include student reasoning attempts. - **kwargs: Additional keyword arguments.

Returns: - The result of the predict method for the best reasoning.

Examples

class BasicQA(dspy.Signature):
    """Answer questions with short factoid answers."""
    question = dspy.InputField()
    answer = dspy.OutputField(desc="often between 1 and 5 words")

# Example completions generated by a model for reference
completions = [
    dspy.Prediction(rationale="I recall that during clear days, the sky often appears this color.", answer="blue"),
    dspy.Prediction(rationale="Based on common knowledge, I believe the sky is typically seen as this color.", answer="green"),
    dspy.Prediction(rationale="From images and depictions in media, the sky is frequently represented with this hue.", answer="blue"),
]

# Pass signature to MultiChainComparison module
compare_answers = dspy.MultiChainComparison(BasicQA)

# Call the MultiChainComparison on the completions
question = 'What is the color of the sky?'
final_pred = compare_answers(completions, question=question)

print(f"Question: {question}")
print(f"Final Predicted Answer (after comparison): {final_pred.answer}")
print(f"Final Rationale: {final_pred.rationale}")