diff --git a/examples/llm_inference/conversion/llm_conversion.ipynb b/examples/llm_inference/conversion/llm_conversion.ipynb index f9966831..bbb21074 100644 --- a/examples/llm_inference/conversion/llm_conversion.ipynb +++ b/examples/llm_inference/conversion/llm_conversion.ipynb @@ -757,7 +757,7 @@ "# @title Run { display-mode: \"form\"}\n", "\n", "model = widgets.Dropdown(\n", - " options=[\"Gemma 2B\", \"Falcon 1B\", \"StableLM 3B\", \"Phi 2\"],\n", + " options=[\"Gemma 2B\",\"Gemma 7B\", \"Falcon 1B\", \"StableLM 3B\", \"Phi 2\"],\n", " value='Gemma 2B',\n", " description='model',\n", " disabled=False,\n", @@ -777,16 +777,34 @@ " disabled=False\n", ")\n", "\n", - "def on_change_model(change):\n", - " if change[\"new\"] != 'Gemma 2b':\n", - " token_description.layout.display = \"none\"\n", - " token.layout.display = \"none\"\n", + "options_mapping = {\n", + " 'Gemma 2B': ['cpu', 'gpu'],\n", + " 'Gemma 7B': ['gpu'],\n", + " 'Falcon 1B': ['cpu', 'gpu'],\n", + " 'StableLM 3B': ['cpu', 'gpu'],\n", + " 'Phi 2': ['cpu', 'gpu']\n", + "}\n", + "\n", + "def on_use_gpu(change):\n", + " selected_value = change['new']\n", + "\n", + " if selected_value in options_mapping:\n", + " backend.options = options_mapping[selected_value]\n", + " backend.value = options_mapping[selected_value][0]\n", " else:\n", - " token_description.layout.display = \"flex\"\n", - " token.layout.display = \"flex\"\n", + " token.options = []\n", + " token.value = None\n", "\n", - "model.observe(on_change_model, names=['value'])\n", + "def on_change_model(change):\n", + " selected_values = ['Gemma 2B','Gemma 7B']\n", + "\n", + " if change['new'] in selected_values:\n", + " token.layout.display = 'flex'\n", + " else:\n", + " token.layout.display = 'none'\n", "\n", + "model.observe(on_change_model, names='value')\n", + "model.observe(on_use_gpu, names='value')\n", "\n", "\n", "display(model)\n", @@ -802,7 +820,7 @@ "\n", "\n", "\n", - "def gemma_download(token):\n", + "def gemma2b_download(token):\n", " REPO_ID = \"google/gemma-2b-it\"\n", " FILENAMES = [\"tokenizer.json\", \"tokenizer_config.json\", \"model-00001-of-00002.safetensors\", \"model-00002-of-00002.safetensors\"]\n", " os.environ['HF_TOKEN'] = token\n", @@ -810,6 +828,14 @@ " for filename in FILENAMES:\n", " hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=\"./gemma-2b-it\")\n", "\n", + "def gemma7b_download(token):\n", + " REPO_ID = \"google/gemma-1.1-7b-it\"\n", + " FILENAMES = [\"tokenizer.json\", \"tokenizer_config.json\", \"model-00001-of-00004.safetensors\", \"model-00002-of-00004.safetensors\", \"model-00003-of-00004.safetensors\", \"model-00004-of-00004.safetensors\"]\n", + " os.environ['HF_TOKEN'] = token\n", + " with out:\n", + " for filename in FILENAMES:\n", + " hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=\"./gemma-1.1-7b-it\")\n", + "\n", "def falcon_download():\n", " REPO_ID = \"tiiuae/falcon-rw-1b\"\n", " FILENAMES = [\"tokenizer.json\", \"tokenizer_config.json\", \"pytorch_model.bin\"]\n", @@ -832,13 +858,20 @@ " hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=\"./phi-2\")\n", "\n", "\n", - "def gemma_convert_config(backend):\n", + "def gemma2b_convert_config(backend):\n", " input_ckpt = '/content/gemma-2b-it/'\n", " vocab_model_file = '/content/gemma-2b-it/'\n", " output_dir = '/content/intermediate/gemma-2b-it/'\n", " output_tflite_file = f'/content/converted_models/gemma_{backend}.bin'\n", " return converter.ConversionConfig(input_ckpt=input_ckpt, ckpt_format='safetensors', model_type='GEMMA_2B', backend=backend, output_dir=output_dir, combine_file_only=False, vocab_model_file=vocab_model_file, output_tflite_file=output_tflite_file)\n", "\n", + "def gemma7b_convert_config(backend):\n", + " input_ckpt = '/content//gemma-1.1-7b-it/'\n", + " vocab_model_file = '/content//gemma-1.1-7b-it/'\n", + " output_dir = '/content/intermediate//gemma-1.1-7b-it/'\n", + " output_tflite_file = f'/content/converted_models/gemma_{backend}.bin'\n", + " return converter.ConversionConfig(input_ckpt=input_ckpt, ckpt_format='safetensors', model_type='GEMMA_7B', backend=backend, output_dir=output_dir, combine_file_only=False, vocab_model_file=vocab_model_file, output_tflite_file=output_tflite_file)\n", + "\n", "def falcon_convert_config(backend):\n", " input_ckpt = '/content/falcon-rw-1b/pytorch_model.bin'\n", " vocab_model_file = '/content/falcon-rw-1b/'\n", @@ -874,7 +907,9 @@ " backend.disabled = True\n", "\n", " if model.value == 'Gemma 2B':\n", - " gemma_download(token.value)\n", + " gemma2b_download(token.value)\n", + " elif model.value == 'Gemma 7B':\n", + " gemma7b_download(token.value)\n", " elif model.value == 'Falcon 1B':\n", " falcon_download()\n", " elif model.value == 'StableLM 3B':\n", @@ -891,7 +926,9 @@ " button.description = \"Converting ...\"\n", "\n", " if model.value == 'Gemma 2B':\n", - " config = gemma_convert_config(backend.value)\n", + " config = gemma2b_convert_config(backend.value)\n", + " elif model.value == 'Gemma 7B':\n", + " config = gemma7b_convert_config(backend.value)\n", " elif model.value == 'Falcon 1B':\n", " config = falcon_convert_config(backend.value)\n", " elif model.value == 'StableLM 3B':\n",