diff --git "a/docs/\347\254\254\345\215\201\347\253\240/NLP\345\237\272\347\241\200.ipynb" "b/docs/\347\254\254\345\215\201\347\253\240/NLP\345\237\272\347\241\200.ipynb" new file mode 100644 index 000000000..498882d2c --- /dev/null +++ "b/docs/\347\254\254\345\215\201\347\253\240/NLP\345\237\272\347\241\200.ipynb" @@ -0,0 +1,720 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "119ec186", + "metadata": {}, + "source": [ + "# 词嵌入(概念部分)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f8e5639e", + "metadata": {}, + "source": [ + "### 在了解什么是词嵌入之前,我们可以思考一下计算机如何识别人类的输入? \n", + "计算机通过将输入信息解析为0和1这般的二进制编码,从而将人类语言转化为机器语言,进行理解。 \n", + "我们先引入一个概念**one-hot编码**,也称为**独热编码**,在给定维度的情况下,一行向量有且仅有一个值为1,例如维度为5的向量[0,0,0,0,1] \n", + "例如,我们在幼儿园或小学学习汉语的时候,首先先识字和词,字和词就会保存在我们的大脑中的某处。
\n", + "\n", + "
一个小朋友刚学会了四个字和词-->[我] [特别] [喜欢] [学习]
\n", + "\n", + "我们的计算机就可以为小朋友开辟一个词向量维度为4的独热编码 \n", + "对于中文 我们先进行分词 我 特别 喜欢 学习 \n", + "那么我们就可以令 我->[1 0 0 0] 特别 ->[0 1 0 0] 喜欢->[0 0 1 0] 学习->[0 0 0 1] \n", + "现在给出一句话 我喜欢学习,那么计算机给出的词向量->[1 0 1 1] \n", + "我们可以思考几个问题: \n", + "1.如果小朋友词汇量越学越多,学到了成千上万个词之后,我们使用上述方法构建的词向量就会有非常大的维度,并且是一个稀疏向量。 \n", + "2.在中文中 诸如 能 会 可以 这样同义词,我们如果使用独热编码,它们是正交的,缺乏词之间的相似性,很难把他们联系到一起。 \n", + "因此我们认为独热编码不是一个很好的词嵌入方法。 \n", + "\n", + "我们再来介绍一下 **稠密表示** \n", + "稠密表示的格式如one-hot编码一致,但数值却不同,如 [0.45,0.65,0.14,1.15,0.97] " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "4db86da3", + "metadata": {}, + "source": [ + "# Bag of Words词袋表示" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "44dc9252", + "metadata": {}, + "source": [ + "  词袋表示顾名思义,我们往一个袋子中装入我们的词汇,构成一个词袋,当我们想表达的时候,我们将其取出,构建词袋的方法可以有如下形式。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "823f8f2d", + "metadata": {}, + "outputs": [], + "source": [ + "corpus = [\"i like reading\", \"i love drinking\", \"i hate playing\", \"i do nlp\"]#我们的语料库\n", + "word_list = ' '.join(corpus).split()\n", + "word_list = list(sorted(set(word_list)))\n", + "word_dict = {w: i for i, w in enumerate(word_list)}\n", + "number_dict = {i: w for i, w in enumerate(word_list)}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8eaeb37d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'do': 0,\n", + " 'drinking': 1,\n", + " 'hate': 2,\n", + " 'i': 3,\n", + " 'like': 4,\n", + " 'love': 5,\n", + " 'nlp': 6,\n", + " 'playing': 7,\n", + " 'reading': 8}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "word_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2bf380c8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{0: 'do',\n", + " 1: 'drinking',\n", + " 2: 'hate',\n", + " 3: 'i',\n", + " 4: 'like',\n", + " 5: 'love',\n", + " 6: 'nlp',\n", + " 7: 'playing',\n", + " 8: 'reading'}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "number_dict" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "90e0ef43", + "metadata": {}, + "source": [ + "根据如上形式,我们可以构建一个维度为9的one&-hot编码,如下(除了可以使用np.eye构建,也可以通过sklearn的库调用)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "9821ed2a", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "voc_size = len(word_dict)\n", + "bow = []\n", + "for i,name in enumerate(word_dict):\n", + " bow.append(np.eye(voc_size)[word_dict[name]])" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "03f1f12f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([1., 0., 0., 0., 0., 0., 0., 0., 0.]),\n", + " array([0., 1., 0., 0., 0., 0., 0., 0., 0.]),\n", + " array([0., 0., 1., 0., 0., 0., 0., 0., 0.]),\n", + " array([0., 0., 0., 1., 0., 0., 0., 0., 0.]),\n", + " array([0., 0., 0., 0., 1., 0., 0., 0., 0.]),\n", + " array([0., 0., 0., 0., 0., 1., 0., 0., 0.]),\n", + " array([0., 0., 0., 0., 0., 0., 1., 0., 0.]),\n", + " array([0., 0., 0., 0., 0., 0., 0., 1., 0.]),\n", + " array([0., 0., 0., 0., 0., 0., 0., 0., 1.])]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bow" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "086a5fd2", + "metadata": {}, + "source": [ + "# N-gram:基于统计的语言模型\n", + "N-gram 模型是一种自然语言处理模型,它利用了语言中词语之间的相关性来预测下一个出现的词语。N-gram 模型通过对一段文本中连续出现的 n 个词语进行建模,来预测文本中接下来出现的词语。比如,如果一个文本中包含连续出现的词语“the cat sat on”,那么 N-gram 模型可能会预测接下来的词语是“the mat”或“a hat”。 \n", + "\n", + "N-gram 模型的精确性取决于用于训练模型的文本的质量和数量。如果用于训练模型的文本包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。此外,如果用于训练模型的文本量较少,那么模型也可能无法充分捕捉到语言中的复杂性。 \n", + "\n", + "**N-gram 模型的优点:**\n", + "\n", + "简单易用,N-gram 模型的概念非常简单,实现起来也很容易。 \n", + "能够捕捉到语言中的相关性,N-gram 模型通过考虑连续出现的 n 个词语来预测下一个词语,因此它能够捕捉到语言中词语之间的相关性。 \n", + "可以使用已有的语料库进行训练,N-gram 模型可以使用已有的大量语料库进行训练,例如 Google 的 N-gram 数据库,这样可以大大提高模型的准确性。 \n", + "\n", + "**N-gram 模型的缺点:**\n", + "\n", + "对于短文本数据集不适用,N-gram 模型需要大量的文本数据进行训练,因此对于短文本数据集可能无法达到较高的准确性。 \n", + "容易受到噪声和语言纠错的影响,N-gram 模型是基于语料库进行训练的,如果语料库中包含大量的语言纠错和拼写错误,那么模型的预测结果也可能不准确。 \n", + "无法捕捉到语言中的非线性关系,N-gram 模型假设语言中的关系是线性的,但事实上语言中可能存在复杂的非线性关系,N-gram 模型无法捕捉到这些关系。 " + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1f5ad65b", + "metadata": {}, + "source": [ + "# NNLM:前馈神经网络语言模型\n", + "下面通过前馈神经网络模型来**展示滑动**窗口的使用" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7bddfa77", + "metadata": {}, + "outputs": [], + "source": [ + "#导入必要的库\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from tqdm import tqdm\n", + "from torch.autograd import Variable\n", + "dtype = torch.FloatTensor" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "29f23588", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['i',\n", + " 'like',\n", + " 'reading',\n", + " 'i',\n", + " 'love',\n", + " 'drinking',\n", + " 'i',\n", + " 'hate',\n", + " 'playing',\n", + " 'i',\n", + " 'do',\n", + " 'nlp']" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "corpus = [\"i like reading\", \"i love drinking\", \"i hate playing\", \"i do nlp\"]\n", + "\n", + "word_list = ' '.join(corpus).split()\n", + "word_list" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "12b58886", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 1000 cost = 1.010682\n", + "epoch: 2000 cost = 0.695155\n", + "epoch: 3000 cost = 0.597085\n", + "epoch: 4000 cost = 0.531892\n", + "epoch: 5000 cost = 0.376044\n", + "epoch: 6000 cost = 0.118038\n", + "epoch: 7000 cost = 0.077081\n", + "epoch: 8000 cost = 0.053636\n", + "epoch: 9000 cost = 0.038089\n", + "epoch: 10000 cost = 0.027224\n", + "[['i', 'like'], ['i', 'love'], ['i', 'hate'], ['i', 'do']] -> ['studying', 'datawhale', 'playing', 'nlp']\n" + ] + } + ], + "source": [ + "#构建我们需要的语料库\n", + "corpus = [\"i like studying\", \"i love datawhale\", \"i hate playing\", \"i do nlp\"]\n", + "\n", + "word_list = ' '.join(corpus).split() #将语料库转化为一个个单词 ,如['i', 'like', 'reading', 'i', ...,'nlp']\n", + "word_list = list(sorted(set(word_list))) #用set去重后转化为链表\n", + "# print(word_list)\n", + "\n", + "word_dict = {w: i for i, w in enumerate(word_list)} #将词表转化为字典 这边是词对应到index\n", + "number_dict = {i: w for i, w in enumerate(word_list)}#这边是index对应到词\n", + "# print(word_dict)\n", + "# print(number_dict)\n", + "\n", + "n_class = len(word_dict) #计算出我们词表的大小,用于后面词向量的构建\n", + "\n", + "m = 2 #词嵌入维度\n", + "n_step = 2 #滑动窗口的大小\n", + "n_hidden = 2 #隐藏层的维度为2\n", + "\n", + "\n", + "def make_batch(sentence): #由于语料库较小,我们象征性将训练集按照批次处理 \n", + " input_batch = []\n", + " target_batch = []\n", + "\n", + " for sen in sentence:\n", + " word = sen.split()\n", + " input = [word_dict[n] for n in word[:-1]]\n", + " target = word_dict[word[-1]]\n", + "\n", + " input_batch.append(input)\n", + " target_batch.append(target)\n", + "\n", + " return input_batch, target_batch\n", + "\n", + "\n", + "class NNLM(nn.Module): #搭建一个NNLM语言模型\n", + " def __init__(self):\n", + " super(NNLM, self).__init__()\n", + " self.embed = nn.Embedding(n_class, m)\n", + " self.W = nn.Parameter(torch.randn(n_step * m, n_hidden).type(dtype))\n", + " self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))\n", + "\n", + " self.U = nn.Parameter(torch.randn(n_hidden, n_class).type(dtype))\n", + " self.b = nn.Parameter(torch.randn(n_class).type(dtype))\n", + "\n", + " def forward(self, x):\n", + " x = self.embed(x) # 4 x 2 x 2\n", + " x = x.view(-1, n_step * m)\n", + " tanh = torch.tanh(self.d + torch.mm(x, self.W)) # 4 x 2\n", + " output = self.b + torch.mm(tanh, self.U)\n", + " return output\n", + "\n", + "model = NNLM()\n", + "\n", + "criterion = nn.CrossEntropyLoss() #损失函数的设置\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001) #优化器的设置\n", + "\n", + "input_batch, target_batch = make_batch(corpus) #训练集和标签值\n", + "input_batch = Variable(torch.LongTensor(input_batch))\n", + "target_batch = Variable(torch.LongTensor(target_batch))\n", + "\n", + "for epoch in range(10000): #训练过程\n", + " optimizer.zero_grad()\n", + "\n", + " output = model(input_batch) # input: 4 x 2\n", + "\n", + " loss = criterion(output, target_batch)\n", + "\n", + " if (epoch + 1) % 1000 == 0:\n", + " print('epoch:', '%04d' % (epoch + 1), 'cost = {:.6f}'.format(loss.item()))\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "predict = model(input_batch).data.max(1, keepdim=True)[1]#模型预测过程\n", + "\n", + "print([sen.split()[:2] for sen in corpus], '->', [number_dict[n.item()] for n in predict.squeeze()])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "93d8cd2f", + "metadata": {}, + "source": [ + "# Word2Vec模型:主要采用Skip-gram和Cbow两种模式\n", + "前文提到的distributed representation稠密向量表达可以用Word2Vec模型进行训练得到。 \n", + "skip-gram模型(跳字模型)是用中心词去预测周围词 \n", + "cbow模型(连续词袋模型)是用周围词预测中心词 " + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "066f68a0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 11%|█ | 10615/100000 [00:02<00:24, 3657.80it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 10000 cost = 1.955088\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 21%|██ | 20729/100000 [00:05<00:21, 3758.47it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 20000 cost = 1.673096\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 30%|███ | 30438/100000 [00:08<00:18, 3710.13it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 30000 cost = 2.247422\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 41%|████ | 40638/100000 [00:11<00:15, 3767.87it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 40000 cost = 2.289902\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 50486/100000 [00:13<00:13, 3713.98it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 50000 cost = 2.396217\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 61%|██████ | 60572/100000 [00:16<00:11, 3450.47it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 60000 cost = 1.539688\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 71%|███████ | 70638/100000 [00:19<00:07, 3809.11it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 70000 cost = 1.638879\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 80403/100000 [00:21<00:05, 3740.33it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 80000 cost = 2.279797\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 90%|█████████ | 90480/100000 [00:24<00:02, 3680.03it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 90000 cost = 1.992100\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 100000/100000 [00:27<00:00, 3677.35it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 100000 cost = 1.307715\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD6CAYAAACiefy7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAnsUlEQVR4nO3de3hU1b3/8fc34U64KagRqQGLyC2BECSAXGysYKmAF4poRVFLU8VDbbViFZtqPW0PtB6pYkRBQFGOCoIoFn8IFDQIBAgIyr1RrhLBhATDJWT9/pghTUICCZnMTGY+r+fJk9lrr9nrO0PyYWftPXubcw4REQl9EYEuQERE/EOBLyISJhT4IiJhQoEvIhImFPgiImFCgS8iEiYU+FKtzCzGzDZVov9QM+tQnTWJhCsL5vPwmzdv7mJiYgJdhlTB8ePH2bFjBx07dqxQ/8zMTJo0aUKzZs2quTKR0LR27dpvnXMtylpXy9/FVEZMTAzp6emBLkOqIDMzkxtuuIGuXbuSlpZGy5YtmT9/Pq+//jpTpkzhxIkT/PCHP+S1114jIyODn/70pxQWFlJQUMCcOXMAeOCBB8jKyqJBgwa8/PLLXHXVVQF+VSLBy8y+Km+dpnSk2m3fvp0HHniAzZs307RpU+bMmcPNN9/MmjVr2LBhA+3bt2fq1Kn06tWLwYMHM2HCBDIyMrjiiisYPXo0//jHP1i7di0TJ07k/vvvD/TLEamxgnoPX0JD69at6dKlCwDdunUjMzOTTZs28cQTT5CdnU1eXh4DBgw443l5eXmkpaUxbNiworbjx4/7q2yRkKPAl2pXt27doseRkZHk5+dz9913M2/ePOLi4pg+fTrLli0743mFhYU0bdqUjIwM/xUrEsI0pSMBkZubS3R0NCdPnmTWrFlF7Y0aNSI3NxeAxo0b07p1a95++20AnHNs2LAhIPWKhAIFvgTE008/TY8ePfjxj39c4iDsbbfdxoQJE+jatSs7d+5k1qxZTJ06lbi4ODp27Mj8+fMDWLVIzRbUp2UmJCQ4naUTvuat38uERVvZl53PpU3r88iAdgzt2jLQZYkENTNb65xLKGud5vAlKM1bv5fH5n5O/slTAOzNzuexuZ8DKPRFzpOmdCQoTVi0tSjsT8s/eYoJi7YGqCKRmk+BL0FpX3Z+pdpF5NwU+BKULm1av1LtInJuCnwJSo8MaEf92pEl2urXjuSRAe0CVJFIzaeDthKUTh+Y1Vk6Ir6jwJegNbRrSwW8iA/5ZErHzKaZ2cHyrntuHpPMbIeZbTSzeF+MKyIiFeerOfzpwMCzrL8BaOv9Gg286KNxRUSkgnwS+M655cDhs3QZAsx0Hp8BTc0s2hdji4hIxfjrLJ2WwO5iy3u8bWcws9Fmlm5m6VlZWX4pTkQkHPgr8K2MtjIv4uOcm+KcS3DOJbRoUeZdukRE5Dz4K/D3AK2KLV8G7PPT2CIigv8C/z1gpPdsnUQgxzm3309ji4gIPjoP38zeBPoDzc1sD/AHoDaAcy4VWAj8BNgBfA+M8sW4IiJScT4JfOfciHOsd8ADvhhLRETOj66lIyISJhT4IiJhQoEvIhImFPgiImFCgS8iEiYU+BKSMjMz6dSp03k//+677+add97xYUUigafAFymloKAg0CWIVAsFvoSsgoIC7rrrLmJjY7n11lv5/vvveeqpp+jevTudOnVi9OjReD4iAv379+f3v/89/fr147nnniuxnfHjx3P33XdTWFgYiJch4jMKfAlZW7duZfTo0WzcuJHGjRszefJkxowZw5o1a9i0aRP5+fm8//77Rf2zs7P517/+xW9/+9uitt/97nccPHiQV199lYgI/bpIzaafYAlZrVq1onfv3gD8/Oc/55NPPmHp0qX06NGDzp07s2TJEjZv3lzUf/jw4SWe//TTT5Odnc1LL72EWVkXfBWpWXRPWwlZpUPazLj//vtJT0+nVatWpKSkcOzYsaL1DRs2LNG/e/furF27lsOHD3PBBRf4pWaR6qQ9fAlZX3/9NStXrgTgzTff5JprrgGgefPm5OXlnfMsnIEDBzJu3DgGDRpEbm5utdcrUt20hy8hq3379syYMYNf/vKXtG3bll/96ld89913dO7cmZiYGLp3737ObQwbNozc3FwGDx7MwoULqV+/vh8qF6kedvoshWCUkJDg0tPTA12GhJkPdn3Ac+ue48DRA1zS8BLGxo9lUJtBgS5LpELMbK1zLqGsddrDFynmg10fkJKWwrFTnrn9/Uf3k5KWAqDQlxpPc/gixTy37rmisD/t2KljPLfuuXKeIVJzKPBFijlw9ECl2kVqEgW+SDGXNLykUu0iNYkCX6SYsfFjqRdZr0Rbvch6jI0fG6CKRHxHgS9nlZqaysyZM32yrZiYGL799lufbKu6DGoziJReKUQ3jMYwohtGk9IrRQdsJSToLB05q+Tk5ECX4HeD2gxSwEtI8skevpkNNLOtZrbDzMaVsb6JmS0wsw1mttnMRvliXDk/Q4cOpVu3bnTs2JEpU6YAEBUVxeOPP05cXByJiYl88803AKSkpDBx4kTAc0XJhx56iL59+9K+fXvWrFnDzTffTNu2bXniiSfOun0RCbwqB76ZRQIvADcAHYARZtahVLcHgC+cc3FAf+BvZlanqmPL+Zk2bRpr164lPT2dSZMmcejQIY4ePUpiYiIbNmygb9++vPzyy2U+t06dOixfvpzk5GSGDBnCCy+8wKZNm5g+fTqHDh0qd/siEni+2MO/GtjhnNvlnDsBzAaGlOrjgEbmuZpVFHAY0F0mAmTSpElFe/K7d+9m+/bt1KlTh5/+9KcAdOvWjczMzDKfO3jwYAA6d+5Mx44diY6Opm7durRp04bdu3eXu/2aqlevXoEuQcRnfDGH3xLYXWx5D9CjVJ/ngfeAfUAjYLhzrsy7SZjZaGA0wA9+8AMflCfFLVu2jMWLF7Ny5UoaNGhA//79OXbsGLVr1y66umRkZGS5d32qW7cuABEREUWPTy8XFBSUu/2aKi0tLdAliPiML/bwy7pQeOkL9AwAMoBLgS7A82bWuKyNOeemOOcSnHMJLVq08EF5UlxOTg7NmjWjQYMGbNmyhc8++6xGbd/foqKiAl1CtSp+jEZCny8Cfw/QqtjyZXj25IsbBcx1HjuAfwNX+WBsqaSBAwdSUFBAbGws48ePJzExsUZtX0TOX5WvlmlmtYBtQBKwF1gD3O6c21ysz4vAN865FDO7GFgHxDnnznpStq6WWXMdXX+QI4syOZV9nMimdWk8IIaGXS8KdFmVFhUVRV5eXqDL8KlnnnmGmTNn0qpVK1q0aEG3bt247rrrSE5O5vvvv+eKK65g2rRpNGvWjDVr1nDvvffSsGFDrrnmGj788EM2bdoU6JcgZ3G2q2VWeQ/fOVcAjAEWAV8CbznnNptZspmdPon7aaCXmX0OfAw8eq6wl5rr6PqDZM/dzqns4wCcyj5O9tztHF1/MMCVydq1a5k9ezbr169n7ty5rFmzBoCRI0fy17/+lY0bN9K5c2f++Mc/AjBq1ChSU1NZuXIlkZGRgSxdfMAn5+E75xY65650zl3hnHvG25bqnEv1Pt7nnLveOdfZOdfJOfe6L8YV/5g+fTpjxoypcP8jizJxJ0sek3cnCzmyKNPHlUllrVixgptuuokGDRrQuHFjBg8ezNGjR8nOzqZfv34A3HXXXSxfvpzs7Gxyc3OLzlS6/fbbA1m6+IAurSA+d3rPvqLt4l8VvSF7MN8cSc6PAj/Elfep2t/+9rfEx8eTlJREVlYW4Pkk7a9//Wt69epFp06dWL169Rnby8rK4pZbbqF79+50796dTz/99Iw+kU3rntF2tvagsfEteLYTpDT1fN/4VsjN3/ft25d3332X/Px8cnNzWbBgAQ0bNqRZs2asWLECgNdee41+/frRrFkzGjVqVHSm1ezZswNZuviAAj/Elfep2vj4eNatW0e/fv2K5msBjh49SlpaGpMnT+aee+45Y3tjx47loYceYs2aNcyZM4f77rvvjD6NB8RgtUv+aFntCBoPiPH56/OZjW/Bgv+CnN2A83xf8F+e9hASHx/P8OHD6dKlC7fccgt9+vQBYMaMGTzyyCPExsaSkZHBk08+CcDUqVMZPXo0PXv2xDlHkyZNAlm+VJEunhbiJk2axLvvvgtQ9KnXiIgIhg8fDsDPf/5zbr755qL+I0aMADx7gkeOHCE7O7vE9hYvXswXX3xRtHzkyBFyc3Np1KhRUdvps3Fq1Fk6Hz8FJ/NLtp3M97TH/iwwNVWTxx9/nMcff/yM9tKfmdi4cSNLly7l5ptvpkmTJmRmZpKQUObJH1JDKPBDWEU/9Vp8Trf0/G7p5cLCQlauXEn9+vXPOnbDrhcFd8CXlrOncu0hbuPGjSxYsID169fzySefUFhYSLNmzXjppZcCXZpUgaZ0Qlh5n3otLCzknXfeAeCNN97gmmuuKXrO//3f/wHwySef0KRJkzP+hL/++ut5/vnni5YzMjKq+VX4SZPLKtce4j7++GNOnjxJp06dSE5O5v7772fEiBGsX78+0KVJFSjwQ1h5n3pt2LAhmzdvplu3bixZsqRovhagWbNm9OrVi+TkZKZOnXrGNidNmkR6ejqxsbF06NCB1NRUv72eapX0JNQu9VdL7fqe9jATFRVFTk4Oubm5vPWW5xhGRkYGCxcuJCcnJ8DVSVVU+ZO21UmftK0e5X16tH///kycOLHcedovVyxlxeyZ5B76lkYXNqfPbSNp3+fa6i7Xfza+5Zmzz9nj2bNPejLk5u8rIioqiqeffrpEuGdkZLBv3z5GjBjBQw89FMDq5Fyq9ZO2Eh6+XLGUj6Y8T+63WeAcud9m8dGU5/lyxdJAl+Y7sT+DhzZBSrbnexiG/WlJSUnk5eUxefLkoraIiAiSkpL44IMP6NmzJ99++y0fffQRPXv2JD4+nmHDhoXcaayhRoEfhsr7pVy2bFm5e/crZs+k4ETJD04VnDjOitm+ud+tBJfY2FiSkpKKLqdQv359YmJi2LlzJ3/5y19YuHAhAH/6059YvHgx69atIyEhgb///e+BLFvOQWfpSIXkHir70kfltUvN1759ey688EJSUlKYPn06EyZMIDMzk48++ojGjRvz/vvv88UXX9C7d28ATpw4Qc+ePQNctZyNAl8qpNGFzT3TOWW0S3ho06YNu3btYtu2bSQkJOCc48c//jFvvvlmoEuTCtKUjlRIn9tGUqtOyUsj1KpTlz63jQxQReJvl19+OXPnzmXkyJFs3ryZxMREPv30U3bs2AHA999/z7Zt2wJcpZyNAl8qpH2fa7l+9BgaNW8BZjRq3oLrR48JrbN05JzatWvHrFmzGDZsGEeOHGH69OmMGDGC2NhYEhMT2bJlS6BLlLPQaZkiUmn7D8xn186JHDu+n3p1o2lzxcNEXzIk0GUJZz8tU3P4IlIp+w/MZ8uWxyks9Fx76NjxfWzZ4rk2j0I/uGlKR0QqZdfOiUVhf1phYT67dupm6MFOgS8ilXLs+P5KtUvwUOCLSKXUqxtdqfbzkZmZyRtvvOGz7YmHAl9EKqXNFQ8TEVHyQnMREfVpc8XDPhtDgV89FPgiUinRlwzhqqueoV7dSwHju8MX8Iv7shn/xAI6derEHXfcweLFi+nduzdt27Zl9erVHD16lHvuuYfu3bvTtWtX5s+fD3iCvU+fPsTHxxMfH09aWhoA48aNY8WKFXTp0oVnn302gK82xDjnqvwFDAS2AjuAceX06Q9kAJuBf1Vku926dXMiEtz+/e9/u8jISLdx40Z36tQpFx8f70aNGuUKCwvdvHnz3JAhQ9xjjz3mXnvtNeecc999951r27aty8vLc0ePHnX5+fnOOee2bdvmTv/OL1261A0aNChgr6kmA9JdOZla5dMyzSwSeAH4MbAHWGNm7znnvijWpykwGRjonPvazGrQrZBE5Fxat25N586dAejYsSNJSUmYGZ07dyYzM5M9e/bw3nvvMXGi50yeY8eO8fXXX3PppZcyZswYMjIyiIyM1Cd1q5kvzsO/GtjhnNsFYGazgSHAF8X63A7Mdc59DeCcO+iDcUUkSNSt+5/LbkRERBQtR0REUFBQQGRkJHPmzKFdu3YlnpeSksLFF1/Mhg0bKCwspF69en6tO9z4Yg6/JbC72PIeb1txVwLNzGyZma01s3IvwGJmo80s3czSs7LOvFiXiNQ8AwYM4B//+Mfp6d2iWyXm5OQQHR1NREQEr732GqdOnQKgUaNG5ObmBqxeX5g+fTpjxowJdBkl+CLwrYy20tdrqAV0AwYBA4DxZnZlWRtzzk1xziU45xJatGjhg/JEJNDGjx/PyZMniY2NpVOnTowfPx6A+++/nxkzZpCYmMi2bdto2LAh4Lkef61atYiLiwvYQVvnHIWFhQEZu7pU+Vo6ZtYTSHHODfAuPwbgnPtzsT7jgHrOuRTv8lTgn865t8+27bKupTNp0iRefPFFDhw4wKOPPsq4cePKfO706dNJT08vccNtEQle89bvZcKirezLzufSpvV5ZEA7hnYtPVlQvTIzM7nhhhu49tprWblyJUOHDuX999/n+PHj3HTTTfzxj38EYOjQoezevZtjx44xduxYRo8eDcCrr77Kn//8Z6Kjo7nyyiupW7eu3zOouq+lswZoa2atgb3AbXjm7IubDzxvZrWAOkAP4Lz+2548eTIffvghrVu3rkLJIhJM5q3fy2NzPyf/pGdKZ292Po/N/RzA76G/detWXn31VYYOHco777zD6tWrcc4xePBgli9fTt++fZk2bRoXXHAB+fn5dO/enVtuuYUTJ07whz/8gbVr19KkSROuvfZaunbt6tfaz6XKUzrOuQJgDLAI+BJ4yzm32cySzSzZ2+dL4J/ARmA18IpzblNlx0pOTmbXrl0MHjyYZ599tmh+7O2336ZTp07ExcXRt2/fov779u1j4MCBtG3blt/97ndVfakiUk0mLNpaFPan5Z88xYRFW/1ey+WXX05iYiIfffQRH330EV27diU+Pp4tW7awfft2wDPTEBcXR2JiIrt372b79u2sWrWK/v3706JFC+rUqcPw4cP9Xvu5+ORqmc65hcDCUm2ppZYnABOqMk5qair//Oc/Wbp0Ke+//35R+1NPPcWiRYto2bIl2dnZRe0ZGRmsX7+eunXr0q5dOx588EFatWpVlRJEpBrsy86vVHt1On0cwTnHY489xi9/+csS65ctW8bixYtZuXIlDRo0oH///hw7dgwAs7IOaQaPkPikbe/evbn77rt5+eWXi47yAyQlJdGkSRPq1atHhw4d+OqrrwJYpYiU59Km9SvV7g8DBgxg2rRp5OXlAbB3714OHjxITk4OzZo1o0GDBmzZsoXPPvsMgB49erBs2TIOHTrEyZMnefvtsx6iDIiQuB5+amoqq1at4oMPPqBLly5kZGQAJc8NjoyMpKCgIEAVisjZPDKgXYk5fID6tSN5ZEC7szyrel1//fV8+eWXRTdmj4qK4vXXX2fgwIGkpqYSGxtLu3btSExMBCA6OpqUlBR69uxJdHQ08fHxJXZAg0FIBP7OnTvp0aMHPXr0YMGCBezevfvcTxKRoHH6wGygz9KJiYlh06b/HF4cO3YsY8eOPaPfhx9+eEZbzoIFXPPa68yPrEWtyFpcdN11NLnxxmqtt7JCIvAfeeQRtm/fjnOOpKQk4uLiivbyRaRmGNq1pd8D3ldyFixg//gncd65/IJ9+9g//kmAoAp93dNWRKSKtv8oiYJ9+85or3XppbRd8rFfawm7e9rOOXCYP+/az97jJ2lZtzaPtYnmlksuCHRZIhKiCvaXfbev8toDJSTO0iluzoHDPLx1N3uOn8QBe46f5OGtu5lz4HCgSxOREFUruuy7fZXXHighF/h/3rWf/MKS01T5hY4/7wqu/2lFJHRc9NCvsVJX+rR69bjooV8HpqByhNyUzt7jJyvVLiJSVacPzB589n8p2L+fWtHRXPTQr4PqgC2EYOC3rFubPWWEe8u6tQNQjYiEiyY33hh0AV9ayE3pPNYmmvoRJT/eXD/CeKxNcM2liYj4W8jt4Z8+G0dn6YiIlBRygQ+e0FfAi4iUFHJTOiIiUjYFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImPBJ4JvZQDPbamY7zGzcWfp1N7NTZnarL8YVEZGKq3Lgm1kk8AJwA9ABGGFmHcrp91dgUVXHFBGRyvPFHv7VwA7n3C7n3AlgNjCkjH4PAnOAgz4YU0REKskXgd8S2F1seY+3rYiZtQRuAlLPtTEzG21m6WaWnpWV5YPyREQEfBP4VkZb6Tuj/y/wqHPu1Lk25pyb4pxLcM4ltGjRwgfliYgI+OZqmXuAVsWWLwNK3749AZhtZgDNgZ+YWYFzbp4PxpcaIDU1lQYNGjBy5MhAlyIStnwR+GuAtmbWGtgL3AbcXryDc6716cdmNh14X2EfXpKTkwNdgkjYq/KUjnOuABiD5+ybL4G3nHObzSzZzPRbXgNlZmZy1VVXcd9999GpUyfuuOMOFi9eTO/evWnbti2rV6/m8OHDDB06lNjYWBITE9m4cSOFhYXExMSQnZ1dtK0f/vCHfPPNN6SkpDBx4kQAdu7cycCBA+nWrRt9+vRhy5YtAXqlIuHFJzdAcc4tBBaWaivzAK1z7m5fjCnVa8eOHbz99ttMmTKF7t2788Ybb/DJJ5/w3nvv8d///d+0atWKrl27Mm/ePJYsWcLIkSPJyMhgyJAhvPvuu4waNYpVq1YRExPDxRdfXGLbo0ePJjU1lbZt27Jq1Sruv/9+lixZEqBXKhI+QvKOV1J1rVu3pnPnzgB07NiRpKQkzIzOnTuTmZnJV199xZw5cwD40Y9+xKFDh8jJyWH48OE89dRTjBo1itmzZzN8+PAS283LyyMtLY1hw4YVtR0/ftx/L0wkjCnwpUx169YtehwREVG0HBERQUFBAbVqnfmjY2b07NmTHTt2kJWVxbx583jiiSdK9CksLKRp06ZkZGRUa/0iciZdS0fOS9++fZk1axYAy5Yto3nz5jRu3Bgz46abbuI3v/kN7du358ILLyzxvMaNG9O6dWvefvttAJxzbNiwwe/1i4QjBb6cl5SUFNLT04mNjWXcuHHMmDGjaN3w4cN5/fXXz5jOOW3WrFlMnTqVuLg4OnbsyPz58/1VtkhYM+dKf0YqeCQkJLj09PRAlyE+9MGuD3hu3XMcOHqASxpewtj4sQxqMyjQZYmEDDNb65xLKGud5vDFbz7Y9QEpaSkcO3UMgP1H95OSlgKg0BfxA03piN88t+65orA/7dipYzy37rkAVSQSXhT44jcHjh6oVLuI+JYCX/zmkoaXVKpdRHxLgS9+MzZ+LPUi65VoqxdZj7HxYwNUkUh40UFb8ZvTB2Z1lo5IYCjwxa8GtRmkgBcJEE3piIiECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImFDgi4iECQW+iEiYUOCLiIQJnwS+mQ00s61mtsPMxpWx/g4z2+j9SjOzOF+MKyIiFVflwDezSOAF4AagAzDCzDqU6vZvoJ9zLhZ4GphS1XFFRKRyfLGHfzWwwzm3yzl3ApgNDCnewTmX5pz7zrv4GXCZD8YVEQlZy5YtIy0tzafb9EXgtwR2F1ve420rz73Ah+WtNLPRZpZuZulZWVk+KE9EpOYJ1sC3MtrKvDO6mV2LJ/AfLW9jzrkpzrkE51xCixYtfFCeiEjwmDlzJrGxscTFxXHnnXeyYMECevToQdeuXbnuuuv45ptvyMzMJDU1lWeffZYuXbqwYsUKn4zti8sj7wFaFVu+DNhXupOZxQKvADc45w75YFwRkRpl8+bNPPPMM3z66ac0b96cw4cPY2Z89tlnmBmvvPIK//M//8Pf/vY3kpOTiYqK4uGHH/bZ+L4I/DVAWzNrDewFbgNuL97BzH4AzAXudM5t88GYIiI1zpIlS7j11ltp3rw5ABdccAGff/45w4cPZ//+/Zw4cYLWrVtX2/hVntJxzhUAY4BFwJfAW865zWaWbGbJ3m5PAhcCk80sw8zSqzquiEhN45zDrOQs+IMPPsiYMWP4/PPPeemllzh27Fi1je+T8/Cdcwudc1c6565wzj3jbUt1zqV6H9/nnGvmnOvi/UrwxbgiIjVJUlISb731FocOeWa1Dx8+TE5ODi1bes5zmTFjRlHfRo0akZub69Px9UlbERE/6dixI48//jj9+vUjLi6O3/zmN6SkpDBs2DD69OlTNNUDcOONN/Luu+/69KCtOVfmCTVBISEhwaWna/ZHRMLDtlUHWDl/J3mHjxN1QV16DrmCK3tcUqltmNna8mZRdBNzEZEgsG3VAZbO2kLBiUIA8g4fZ+msLQCVDv3yaEpHRCQIrJy/syjsTys4UcjK+Tt9NoYCX0QkCOQdPl6p9vOhwBcRCQJRF9StVPv5UOCLiASBnkOuoFadkpFcq04EPYdc4bMxdNBWRCQInD4wW9WzdM5GgS8iEiSu7HGJTwO+NE3piIiECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhAkFvohImPBJ4JvZQDPbamY7zGxcGevNzCZ51280s3hfjCsiIhVX5cA3s0jgBeAGoAMwwsw6lOp2A9DW+zUaeLGq44qISOX4Yg//amCHc26Xc+4EMBsYUqrPEGCm8/gMaGpm0T4YW0REKsgXgd8S2F1seY+3rbJ9ADCz0WaWbmbpWVlZPihPRETAN4FvZbS58+jjaXRuinMuwTmX0KJFiyoXJyIiHr4I/D1Aq2LLlwH7zqOPiIhUI18E/hqgrZm1NrM6wG3Ae6X6vAeM9J6tkwjkOOf2+2BsERGpoCrf09Y5V2BmY4BFQCQwzTm32cySvetTgYXAT4AdwPfAqKqOKyIileOTm5g75xbiCfXibanFHjvgAV+MJSIi50eftBURCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMKHAFxEJEwp8EZEwocAXEQkTCnwRkTChwBcRCRMKfBGRMFGlwDezC8zs/5nZdu/3ZmX0aWVmS83sSzPbbGZjqzKmiIicn6ru4Y8DPnbOtQU+9i6XVgD81jnXHkgEHjCzDlUcV0REKqmqgT8EmOF9PAMYWrqDc26/c26d93Eu8CXQsorjiohIJVU18C92zu0HT7ADF52ts5nFAF2BVWfpM9rM0s0sPSsrq9IF9erVq9LPEREJB7XO1cHMFgOXlLHq8coMZGZRwBzg1865I+X1c85NAaYAJCQkuMqMAZCWllbZp4iIhIVzBr5z7rry1pnZN2YW7Zzbb2bRwMFy+tXGE/aznHNzz7vaCoiKiiIvL4/9+/czfPhwjhw5QkFBAS+++CJ9+vSpzqFFRIJaVad03gPu8j6+C5hfuoOZGTAV+NI59/cqjldhb7zxBgMGDCAjI4MNGzbQpUsXfw0tIhKUzrmHfw5/Ad4ys3uBr4FhAGZ2KfCKc+4nQG/gTuBzM8vwPu/3zrmFVRz7rLp3784999zDyZMnGTp0qAJfRMJelfbwnXOHnHNJzrm23u+Hve37vGGPc+4T55w552Kdc128X9Ua9gB9+/Zl+fLltGzZkjvvvJOZM2dW95AiIkEtZD9p+9VXX3HRRRfxi1/8gnvvvZd169YFuiQRkYCq6pRO0Fq2bBkTJkygdu3aREVFaQ9fRMJeyAV+Xto0eLYTd+Xs4a57L4OkxyD2Z4EuS0Qk4EIr8De+BQv+C07me5ZzdnuWQaEvImEvtObwP37qP2F/2sl8T7uISJgLrcDP2VO5dhGRMBJagd/kssq1i4iEkdAK/KQnoXb9km2163vaRUTCXGgFfuzP4MZJ0KQVYJ7vN07SAVsREULtLB3whLsCXkTkDKG1hy8iIuVS4IuIhAkFvohImFDgi4iECQW+iEiYMOcqfdtYvzGzLOArH22uOfCtj7blS6qrclRX5QRrXRC8tdX0ui53zrUoa0VQB74vmVm6cy4h0HWUproqR3VVTrDWBcFbWyjXpSkdEZEwocAXEQkT4RT4UwJdQDlUV+WorsoJ1rogeGsL2brCZg5fRCTchdMevohIWFPgi4iEiZAKfDMbaGZbzWyHmY0rY/0dZrbR+5VmZnFBVNsQb10ZZpZuZtcEQ13F+nU3s1Nmdmsw1GVm/c0sx/t+ZZiZX256UJH3y1tbhpltNrN/BUNdZvZIsfdqk/ff8oIgqKuJmS0wsw3e92tUdddUwbqamdm73t/J1WbWyU91TTOzg2a2qZz1ZmaTvHVvNLP4Sg3gnAuJLyAS2Am0AeoAG4AOpfr0App5H98ArAqi2qL4zzGVWGBLMNRVrN8SYCFwazDUBfQH3g/Cn7GmwBfAD7zLFwVDXaX63wgsCYa6gN8Df/U+bgEcBuoEQV0TgD94H18FfOynn7G+QDywqZz1PwE+BAxIrGyGhdIe/tXADufcLufcCWA2MKR4B+dcmnPuO+/iZ4C/7n1YkdrynPdfFGgI+ONo+jnr8noQmAMc9ENNlanL3ypS1+3AXOfc1wDOOX+8Z5V9v0YAbwZJXQ5oZGaGZ6fnMFAQBHV1AD4GcM5tAWLM7OJqrgvn3HI870F5hgAzncdnQFMzi67o9kMp8FsCu4st7/G2ledePP9T+kOFajOzm8xsC/ABcE8w1GVmLYGbgFQ/1FPhurx6eqcCPjSzjkFS15VAMzNbZmZrzWxkkNQFgJk1AAbi+Q88GOp6HmgP7AM+B8Y65wqDoK4NwM0AZnY1cDn+20E8m8rmXAmhFPhWRluZe8lmdi2ewH+0WisqNmQZbWfU5px71zl3FTAUeLq6i6Jidf0v8Khz7lT1l1OkInWtw3PNkDjgH8C86i6KitVVC+gGDAIGAOPN7MogqOu0G4FPnXNn24v0lYrUNQDIAC4FugDPm1nj6i2rQnX9Bc9/3Bl4/sJdT/X/5VERlfm3PkMo3eJwD9Cq2PJlePYaSjCzWOAV4Abn3KFgqu0059xyM7vCzJo756rzIk4VqSsBmO35i5vmwE/MrMA5Ny+QdTnnjhR7vNDMJgfJ+7UH+NY5dxQ4ambLgThgW4DrOu02/DOdAxWraxTwF+905g4z+zeeOfPVgazL+/M1CjwHSoF/e78CrVJZcgZ/HIjw08GOWsAuoDX/ORDTsVSfHwA7gF5BWNsP+c9B23hg7+nlQNZVqv90/HPQtiLv1yXF3q+rga+D4f3CMz3xsbdvA2AT0CnQdXn7NcEzP9ywuv8NK/F+vQikeB9f7P25bx4EdTXFe/AY+AWeefNqf8+848VQ/kHbQZQ8aLu6MtsOmT1851yBmY0BFuE5Cj/NObfZzJK961OBJ4ELgcnePdYC54er4lWwtluAkWZ2EsgHhjvvv3CA6/K7CtZ1K/ArMyvA837dFgzvl3PuSzP7J7ARKARecc6VeYqdP+vydr0J+Mh5/vqodhWs62lgupl9jifEHnXV+1daRetqD8w0s1N4zrq6tzprOs3M3sRzBlpzM9sD/AGoXayuhXjO1NkBfI/3r5AKb7+af0dERCRIhNJBWxEROQsFvohImFDgi4iECQW+iEiYUOCLiIQJBb6ISJhQ4IuIhIn/DyCri3Zc6/JlAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "打印\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.autograd import variable\n", + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "from tqdm import tqdm\n", + "\n", + "dtype = torch.FloatTensor\n", + "#我们使用的语料库 \n", + "sentences = ['i like dog','i like cat','i like animal','dog is animal','cat is animal',\n", + " 'dog like meat','cat like meat','cat like fish','dog like meat','i like apple',\n", + " 'i hate apple','i like movie','i like read','dog like bark','dog like cat']\n", + "\n", + "\n", + "\n", + "word_sequence = ' '.join(sentences).split() #将语料库的每一句话的每一个词转化为列表 \n", + "#print(word_sequence)\n", + "\n", + "word_list = list(set(word_sequence)) #构建我们的词表 \n", + "#print(word_list)\n", + "\n", + "#word_voc = list(set(word_sequence)) \n", + "\n", + "#接下来对此表中的每一个词编号 这就用到了我们之前提到的one-hot编码 \n", + "\n", + "#词典 词对应着编号\n", + "word_dict = {w:i for i,w in enumerate(word_list)}\n", + "#print(word_dict)\n", + "#编号对应着词\n", + "index_dict = {i:w for w,i in enumerate(word_list)}\n", + "#print(index_dict)\n", + "\n", + "\n", + "batch_size = 2\n", + "voc_size = len(word_list)\n", + "\n", + "skip_grams = []\n", + "for i in range(1,len(word_sequence)-1,3):\n", + " target = word_dict[word_sequence[i]] #当前词对应的id\n", + " context = [word_dict[word_sequence[i-1]],word_dict[word_sequence[i+1]]] #两个上下文词对应的id\n", + "\n", + " for w in context:\n", + " skip_grams.append([target,w])\n", + "\n", + "embedding_size = 10 \n", + "\n", + "\n", + "class Word2Vec(nn.Module):\n", + " def __init__(self):\n", + " super(Word2Vec,self).__init__()\n", + " self.W1 = nn.Parameter(torch.rand(len(word_dict),embedding_size)).type(dtype) \n", + " #将词的one-hot编码对应到词向量中\n", + " self.W2 = nn.Parameter(torch.rand(embedding_size,voc_size)).type(dtype)\n", + " #将词向量 转化为 输出 \n", + " def forward(self,x):\n", + " hidden_layer = torch.matmul(x,self.W1)\n", + " output_layer = torch.matmul(hidden_layer,self.W2)\n", + " return output_layer\n", + "\n", + "\n", + "model = Word2Vec()\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(),lr=1e-5)\n", + "\n", + "#print(len(skip_grams))\n", + "#训练函数\n", + "\n", + "def random_batch(data,size):\n", + " random_inputs = []\n", + " random_labels = []\n", + " random_index = np.random.choice(range(len(data)),size,replace=False)\n", + " \n", + " for i in random_index:\n", + " random_inputs.append(np.eye(voc_size)[data[i][0]]) #从一个单位矩阵生成one-hot表示\n", + " random_labels.append(data[i][1])\n", + " \n", + " return random_inputs,random_labels\n", + "\n", + "for epoch in tqdm(range(100000)):\n", + " input_batch,target_batch = random_batch(skip_grams,batch_size) # X -> y\n", + " input_batch = torch.Tensor(input_batch)\n", + " target_batch = torch.LongTensor(target_batch)\n", + "\n", + " optimizer.zero_grad()\n", + "\n", + " output = model(input_batch)\n", + "\n", + " loss = criterion(output,target_batch)\n", + " if((epoch+1)%10000==0):\n", + " print(\"epoch:\",\"%04d\" %(epoch+1),'cost =' ,'{:.6f}'.format(loss))\n", + "\n", + " loss.backward() \n", + " optimizer.step()\n", + "\n", + "for i , label in enumerate(word_list):\n", + " W1,_ = model.parameters()\n", + " x,y = float(W1[i][0]),float(W1[i][1])\n", + " plt.scatter(x,y)\n", + " plt.annotate(label,xy=(x,y),xytext=(5,2),textcoords='offset points',ha='right',va='bottom')\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1edccf25", + "metadata": {}, + "source": [ + "在自然语言处理领域,常见的评价指标包括以下几种: \n", + "\n", + "**准确率(Accuracy)**: \n", + "准确率是最简单和常见的评价指标之一,用于度量模型在整体样本集上正确分类的比例。 \n", + "\n", + "**精确率(Precision)和召回率(Recall)**: \n", + "精确率和召回率是用于评估二分类模型性能的指标。精确率指的是模型预测为正例中真正为正例的比例,而召回率指的是真正为正例中被模型预测为正例的比例。 \n", + "\n", + "**F1值(F1-Score)**: \n", + "F1值是精确率和召回率的调和均值,综合了两者的评估结果。F1值越高,代表模型在精确率和召回率之间取得了更好的平衡。 \n", + "\n", + "**混淆矩阵(Confusion Matrix)**: \n", + "混淆矩阵是用于可视化二分类模型性能的矩阵。它将实际类别与模型预测类别的结果进行交叉统计,可以计算出准确率、精确率、召回率等指标。 \n", + "\n", + "**ROC曲线和AUC值(Receiver Operating Characteristic Curve and Area Under Curve)**: \n", + "ROC曲线是以不同的分类阈值为基础,绘制出真正例率(True Positive Rate)和假正例率(False Positive Rate)之间的关系曲线。\n", + "AUC值表示ROC曲线下的面积,用于度量模型在不同阈值下的分类性能。 \n", + "\n", + "**BLEU评估(Bilingual Evaluation Understudy)**: \n", + "BLEU评估用于评估机器翻译质量的指标,通过比较候选翻译与参考翻译之间的词语重叠度来计算得分。 \n", + "\n", + "**困惑度(Perplexity)**: \n", + "困惑度常用于语言模型的评估,表示模型对给定序列进行预测的困难程度。困惑度越低,代表模型对输入序列的预测越准确。 \n", + "\n", + "这些评价指标并不是固定的,具体使用哪些指标取决于任务类型和需求。在不同的自然语言处理任务中,还可能会有其他特定的评价指标被使用。 " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytorch", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9 (default, Aug 31 2020, 12:42:55) \n[GCC 7.3.0]" + }, + "vscode": { + "interpreter": { + "hash": "7648c2b9d25760d0d65f53f9b9a34de48caa24d8265d64b0ff81e2f2641d528d" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}