Guidance acceleration

When multiple generation or LLM-directed control flow statements are used in a single Guidance program then we can significantly improve inference performance by maintaining a session state with the LLM inferencer and so reusing the Key/Value caches as we progress through the program. This is much faster than letting the model generate all of the structural tokens itself (for example if the structure was demonstrated using a one-shot example), and also faster that simply recalling the model without any state at each point in the Guidance program.

We call this “guidance acceleration” and it is supported currently by local models such as guidance.models.LlamaCpp or guidance.models.Transformers as demonstrated below.

[1]:
import time
import torch
import guidance
from guidance import models, gen
[2]:
# Define a trivial string we can extend with small models easily
prefix = "Repeat this. Repeat this. "*5 + "Repeat this. Repeat this."
print(prefix)
Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this.

Load the model we will test

[3]:
model = 'mistralai/Mistral-7B-v0.1'
device = 'cuda'
llm_gpt2_large_gpu = guidance.models.Transformers(model, device=device) # run on an A100 for the numbers below

Run with normal token acceleration

[4]:
@guidance
def with_acceleration(lm, n_reps) -> float:
    for _ in range(n_reps):
        lm += prefix + gen(name='story', max_tokens=4) + " "
    return lm

start = time.time()
llm_gpt2_large_gpu + with_acceleration(10)
print(f"With guidance acceleration and token healing:", time.time() - start)

torch.cuda.empty_cache()
Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. 
With guidance acceleration and token healing: 3.9129340648651123

Run without token acceleration

[5]:
@guidance
def without_acceleration(lm, n_reps) -> float:
    prompt = ""
    for _ in range(n_reps):

        # disable KV cache reuse (using non-public member variables since there is not normal reason to do this)
        lm._cache_state["past_key_values"] = None
        lm._cache_state["logits"] = None
        lm._cache_state["cache_token_ids"] = []

        # generate a chunk
        lm_new = lm + prompt + prefix + gen(name='story', max_tokens=4)

        prompt += prefix + " Repeat this. " # update the prompt
    return lm_new

start = time.time()
llm_gpt2_large_gpu + without_acceleration(10)
print(f"Without guidance acceleration:", time.time() - start)

torch.cuda.empty_cache()
Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this.
Without guidance acceleration: 4.819199085235596

Run as a single generation call

This is how you have to run most remote endpoints that don’t support guidance.

[6]:
@guidance
def single_gen_call(lm, n_reps) -> float:
    lm += prefix + gen(name='story', max_tokens=(13*4) * n_reps)
    return lm

start = time.time()
llm_gpt2_large_gpu + single_gen_call(10)
print(f"Single generation call of same length:", time.time() - start)

torch.cuda.empty_cache()
Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this.
Single generation call of same length: 20.686521768569946

Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!