notebooks/references.ipynb

214 lines
18 KiB
Plaintext
Raw Normal View History

2020-09-07 06:54:31 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 信息熵\n",
"\n",
"\n",
"## 信息量\n",
"\n",
"$$\n",
"I_i=\\log_2(\\frac{1}{p_i})=-\\log_2{p_i}\n",
"$$\n",
"\n",
"其中,$I_i$ 为$i$的信息量,$p_i$ 为$i$出现的概率。显然,当$i$出现的机率越小的时候它的信息量就越大。\n",
"\n",
"## 信息熵\n",
"\n",
"$$\n",
"H(X)=\\sum_{i=1}^{n}(p_i \\times \\log_2(\\frac{1}{p_i}))\n",
"$$\n",
"\n",
"信息熵:信息量的期望\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"设有 __模型1__ 和 __模型2__ 对选情进行预测:\n",
"\n",
"__模型1__\n",
"\n",
"| Prediction | Ground truth | Correct |\n",
"|-------------|--------------|----------|\n",
"| 0.3 0.3 0.4 | 0 0 1 | yes |\n",
"| 0.3 0.4 0.3 | 0 1 0 | yes |\n",
"| 0.1 0.2 0.7 | 1 0 0 | no |\n",
"\n",
"__模型2__\n",
"\n",
"| Prediction | Ground truth | Correct |\n",
"|-------------|--------------|----------|\n",
"| 0.1 0.2 0.7 | 0 0 1 | yes |\n",
"| 0.1 0.7 0.2 | 0 1 0 | yes |\n",
"| 0.3 0.4 0.3 | 1 0 0 | no |\n",
"\n",
"从 Correct 列来看(accurary)两个模型的表现都是2对1对但从具体的 prediction 来看,显然 __模型2__ 的表现更好。\n",
"\n",
"\n",
"## Classification Error分类错误率\n",
"\n",
"$$\n",
"\\text{classification error} = \\frac{\\text{count of error items}}{\\text{count of all items}}\n",
"$$\n",
"\n",
"显然,按 accuracy 来判断模型的优质程序不够精细。\n",
"\n",
"## Mean Squared Error (均方误差)\n",
"\n",
"$$\n",
"MSE = \\frac{1}{n}\\sum_{i}^{n}(\\hat{y_i} - y_i)^2\n",
"$$\n",
"\n",
"\n",
"__模型一__:\n",
"\n",
"$$\n",
"loss_1 = (0.3-0)^2 + (0.3-0)^2 + (0.4-1)^2 = 0.54\\\\\n",
"loss_2 = (0.3-0)^2 + (0.4-1)^2 + (0.3-0)^2 = 0.54\\\\\n",
"loss_3 = (0.1-1)^2 + (0.2-0)^2 + (0.7-0)^2 = 1.32\\\\\n",
"L = \\frac{0.54 + 0.54 + 1.32}{3} = 0.8\n",
"$$\n",
"\n",
"\n",
"__模型二__:\n",
"\n",
"$$\n",
"loss_1 = (0.1-0)^2 + (0.2-0)^2 + (0.7-1)^2 = 0.138\\\\\n",
"loss_2 = (0.1-0)^2 + (0.7-1)^2 + (0.2-0)^2 = 0.138\\\\\n",
"loss_3 = (0.3-1)^2 + (0.4-0)^2 + (0.3-0)^2 = 0.72\\\\\n",
"L = \\frac{0.138 + 0.138 + 0.72}{3} = 0.332\n",
"$$\n",
"\n",
"显然MSE 可以更好地反映出模型的表现差异。但它的问题在训练时开始阶段的梯度下降速率非常慢。\n",
"\n",
"\n",
"## Cross Entropy Error Function (交叉熵损失函数)\n",
"\n",
"### 二分类\n",
"\n",
"$$\n",
"L = \\frac{1}{N}\\sum_{i}L_i = \\frac{1}{N}\\sum_i-[y_i\\log(p_i)+(1-y_i)\\log(1-p_i)]\n",
"$$\n",
"\n",
"* $y_i$: 样本$i$的label, 正类为1, 负类为0\n",
"* $p_i$: 样本$i$预测为正的概率\n",
"\n",
"### 多分类\n",
"\n",
"$$\n",
"L = \\frac{1}{N}\\sum_{i}L_i = \\frac{1}{N}\\sum_i -\\sum_{c=1}^{M}y_{ic}\\log(p_{ic})\n",
"$$\n",
"\n",
"* $M$ 类别的数量 \n",
"* $y_{ic} 指示变量0或1如果该类别和样本$i$的类别相同就是1,否则是0\n",
"* $p_{ic} 对于难测样本$i$属于类别$c$的预测查概率\n",
"\n",
"__模型一__:\n",
"\n",
"$$\n",
"loss_1 = -(0 \\times \\log 0.3 + 0 \\times \\log 0.3 + 1 \\times \\log 0.4) = 0.91\\\\\n",
"loss_2 = -(0 \\times \\log 0.3 + 1 \\times \\log 0.4 + 0 \\times \\log 0.3) = 0.91\\\\\n",
"loss_3 = -(1 \\times \\log 0.1 + 0 \\times \\log 0.2 + 0 \\times \\log 0.7) = 2.30\\\\\n",
"L = \\frac{0.91 + 0.91 + 2.30}{3}=1.37\n",
"$$\n",
"\n",
"\n",
"__模型二__:\n",
"\n",
"$$\n",
"loss_1 = -(0 \\times \\log 0.1 + 0 \\times \\log 0.2 + 1 \\times \\log 0.7) = 0.35\\\\\n",
"loss_2 = -(0 \\times \\log 0.1 + 1 \\times \\log 0.7 + 0 \\times \\log 0.2) = 0.35\\\\\n",
"loss_3 = -(1 \\times \\log 0.3 + 0 \\times \\log 0.4 + 0 \\times \\log 0.4) = 1.20\\\\\n",
"L = \\frac{0.35 + 0.35 + 1.20}{3}=0.63\n",
"$$\n",
"\n",
"可以发现,交叉熵损失函数也可以反映出 __模型一__ 和 __模型二__ 的优劣\n",
"\n",
"### 函数性质"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<ipython-input-1-f5324e002a15>:6: RuntimeWarning: divide by zero encountered in log\n",
" y = -np.log(x) # 二分类化简\n"
]
},
{
"data": {
"text/plain": [
"Text(0, 0.5, 'loss')"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAEGCAYAAABvtY4XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAjLElEQVR4nO3dd3xd5X3H8c9P29rWtmXJ8pAHGMsbmxHMCKGElSYhQCChodDQkNKsNqtN05k2TdqmtBAHUgIhhkIScBgBggGbYWN57y0vyZq2pofG0z/OlSJcDxnr6tx77vf9et2Xru49957fI9lfnfuc5zyPOecQEZHgifO7ABERCQ8FvIhIQCngRUQCSgEvIhJQCngRkYBK8LuA/vLy8lxZWZnfZYiIRI2VK1c2OOfyT/ZcRAV8WVkZlZWVfpchIhI1zGzPqZ5TF42ISEAp4EVEAkoBLyISUAp4EZGAUsCLiASUAl5EJKAU8CIiARX1Ad/T43hg8XaWbKv3uxQRkYgS9QEfF2f8eMkuXttc63cpIiIRJeoDHqAoM4XalmN+lyEiElECEfCFmSkcbDnqdxkiIhElMAFfq4AXEXmfQAR8UVYyda3H6OnR+rIiIr0CEfCFmSl09zga2tUPLyLSKzABD1DbrIAXEekViIAv6g149cOLiPQJRMD3HsFrJI2IyO8FIuDz0pOIM6hTwIuI9AlEwCfEx5GfkawjeBGRfgIR8NB7sZNOsoqI9ApUwKuLRkTk9wIT8EWarkBE5H0CE/CFmckc7ujkaGe336WIiESEAAW8N1SyTv3wIiJAgAK+KEtj4UVE+gtMwOtiJxGR9wtcwGskjYiIJzABn5mSwLDEeA42K+BFRCBAAW9mFGbqalYRkV6BCXjovdhJo2hERCBgAV+UpYudRER6hT3gzSzezFab2fPh3lfv4tvOaek+EZGhOIK/H9g8BPuhMDOF4109NB/pHIrdiYhEtLAGvJmNAj4KPBzO/fQq0lh4EZE+4T6C/3fgL4CeU21gZveYWaWZVdbX15/TzgozkwE0VFJEhDAGvJldB9Q551aebjvn3ALn3Czn3Kz8/Pxz2qfmoxER+b1wHsFfDNxgZlXAk8AVZvbzMO6Pgt4jeHXRiIiEL+Cdc99wzo1yzpUBtwCLnXO3h2t/AMkJ8eSkJSngRUQI2Dh4gJLhw9jT2O53GSIivhuSgHfOveGcu24o9jWpKJPNNa0aCy8iMS9wR/CTRmTQ1H6c+jadaBWR2Ba4gJ9YlAHA1oOtPlciIuKvwAX8pKJMALbUKOBFJLYFLuBz0pIoyEhmi47gRSTGBS7gASaNyGTLwRa/yxAR8VUwA74og+11bXR1n3KGBBGRwAtswB/v6qFK4+FFJIYFMuB7R9Js1olWEYlhgQz48QXpxMeZhkqKSEwLZMAnJ8QzLj9NJ1pFJKYFMuABJoamLBARiVWBDfhJRRkcOHyElqNavk9EYlOgAx5gm/rhRSRGBTfgR4SmLFDAi0iMCmzAj8xKISMlQSdaRSRmBTbgzYzJRZlsOKCAF5HYFNiAB5hVNpwNB5ppP9bldykiIkMu0AE/b1wuXT2Oyj2H/C5FRGTIBTrgZ44eTkKcsWxXo9+liIgMuUAHfGpSAhUl2Qp4EYlJgQ54gHljc1m3v5k29cOLSIwJfMDPHZtLd4+jsqrJ71JERIZU4AN+xuhsEuONZbsU8CISWwIf8KlJCVSMyuZd9cOLSIwJfMCDN1xyw4FmWjXxmIjEkJgI+L5+eI2HF5EYEhMBP6N0eKgfXt00IhI7YiLghyXFM71kOO/uVMCLSOyIiYAHuGxiPuv2N3Ow+ajfpYiIDImYCfiPnF8EwMsbD/pciYjI0IiZgB9fkE55QTovbajxuxQRkSERMwEP8AdTinhvdxONbcf8LkVEJOxiKuCvmTKCHgevbKr1uxQRkbCLqYCfPCKD0bmp/HaD+uFFJPjCFvBmlmJm75nZWjPbaGbfDde+zqImrjm/iHd2NtB8RFe1ikiwhfMI/hhwhXOuApgGXGNmc8O4vwG5ZkoRnd2O1zarm0ZEgi1sAe88baFvE0M3F679DVTFqGxGZKXwkrppRCTgwtoHb2bxZrYGqANedc4tP8k295hZpZlV1tfXh7McAOLijGumFPHmtnqaO9RNIyLBFdaAd851O+emAaOAOWY25STbLHDOzXLOzcrPzw9nOX0+MXMUx7t6+PXq/UOyPxERPwzJKBrn3GHgdeCaodjfmZw/Moupo7J4csU+nPO910hEJCzCOYom38yyQ/eHAR8GtoRrf2fr1jmlbDnYyup9h/0uRUQkLMJ5BD8CeN3M1gEr8Prgnw/j/s7K9RUjSU2K58n39vpdiohIWIRzFM0659x059xU59wU59zfhmtfH0R6cgI3VIzkN2trtNKTiARSTF3JeqJb55RypLOb59ZU+12KiMigi+mAnzoqi8kjMlmobhoRCaCYDngz47YLS9lY3cJ7u5v8LkdEZFDFdMADfGLGKHLSknjozZ1+lyIiMqhiPuCHJcVz50VlLN5Sx5aDLX6XIyIyaGI+4AE+M280qUnx/PjNXX6XIiIyaBTwQHZqErfNKWXR2mr2NXX4XY6IyKBQwIfcdekY4gweXqqjeBEJBgV8yIisYXxsejFPrthHfavWbBWR6KeA7+fe+ePp6nE8sHi736WIiJwzBXw/Y/LSuGV2CU8s30tVQ7vf5YiInBMF/Anuv7KcxPg4/vWVrX6XIiJyThTwJyjITOHuS8fw/Loa1u0/7Hc5IiIfmAL+JO7+0Fhy0pL43ktbtCCIiEQtBfxJZKQk8sUrxvPOzkYWb6nzuxwRkQ9EAX8Kn75wNOML0vnOoo0cOd7tdzkiImdNAX8KSQlx/N2NU9h/6AgPvK5hkyISfRTwpzFvXC5/OL2YBUt2saOu1e9yRETOyoAC3szuN7NM8zxiZqvM7OpwFxcJvvnRyQxLjOfbz27QCVcRiSoDPYL/nHOuBbgaGA7cAXwvbFVFkLz0ZP7imkks29XEMyv3+12OiMiADTTgLfT1WuBx59zGfo8F3m1zSpldNpy//c0mDhw+4nc5IiIDMtCAX2lmr+AF/MtmlgH0hK+syBIXZ/zgk9Podo6/eGYtPT3qqhGRyDfQgL8L+Dow2znXASQCfxS2qiJQaW4q3/7oeby9o5HHl+3xuxwRkTMaaMDPA7Y65w6b2e3At4Hm8JUVmW6dU8L8ifn800ub2Vnf5nc5IiKnNdCAfxDoMLMK4CvATuCxsFUVocyMf/n4VFIS4/niL1ZztFMXQIlI5BpowHc5b4zgjcADzrn/AjLCV1bkKshM4Yc3V7CppoXv/maT3+WIiJzSQAO+1cy+gTc88gUzi8Prh49JV0wq5POXjWPhe3t5dvUBv8sRETmpgQb8p4BjeOPhDwKjgO+Hraoo8NWrJzCnLIdv/nq9rnIVkYg0oIAPhfoTQJaZXQccdc7FXB98fwnxcfzo1ukMS4znnsdW0tzR6XdJIiLvM9CpCm4G3gM+CdwMLDezT4SzsGhQlJXCg7fPZN+hDu5buIqu7pi5NEBEosBAu2i+hTcG/rPOuc8Ac4C/Cl9Z0WPOmBz+4aYLWLq9gb9/YbPf5YiI9EkY4HZxzrn+K180opko+9w8u4Rtta08/NZuxhekc/vc0X6XJCIy4ID/rZm9DCwMff8p4MXwlBSdvnHtZHY1tPPXz22gICOZq88v8rskEYlxAz3J+jVgATA1dFvgnPvLcBYWbeLjjAdum84Fo7L54sLVVFY1+V2SiMS4AXezOOd+6Zz7cuj26zNtb2YlZva6mW0ys41mdv+5lRr5UpMS+J87Z1OcPYy7flbJtloNnxQR/5w24M2s1cxaTnJrNbOWM7x3F/AV59x5wFzgC2Z23mAVHqly0pL42efmkJwQxx2PLGdPY7vfJYlIjDptwDvnMpxzmSe5ZTjnMs/w2hrn3KrQ/VZgM1A8eKVHrpKcVB6/60KOd/Vw20+Ws6+pw++SRCQGDclIGDMrA6YDy0/y3D1mVmlmlfX19UNRzpCYWJTB43ddSOvRTm57eBnVWihERIZY2APezNKBXwJ
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"\n",
"x = np.linspace(0, 1, 100)\n",
"y = -np.log(x) # 二分类化简\n",
"\n",
"plt.plot(x, y)\n",
"plt.xlabel('predicted probability')\n",
"plt.ylabel('loss')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"可以看出,该函数是 __凸函数__ ,求导时能够得到全局最优值。"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}