{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Guidance acceleration\n", "\n", "When multiple generation or LLM-directed control flow statements are used in a single Guidance program then we can significantly improve inference performance by maintaining a session state with the LLM inferencer and so reusing the Key/Value caches as we progress through the program. This is much faster than letting the model generate all of the structural tokens itself (for example if the structure was demonstrated using a one-shot example), and also faster that simply recalling the model without any state at each point in the Guidance program.\n", "\n", "We call this \"guidance acceleration\" and it is supported currently by local models such as `guidance.models.LlamaCpp` or `guidance.models.Transformers` as demonstrated below." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import time\n", "import torch\n", "import guidance\n", "from guidance import models, gen" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this.\n" ] } ], "source": [ "# Define a trivial string we can extend with small models easily\n", "prefix = \"Repeat this. Repeat this. \"*5 + \"Repeat this. Repeat this.\"\n", "print(prefix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the model we will test" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "560ce990c6484163aa19a2f37788dfeb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "With guidance acceleration and token healing: 3.9129340648651123\n" ] } ], "source": [ "@guidance\n", "def with_acceleration(lm, n_reps) -> float:\n", " for _ in range(n_reps):\n", " lm += prefix + gen(name='story', max_tokens=4) + \" \"\n", " return lm\n", "\n", "start = time.time()\n", "llm_gpt2_large_gpu + with_acceleration(10)\n", "print(f\"With guidance acceleration and token healing:\", time.time() - start)\n", "\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run without token acceleration" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this.
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Without guidance acceleration: 4.819199085235596\n" ] } ], "source": [ "@guidance\n", "def without_acceleration(lm, n_reps) -> float:\n", " prompt = \"\"\n", " for _ in range(n_reps):\n", "\n", " # disable KV cache reuse (using non-public member variables since there is not normal reason to do this)\n", " lm._cache_state[\"past_key_values\"] = None\n", " lm._cache_state[\"logits\"] = None\n", " lm._cache_state[\"cache_token_ids\"] = []\n", "\n", " # generate a chunk\n", " lm_new = lm + prompt + prefix + gen(name='story', max_tokens=4)\n", "\n", " prompt += prefix + \" Repeat this. \" # update the prompt\n", " return lm_new\n", "\n", "start = time.time()\n", "llm_gpt2_large_gpu + without_acceleration(10)\n", "print(f\"Without guidance acceleration:\", time.time() - start)\n", "\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run as a single generation call\n", "\n", "This is how you have to run most remote endpoints that don't support guidance." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this. Repeat this.
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Single generation call of same length: 20.686521768569946\n" ] } ], "source": [ "@guidance\n", "def single_gen_call(lm, n_reps) -> float:\n", " lm += prefix + gen(name='story', max_tokens=(13*4) * n_reps)\n", " return lm\n", "\n", "start = time.time()\n", "llm_gpt2_large_gpu + single_gen_call(10)\n", "print(f\"Single generation call of same length:\", time.time() - start)\n", "\n", "torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
\n", "
Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!
" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 4 }