Skip to content

teleprompt.BootstrapFinetune

Constructor

__init__(self, metric=None, teacher_settings={}, multitask=True)

The constructor initializes a BootstrapFinetune instance and sets up its attributes. It defines the teleprompter as a BootstrapFewShot instance for the finetuning compilation.

class BootstrapFinetune(Teleprompter):
    def __init__(self, metric=None, teacher_settings={}, multitask=True):

Parameters: - metric (callable, optional): Metric function to evaluate examples during bootstrapping. Defaults to None. - teacher_settings (dict, optional): Settings for teacher predictor. Defaults to empty dictionary. - multitask (bool, optional): Enable multitask fine-tuning. Defaults to True.

Method

compile(self, student, *, teacher=None, trainset, valset=None, target='t5-large', bsize=12, accumsteps=1, lr=5e-5, epochs=1, bf16=False)

This method first compiles for bootstrapping with the BootstrapFewShot teleprompter. It then prepares fine-tuning data by generating prompt-completion pairs for training and performs finetuning. After compilation, the LMs are set to the finetuned models and the method returns a compiled and fine-tuned predictor.

Parameters: - student (Predict): Student predictor to be fine-tuned. - teacher (Predict, optional): Teacher predictor to help with fine-tuning. Defaults to None. - trainset (list): Training dataset for fine-tuning. - valset (list, optional): Validation dataset for fine-tuning. Defaults to None. - target (str, optional): Target model for fine-tuning. Defaults to 't5-large'. - bsize (int, optional): Batch size for training. Defaults to 12. - accumsteps (int, optional): Gradient accumulation steps. Defaults to 1. - lr (float, optional): Learning rate for fine-tuning. Defaults to 5e-5. - epochs (int, optional): Number of training epochs. Defaults to 1. - bf16 (bool, optional): Enable mixed-precision training with BF16. Defaults to False.

Returns: - compiled2 (Predict): A compiled and fine-tuned Predict instance.

Example

#Assume defined trainset
#Assume defined RAG class
...

#Define teleprompter
teleprompter = BootstrapFinetune(teacher_settings=dict({'lm': teacher}))

# Compile!
compiled_rag = teleprompter.compile(student=RAG(), trainset=trainset, target='google/flan-t5-base')