forked from kshitizrimal/applied-deep-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
34ac542
commit 9be7332
Showing
1 changed file
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"name": "easy-gpt-demo.ipynb", | ||
"version": "0.3.2", | ||
"provenance": [], | ||
"private_outputs": true, | ||
"collapsed_sections": [] | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"accelerator": "GPU" | ||
}, | ||
"cells": [ | ||
{ | ||
"metadata": { | ||
"id": "nOV0EV6vArTZ", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"!git clone https://github.com/openai/gpt-2.git" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"metadata": { | ||
"id": "tjouOxR6BHdq", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"!cd gpt-2 && sh download_model.sh 117M" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"metadata": { | ||
"id": "JpsJG2qZBPWh", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"!cd gpt-2 && pip3 install -r requirements.txt" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"metadata": { | ||
"id": "H7wLvPqPCfCJ", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"!mv gpt-2/models models" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"metadata": { | ||
"id": "Dju6D9MABz1v", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"import sys\n", | ||
"import fire\n", | ||
"import json\n", | ||
"import os\n", | ||
"import numpy as np\n", | ||
"import tensorflow as tf\n", | ||
"import textwrap\n", | ||
"\n", | ||
"sys.path.insert(0, './gpt-2/src')\n", | ||
"import model, sample, encoder, generate_unconditional_samples" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"metadata": { | ||
"id": "H8UpxUHVGPPF", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"def interact_model(\n", | ||
" model_name='117M',\n", | ||
" seed=None,\n", | ||
" nsamples=1,\n", | ||
" batch_size=None,\n", | ||
" length=None,\n", | ||
" temperature=1,\n", | ||
" top_k=0,\n", | ||
" raw_text='test',\n", | ||
"):\n", | ||
" if batch_size is None:\n", | ||
" batch_size = 1\n", | ||
" assert nsamples % batch_size == 0\n", | ||
" np.random.seed(seed)\n", | ||
" tf.set_random_seed(seed)\n", | ||
"\n", | ||
" enc = encoder.get_encoder(model_name)\n", | ||
" hparams = model.default_hparams()\n", | ||
" with open(os.path.join('models', model_name, 'hparams.json')) as f:\n", | ||
" hparams.override_from_dict(json.load(f))\n", | ||
"\n", | ||
" if length is None:\n", | ||
" length = hparams.n_ctx // 2\n", | ||
" elif length > hparams.n_ctx:\n", | ||
" raise ValueError(\"Can't get samples longer than window size: %s\" % hparams.n_ctx)\n", | ||
"\n", | ||
" with tf.Session(graph=tf.Graph()) as sess:\n", | ||
" context = tf.placeholder(tf.int32, [batch_size, None])\n", | ||
" output = sample.sample_sequence(\n", | ||
" hparams=hparams, length=length,\n", | ||
" context=context,\n", | ||
" batch_size=batch_size,\n", | ||
" temperature=temperature, top_k=top_k\n", | ||
" )\n", | ||
"\n", | ||
" saver = tf.train.Saver()\n", | ||
" ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))\n", | ||
" saver.restore(sess, ckpt)\n", | ||
"\n", | ||
" context_tokens = enc.encode(raw_text)\n", | ||
" generated = 0\n", | ||
" for _ in range(nsamples // batch_size):\n", | ||
" out = sess.run(output, feed_dict={\n", | ||
" context: [context_tokens for _ in range(batch_size)]\n", | ||
" })[:, len(context_tokens):]\n", | ||
" for i in range(batch_size):\n", | ||
" generated += 1\n", | ||
" text = enc.decode(out[i])\n", | ||
" print(\"=\" * 40 + \" SAMPLE \" + str(generated) + \" \" + \"=\" * 40)\n", | ||
" print(textwrap.fill(text, 150))\n", | ||
" print(\"=\" * 80)" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
}, | ||
{ | ||
"metadata": { | ||
"id": "jkrryjyWLZvS", | ||
"colab_type": "text" | ||
}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"### Generate text based on a conditioning text" | ||
] | ||
}, | ||
{ | ||
"metadata": { | ||
"id": "iBQIGgUgG0wh", | ||
"colab_type": "code", | ||
"colab": {} | ||
}, | ||
"cell_type": "code", | ||
"source": [ | ||
"# Text below from 'The Onion'\n", | ||
"interact_model(\n", | ||
" length=90,\n", | ||
" nsamples=3,\n", | ||
" raw_text=\"NEW YORK—At 4:32 p.m. Tuesday, every single resident of New York City decided to evacuate the famed metropolis, having realized it was nothing more than a massive, trash-ridden hellhole that slowly sucks the life out of every one of its inhabitants.\")" | ||
], | ||
"execution_count": 0, | ||
"outputs": [] | ||
} | ||
] | ||
} |