teleprompt.BootstrapFinetune
Constructor
__init__(self, metric=None, teacher_settings={}, multitask=True)
The constructor initializes a BootstrapFinetune
instance and sets up its attributes. It defines the teleprompter as a BootstrapFewShot
instance for the finetuning compilation.
class BootstrapFinetune(Teleprompter):
def __init__(self, metric=None, teacher_settings={}, multitask=True):
Parameters:
metric
(callable, optional): Metric function to evaluate examples during bootstrapping. Defaults toNone
.teacher_settings
(dict, optional): Settings for teacher predictor. Defaults to empty dictionary.multitask
(bool, optional): Enable multitask fine-tuning. Defaults toTrue
.
Method
compile(self, student, *, teacher=None, trainset, valset=None, target='t5-large', bsize=12, accumsteps=1, lr=5e-5, epochs=1, bf16=False)
This method first compiles for bootstrapping with the BootstrapFewShot
teleprompter. It then prepares fine-tuning data by generating prompt-completion pairs for training and performs finetuning. After compilation, the LMs are set to the finetuned models and the method returns a compiled and fine-tuned predictor.
Parameters:
student
(Predict): Student predictor to be fine-tuned.teacher
(Predict, optional): Teacher predictor to help with fine-tuning. Defaults toNone
.trainset
(list): Training dataset for fine-tuning.valset
(list, optional): Validation dataset for fine-tuning. Defaults toNone
.target
(str, optional): Target model for fine-tuning. Defaults to't5-large'
.bsize
(int, optional): Batch size for training. Defaults to12
.accumsteps
(int, optional): Gradient accumulation steps. Defaults to1
.lr
(float, optional): Learning rate for fine-tuning. Defaults to5e-5
.epochs
(int, optional): Number of training epochs. Defaults to1
.bf16
(bool, optional): Enable mixed-precision training with BF16. Defaults toFalse
.
Returns:
compiled2
(Predict): A compiled and fine-tunedPredict
instance.
Example
#Assume defined trainset
#Assume defined RAG class
...
#Define teleprompter
teleprompter = BootstrapFinetune(teacher_settings=dict({'lm': teacher}))
# Compile!
compiled_rag = teleprompter.compile(student=RAG(), trainset=trainset, target='google/flan-t5-base')