Skip to content

Commit

Permalink
Merge pull request #472 from kuaashish/main
Browse files Browse the repository at this point in the history
 Added a conversion Python example for the Gemma 7B model.
  • Loading branch information
woodyhoko authored Nov 15, 2024
2 parents dccb79d + ed73b58 commit 99a3a6c
Showing 1 changed file with 49 additions and 12 deletions.
61 changes: 49 additions & 12 deletions examples/llm_inference/conversion/llm_conversion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@
"# @title Run { display-mode: \"form\"}\n",
"\n",
"model = widgets.Dropdown(\n",
" options=[\"Gemma 2B\", \"Falcon 1B\", \"StableLM 3B\", \"Phi 2\"],\n",
" options=[\"Gemma 2B\",\"Gemma 7B\", \"Falcon 1B\", \"StableLM 3B\", \"Phi 2\"],\n",
" value='Gemma 2B',\n",
" description='model',\n",
" disabled=False,\n",
Expand All @@ -777,16 +777,34 @@
" disabled=False\n",
")\n",
"\n",
"def on_change_model(change):\n",
" if change[\"new\"] != 'Gemma 2b':\n",
" token_description.layout.display = \"none\"\n",
" token.layout.display = \"none\"\n",
"options_mapping = {\n",
" 'Gemma 2B': ['cpu', 'gpu'],\n",
" 'Gemma 7B': ['gpu'],\n",
" 'Falcon 1B': ['cpu', 'gpu'],\n",
" 'StableLM 3B': ['cpu', 'gpu'],\n",
" 'Phi 2': ['cpu', 'gpu']\n",
"}\n",
"\n",
"def on_use_gpu(change):\n",
" selected_value = change['new']\n",
"\n",
" if selected_value in options_mapping:\n",
" backend.options = options_mapping[selected_value]\n",
" backend.value = options_mapping[selected_value][0]\n",
" else:\n",
" token_description.layout.display = \"flex\"\n",
" token.layout.display = \"flex\"\n",
" token.options = []\n",
" token.value = None\n",
"\n",
"model.observe(on_change_model, names=['value'])\n",
"def on_change_model(change):\n",
" selected_values = ['Gemma 2B','Gemma 7B']\n",
"\n",
" if change['new'] in selected_values:\n",
" token.layout.display = 'flex'\n",
" else:\n",
" token.layout.display = 'none'\n",
"\n",
"model.observe(on_change_model, names='value')\n",
"model.observe(on_use_gpu, names='value')\n",
"\n",
"\n",
"display(model)\n",
Expand All @@ -802,14 +820,22 @@
"\n",
"\n",
"\n",
"def gemma_download(token):\n",
"def gemma2b_download(token):\n",
" REPO_ID = \"google/gemma-2b-it\"\n",
" FILENAMES = [\"tokenizer.json\", \"tokenizer_config.json\", \"model-00001-of-00002.safetensors\", \"model-00002-of-00002.safetensors\"]\n",
" os.environ['HF_TOKEN'] = token\n",
" with out:\n",
" for filename in FILENAMES:\n",
" hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=\"./gemma-2b-it\")\n",
"\n",
"def gemma7b_download(token):\n",
" REPO_ID = \"google/gemma-1.1-7b-it\"\n",
" FILENAMES = [\"tokenizer.json\", \"tokenizer_config.json\", \"model-00001-of-00004.safetensors\", \"model-00002-of-00004.safetensors\", \"model-00003-of-00004.safetensors\", \"model-00004-of-00004.safetensors\"]\n",
" os.environ['HF_TOKEN'] = token\n",
" with out:\n",
" for filename in FILENAMES:\n",
" hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=\"./gemma-1.1-7b-it\")\n",
"\n",
"def falcon_download():\n",
" REPO_ID = \"tiiuae/falcon-rw-1b\"\n",
" FILENAMES = [\"tokenizer.json\", \"tokenizer_config.json\", \"pytorch_model.bin\"]\n",
Expand All @@ -832,13 +858,20 @@
" hf_hub_download(repo_id=REPO_ID, filename=filename, local_dir=\"./phi-2\")\n",
"\n",
"\n",
"def gemma_convert_config(backend):\n",
"def gemma2b_convert_config(backend):\n",
" input_ckpt = '/content/gemma-2b-it/'\n",
" vocab_model_file = '/content/gemma-2b-it/'\n",
" output_dir = '/content/intermediate/gemma-2b-it/'\n",
" output_tflite_file = f'/content/converted_models/gemma_{backend}.bin'\n",
" return converter.ConversionConfig(input_ckpt=input_ckpt, ckpt_format='safetensors', model_type='GEMMA_2B', backend=backend, output_dir=output_dir, combine_file_only=False, vocab_model_file=vocab_model_file, output_tflite_file=output_tflite_file)\n",
"\n",
"def gemma7b_convert_config(backend):\n",
" input_ckpt = '/content//gemma-1.1-7b-it/'\n",
" vocab_model_file = '/content//gemma-1.1-7b-it/'\n",
" output_dir = '/content/intermediate//gemma-1.1-7b-it/'\n",
" output_tflite_file = f'/content/converted_models/gemma_{backend}.bin'\n",
" return converter.ConversionConfig(input_ckpt=input_ckpt, ckpt_format='safetensors', model_type='GEMMA_7B', backend=backend, output_dir=output_dir, combine_file_only=False, vocab_model_file=vocab_model_file, output_tflite_file=output_tflite_file)\n",
"\n",
"def falcon_convert_config(backend):\n",
" input_ckpt = '/content/falcon-rw-1b/pytorch_model.bin'\n",
" vocab_model_file = '/content/falcon-rw-1b/'\n",
Expand Down Expand Up @@ -874,7 +907,9 @@
" backend.disabled = True\n",
"\n",
" if model.value == 'Gemma 2B':\n",
" gemma_download(token.value)\n",
" gemma2b_download(token.value)\n",
" elif model.value == 'Gemma 7B':\n",
" gemma7b_download(token.value)\n",
" elif model.value == 'Falcon 1B':\n",
" falcon_download()\n",
" elif model.value == 'StableLM 3B':\n",
Expand All @@ -891,7 +926,9 @@
" button.description = \"Converting ...\"\n",
"\n",
" if model.value == 'Gemma 2B':\n",
" config = gemma_convert_config(backend.value)\n",
" config = gemma2b_convert_config(backend.value)\n",
" elif model.value == 'Gemma 7B':\n",
" config = gemma7b_convert_config(backend.value)\n",
" elif model.value == 'Falcon 1B':\n",
" config = falcon_convert_config(backend.value)\n",
" elif model.value == 'StableLM 3B':\n",
Expand Down

0 comments on commit 99a3a6c

Please sign in to comment.