[关闭]
@vivounicorn 2020-03-29T17:08:16.000000Z 字数 58218 阅读 3401

机器学习与人工智能技术分享-第二章 建模方法回顾

第二章 机器学习 LR SVM ATM

回到目录


2. 建模方法回顾

以通用的监督学习为例,基本包含4个部分:

2.0 偏差与方差

  • 在机器学习算法中,偏差是由先验假设的不合理带来的模型误差,高偏差会导致欠拟合: 所谓欠拟合是指对特征和标注之间的因果关系学习不到位,导致模型本身没有较好的学到历史经验的现象;
  • 方差表征的是模型误差对样本发生一定变化时的敏感度,高方差会导致过拟合:模型对训练样本中的随机噪声也做了拟合学习,导致在未知样本上应用时出现效果较差的现象;
  • 机器学习模型的核心之一在于其推广能力,即在未知样本上的表现。

对方差和偏差的一种直观解释:

一个例子,假如我们有预测模型:

我们希望用 估计 ,如果使用基于square loss 的线性回归,则误差分析如下:


所以大家可以清楚的看到模型学习过程其实就是对偏差和方差的折中过程。

2.1 线性回归-Linear Regression



简单线性回归

2.1.1 模型原理

标准线性回归通过对自变量的线性组合来预测因变量,组合自变量的权重通过最小化训练集中所有样本的预测平方误差和来得到,原理如下。

  • 预测函数
  • 参数学习-采用最小二乘法

所有机器学习模型的成立都会有一定的先验假设,线性回归也不例外,它对数据做了以下强假设:

  • 自变量相互独立,无多重共线性
  • 因变量是自变量的线性加权组合:
  • 所有样本独立同分布(iid),且误差项服从以下分布:

最小二乘法与以上假设的关系推导如下:

使用MLE(极大似然法)估计参数如下:

线性回归有两个重要变体:

关于正则化及最优化后续会做介绍。

2.1.2 损失函数

损失函数1 —— Least Square Loss

进一步阅读可参考:Least Squares

Q: 模型和损失的关系是什么?

2.2 支持向量机-Support Vector Machine

支持向量机通过寻找一个分类超平面使得(相对于其它超平面)它与训练集中任何一类样本中最接近于超平面的样本的距离最大。虽然从实用角度讲(尤其是针对大规模数据和使用核函数)并非最优选择,但它是大家理解机器学习的最好模型之一,涵盖了类似偏差和方差关系的泛化理论、最优化原理、核方法原理、正则化等方面知识。

2.2.1 模型原理

SVM原理可以从最简单的解析几何问题中得到:

超平面的定义如下:

从几何关系上来看,超平面与数据点的关系如下(以正样本点为例):

定义几何距离和函数距离分别如下:

由于超平面的大小对于SVM求解并不重要,重要的是其方向,所以根据SVM的定义,得到约束最优化问题:

现实当中我们无法保证数据是线性可分的,强制要求所有样本能正确分类是不太可能的,即使做了核变换也只是增加了这种可能性,因此我们又需要做折中,允许误分的情况出现,对误分的样本根据其严重性做惩罚,所以引入松弛变量,将上述问题变成软间隔优化问题。

新的优化问题:

如果选择:

那么优化问题变成:

2.2.2 损失函数

损失函数2 —— Hinge Loss

使用hinge loss将SVM套入机器学习框架,让它更容易理解。此时原始约束最优化问题变成损失函数是hinge loss且正则项是L2正则的无约束最优化问题:

下面我证明以上问题(1)和问题(2)是等价的(反之亦然):

到此为止,SVM和普通的判别模型没什么两样,也没有support vector的概念,它之所以叫SVM就得说它的对偶形式了,通过拉格朗日乘数法对原始问题做对偶变换:

从互补松弛条件可以得到以下信息:
时,松弛变量不为零,此时其几何间隔小于,对应样本点就是误分点;当时,松弛变量为零,此时其几何间隔大于,对应样本点就是内部点,即分类正确而又远离最大间隔分类超平面的那些样本点;而时,松弛变量为零,此时其几何间隔等于,对应样本点就是支持向量的取值一定是,这意味着向量被限制在了一个边长为的盒子里。详细说明可参考SVM学习——软间隔优化


越大表明你越不想放弃离群点,分类超平面越向离群点移动

当以上问题求得最优解后,几何间隔变成如下形式:


它只与有限个样本有关系,这些样本被称作支持向量,从这儿也能看出此时模型参数个数与样本个数有关系,这是典型的非参学习过程。

2.2.3 核方法

上面对将内积用一个核函数做了代替,实际上这种替换不限于SVM,所有出现样本间内积的地方都可以考虑这种核变换,本质上它就是通过某种隐式的空间变换在新空间(有限维或无限维兼可)做样本相似度衡量,采用核方法后的模型都可以看做是无固定参数的基于样本的学习器,属于非参学习,核方法与SVM这类模型的发展是互相独立的。


from Kernel Trick

这里不对原理做展开,可参考:
1、Kernel Methods for Pattern Analysis
2、the kernel trick for distances

一些可以应用核方法的模型:

  • SVM
  • Perceptron
  • PCA
  • Gaussian processes
  • Canonical correlation analysis
  • Ridge regression
  • Spectral clustering

在我看来核方法的意义在于:
1、对样本进行空间映射,以较低成本隐式的表达样本之间的相似度,改善样本线性可分的状况,但不能保证线性可分;
2、将线性模型变成非线性模型从而提升其表达能力,但这种升维的方式往往会造成计算复杂度的上升。

一些关于SVM的参考资料
SVM学习——线性学习器
SVM学习——求解二次规划问题
SVM学习——核函数
SVM学习——统计学习理论
SVM学习——软间隔优化
SVM学习——Coordinate Desent Method
SVM学习——Sequential Minimal Optimization
SVM学习——Improvements to Platt’s SMO Algorithm

2.3 逻辑回归-Logistic Regression

逻辑回归恐怕是互联网领域用的最多的模型之一了,很多公司做算法的同学都会拿它做为算法系统进入模型阶段的baseline。

2.3.1 模型原理

逻辑回归是一种判别模型,与线性回归类似,它有比较强的先验假设 :

  • 假设因变量服从贝努利分布
  • 假设训练样本服从钟形分布,例如高斯分布:
  • 是样本标注,布尔类型,取值为0或1;
  • 是样本的特征向量。

逻辑回归是判别模型,所以我们直接学习,以高斯分布为例:


整个原理部分的推导过程如下:

采用 MLE 或者 MAP 做参数求解:

2.3.2 损失函数

损失函数3 —— Cross Entropy Loss

简单理解,从概率角度:Cross Entropy损失函数衡量的是两个概率分布之间的相似性,对真实分布估计的越准损失越小;从信息论角度:用编码方式对由编码方式产生的信息做编码,如果两种编码方式越接近,产生的信息损失越小。与Cross Entropy相关的一个概念是Kullback–Leibler divergence,后者是衡量两个概率分布接近程度的标量值,定义如下:


当两个分布完全一致时其值为0,显然Cross Entropy与Kullback–Leibler divergence的关系是:

关于交叉熵及其周边原理,有一篇文章写得特别好:Visual Information Theory

2.4 Bagging and Boosting框架

Bagging和Boosting是两类最常用以及好用的模型融合框架,殊途而同归。

2.4.1 Bagging框架

Bagging(Breiman, 1996) 方法是通过对训练样本和特征做有放回的抽样,并拟合若干个基础模型进而通过投票方式做最终分类决策的框架。每个基础分类器(可以是树形结构、神经网络等等任何分类模型)的特点是低偏差、高方差,框架通过(加权)投票方式降低方差,使得整体趋于低偏差、低方差

分析如下:
假设任务是学习一个模型 ,我们通过抽样生成生成 个数据集,并训练得到个基础分类器

从结论可以发现多分类器投票机制的引入可以降低模型方差从而降低分类错误率,大家可以多理解理解这一系列推导。

2.4.2 Boosting框架

Boosting(Freund & Shapire, 1996) 通过迭代方式训练若干基础分类器,每个分类器依据上一轮分类器产生的残差做权重调整,每轮的分类器需要够“简单”,具有高偏差、低方差的特点,框架再辅以(加权)投票方式降低偏差,使得整体趋于低偏差、低方差

一个简单的总结:


AnyBoost Algorithm
Boost算法是个框架,很多模型都能往进来套。

Q: boosting 和 margin的关系是什么(机器学习中margin的定义为)?
Q: 类似bagging,为什么boosting能够通过reweight及投票方式降低整体偏差?

2.5 Additive Tree 模型

Additive tree models (ATMs)是指基础模型是树形结构的一类融合模型,可做分类、回归,很多经典的模型可以被看做ATM模型,比如Random forest 、Adaboost with trees、GBDT等。

ATM 对N棵决策树做加权融合,其判别函数为:

2.5.1 Random Forests

Random Forest 属于bagging类模型,每棵树会使用各自随机抽样样本和特征被独立的训练。

2.5.2 AdaBoost with trees

AdaBoost with trees通过训练多个弱分类器来组合得到一个强分类器,每次迭代会生成一棵高偏差、低方差的树形弱分类器,每一轮的训练会更关注上一轮被分类器分错的样本,为其加大权重,训练过程如下:


From Bishop(2006)

2.5.3 Gradient Boosting Decision Tree

Gradient boosted 是一类boosting的技术,不同于Adaboost加大误分样本权重的策略,它每次迭代加的是上一轮梯度更新值:


其训练过程如下:

GBDT是基础分类器为决策树的可做分类和回归的模型。

目前我认为最好的GBDT的实现是XGBoost:
其回归过程的示例图如下,通过对样本落到每棵树的叶子节点的权重值做累加来实现回归(或分类):


Regression Tree Ensemble from chentianqi

其原理推导如下:

对GBDT来说依然避免不了过拟合,所以与传统机器学习一样,通过正则化策略可以降低这种风险:

  • 提前终止(Early Stopping)
    通过观察模型在验证集上的错误率,如果它变化不大则可以提前结束训练,控制迭代轮数(即树的个数);
  • 收缩(Shrinkage)

    从迭代的角度可以看成是学习率(learning rate),从融合(ensemble)的角度可以看成每棵树的权重,的大小经验上可以取0.1,它是对模型泛化性和训练时长的折中;
  • 抽样(Subsampling)
    借鉴Bagging的思想,GBDT可以在每一轮树的构建中使用训练集中无放回抽样的样本,也可以对特征做抽样,模拟真实场景下的样本分布波动;
  • 目标函数中显式的正则化项

    通过对树的叶子节点个数、叶子节点权重做显式的正则化达到缓解过拟合的效果;
  • 参数放弃(Dropout)
    模拟深度学习里随机放弃更新权重的方法,可以在每新增一棵树的时候拟合随机抽取的一些树的残差,相关方法可以参考:DART: Dropouts meet Multiple Additive Regression Trees,文中对该方法和Shrinkage的方法做了比较:

XGBoost源码在: https://github.com/dmlc中,其包含非常棒的设计思想和实现,建议大家都去学习一下,一起添砖加瓦。原理部分我就不再多写了,看懂一篇论文即可,但特别需要注意的是文中提到的weighted quantile sketch算法,它用来解决当样本集权重分布不一致时如何选择分裂节点的问题:XGBoost: A Scalable Tree Boosting System

2.5.4 简单的例子

下面是关于几个常用机器学习模型的对比,从中能直观地体会到不同模型的运作区别,数据集采用libsvm作者整理好的fourclass_scale数据集,机器学习工具采用sklearn,代码中模型未做任何调参,仅使用默认参数设置。

  1. import urllib
  2. import matplotlib
  3. import os
  4. matplotlib.use('Agg')
  5. from matplotlib import pyplot as plt
  6. from mpl_toolkits.mplot3d import proj3d
  7. import numpy as np
  8. from mpl_toolkits.mplot3d import Axes3D
  9. from sklearn.externals.joblib import Memory
  10. from sklearn.datasets import load_svmlight_file
  11. from sklearn import metrics
  12. from sklearn.metrics import roc_auc_score
  13. from sklearn import svm
  14. from sklearn.linear_model import LogisticRegression
  15. from sklearn.linear_model import Ridge
  16. from sklearn.ensemble import GradientBoostingClassifier
  17. from mpl_toolkits.mplot3d import Axes3D
  18. from matplotlib import cm
  19. from matplotlib.ticker import LinearLocator, FormatStrFormatter
  20. from sklearn.tree import DecisionTreeClassifier
  21. import keras
  22. from keras.models import Sequential
  23. from keras.layers.core import Dense,Dropout,Activation
  24. def download(outpath):
  25. filename=outpath+"/fourclass_scale"
  26. if os.path.exists(filename) == False:
  27. urllib.urlretrieve("https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/fourclass_scale",filename)
  28. def data_building():
  29. dtrain = load_svmlight_file('fourclass_scale')
  30. train_d=dtrain[0].toarray()
  31. train_l=dtrain[1]
  32. x1 = train_d[:,0]
  33. x2 = train_d[:,1]
  34. y = train_l
  35. px1 = []
  36. px2 = []
  37. pl = []
  38. nx1 = []
  39. nx2 = []
  40. nl = []
  41. idx = 0
  42. for i in y:
  43. if i == 1:
  44. px1.append(x1[idx]-0.5)
  45. px2.append(x2[idx]+0.5)
  46. pl.append(i)
  47. else:
  48. nx1.append(x1[idx]+0.8)
  49. nx2.append(x2[idx]-0.8)
  50. nl.append(i)
  51. idx = idx + 1
  52. x_axis, y_axis = np.meshgrid(np.linspace(x1.min(), x1.max(), 100), np.linspace(x2.min(), x2.max(), 100))
  53. return x_axis, y_axis, px1, px2, nx1, nx2, train_d, train_l
  54. def paint(name, x_axis, y_axis, px1, px2, nx1, nx2, z):
  55. fig = plt.figure()
  56. ax = Axes3D(fig)
  57. ax=plt.subplot(projection='3d')
  58. ax.scatter(px1,px2,c='r')
  59. ax.scatter(nx1,nx2,c='g')
  60. ax.plot_surface(x_axis, y_axis,z.reshape(x_axis.shape), rstride=8, cstride=8, alpha=0.3)
  61. ax.contourf(x_axis, y_axis, z.reshape(x_axis.shape), zdir='z', offset=-100, cmap=cm.coolwarm)
  62. ax.contourf(x_axis, y_axis, z.reshape(x_axis.shape), levels=[0,max(z)], cmap=cm.hot)
  63. ax.set_xlabel('X')
  64. ax.set_ylabel('Y')
  65. ax.set_zlabel('Z')
  66. fig.savefig(name+".png", format='png')
  67. def svc(x_axis, y_axis, x,y):
  68. clf = svm.SVC()
  69. clf.fit(x, y)
  70. y = clf.predict(np.c_[x_axis.ravel(), y_axis.ravel()])
  71. return y
  72. def lr(x_axis, y_axis, x,y):
  73. clf = LogisticRegression()
  74. clf.fit(x, y)
  75. y = clf.predict(np.c_[x_axis.ravel(), y_axis.ravel()])
  76. return y
  77. def ridge(x_axis, y_axis, x,y):
  78. clf = Ridge()
  79. clf.fit(x, y)
  80. y = clf.predict(np.c_[x_axis.ravel(), y_axis.ravel()])
  81. return y
  82. def dt(x_axis, y_axis, x,y):
  83. clf = GradientBoostingClassifier()
  84. clf.fit(x, y)
  85. y = clf.predict(np.c_[x_axis.ravel(), y_axis.ravel()])
  86. return y
  87. def nn(x_axis, y_axis, x,y):
  88. model = Sequential()
  89. model.add(Dense(20, input_dim=2))
  90. model.add(Activation('relu'))
  91. model.add(Dense(20))
  92. model.add(Activation('relu'))
  93. model.add(Dense(1, activation='tanh'))
  94. model.compile(loss='mse',
  95. optimizer='adam',
  96. metrics=['accuracy'])
  97. model.fit(x,y,batch_size=20, nb_epoch=50, validation_split=0.2)
  98. y = model.predict(np.c_[x_axis.ravel(), y_axis.ravel()],batch_size=20)
  99. return y
  100. if __name__ == '__main__':
  101. download("/root")
  102. x_axis, y_axis, px1, px2, nx1, nx2, train_d, train_l = data_building()
  103. z = svc(x_axis, y_axis, train_d, train_l)
  104. paint("svc", x_axis, y_axis, px1, px2, nx1, nx2, z)
  105. z = lr(x_axis, y_axis, train_d, train_l)
  106. paint("lr", x_axis, y_axis, px1, px2, nx1, nx2, z)
  107. z = ridge(x_axis, y_axis, train_d, train_l)
  108. paint("ridge", x_axis, y_axis, px1, px2, nx1, nx2, z)
  109. z = dt(x_axis, y_axis, train_d, train_l)
  110. paint("gbdt", x_axis, y_axis, px1, px2, nx1, nx2, z)
  111. z = nn(x_axis, y_axis, train_d, train_l)
  112. paint("nn", x_axis, y_axis, px1, px2, nx1, nx2, z)



2.5.5 CatBoost

基于不同的工程实现方式,目前常用的GBDT框架有3种,分别为:XGBoost(天奇)、LightGBM(微软)、CatBoost(Yandex)。三种实现在效果上并没有特别大的区别,区别主要在特征的处理尤其是类别特征上和建树过程上。

  1. # coding=UTF-8
  2. """
  3. class CatBoostModel
  4. """
  5. import random
  6. from sklearn.model_selection import train_test_split
  7. from sklearn.preprocessing import LabelEncoder
  8. from sklearn.model_selection import StratifiedShuffleSplit
  9. from sklearn.metrics import roc_auc_score
  10. import gc
  11. import catboost as cb
  12. from catboost import *
  13. import numpy as np
  14. import pandas as pd
  15. class CatBoostModel():
  16. def __init__(self, **kwargs):
  17. assert kwargs['catboost_params'] # 定义模型参数
  18. assert kwargs['eval_ratio'] # 定义训练集划分比例
  19. assert kwargs['early_stopping_rounds'] # 缓解过拟合,提前结束迭代参数
  20. assert kwargs['num_boost_round'] # 控制树的个数
  21. assert kwargs['cat_features'] # 类别特征列表
  22. assert kwargs['all_features'] # 所有特征列表
  23. self.catboost_params = kwargs['catboost_params']
  24. self.eval_ratio = kwargs['eval_ratio']
  25. self.early_stopping_rounds = kwargs['early_stopping_rounds']
  26. self.num_boost_round = kwargs['num_boost_round']
  27. self.cat_features = kwargs['cat_features']
  28. self.all_features = kwargs['all_features']
  29. self.selected_features_ = None
  30. self.X = None
  31. self.y = None
  32. self.model = None
  33. def get_selected_features(self, topk):
  34. """
  35. Fit the training data to FeatureSelector
  36. Returns
  37. -------
  38. list :
  39. Return the index of imprtant feature.
  40. """
  41. assert topk > 0
  42. self.selected_features_ = self.feature_importance.argsort()[-topk:][::-1]
  43. return self.selected_features_
  44. def predict(self, X, num_iteration=None):
  45. return self.model.predict(X, num_iteration)
  46. def get_fea_importance(self, clf, columns):
  47. importances = clf.feature_importances_
  48. indices = np.argsort(importances)[::-1]
  49. importance_list = []
  50. for f in range(len(columns)):
  51. importance_list.append((columns[indices[f]], importances[indices[f]]))
  52. print("%2d) %-*s %f" % (f + 1, 30, columns[indices[f]], importances[indices[f]]))
  53. print("another feature importances with prettified=True\n")
  54. print(clf.get_feature_importance(prettified=True))
  55. importance_df = pd.DataFrame(importance_list, columns=['Features', 'Importance'])
  56. return importance_df
  57. def train_test_split(self, X, y, test_size, random_state=2020):
  58. sss = list(StratifiedShuffleSplit(
  59. n_splits=1, test_size=test_size, random_state=random_state).split(X, y))
  60. X_train = np.take(X, sss[0][0], axis=0)
  61. X_eval = np.take(X, sss[0][1], axis=0)
  62. y_train = np.take(y, sss[0][0], axis=0)
  63. y_eval = np.take(y, sss[0][1], axis=0)
  64. return [X_train, X_eval, y_train, y_eval]
  65. def catboost_model_train(self,
  66. df,
  67. finetune=None,
  68. target_name='Label',
  69. id_index='Id'):
  70. df = df.loc[df[target_name].isnull() == False]
  71. feature_name = [i for i in df.columns if i not in [target_name, id_index]]
  72. for i in feature_name:
  73. if i in self.cat_features:
  74. #df[i].fillna(-999, inplace=True)
  75. if df[i].fillna('na').nunique() < 12:
  76. df.loc[:, i] = df.loc[:, i].fillna('na').astype('category')
  77. else:
  78. df.loc[:, i] = LabelEncoder().fit_transform(df.loc[:, i].fillna('na').astype(str))
  79. if type(df.loc[0,i])!=str or type(df.loc[0,i])!=int or type(df.loc[0,i])!=long:
  80. df.loc[:, i] = df.loc[:, i].astype(str)
  81. X_train, X_eval, y_train, y_eval = self.train_test_split(df[feature_name],
  82. df[target_name].values,
  83. self.eval_ratio,
  84. random.seed(41))
  85. del df
  86. gc.collect()
  87. catboost_train = Pool(data=X_train, label=y_train, cat_features=self.cat_features, feature_names=self.all_features)
  88. catboost_eval = Pool(data=X_eval, label=y_eval, cat_features=self.cat_features, feature_names=self.all_features)
  89. self.model = cb.train(params=self.catboost_params,
  90. init_model=finetune,
  91. pool=catboost_train,
  92. num_boost_round=self.num_boost_round,
  93. eval_set=catboost_eval,
  94. verbose_eval=50,
  95. plot=True,
  96. early_stopping_rounds=self.early_stopping_rounds)
  97. self.feature_importance = self.get_fea_importance(self.model, self.all_features)
  98. metrics = self.model.eval_metrics(data=catboost_eval,metrics=['AUC'],plot=True)
  99. print('AUC values:{}'.format(np.array(metrics['AUC'])))
  100. return self.feature_importance, metrics, self.model

2、定义catboost trainer(里面涉及nni的部分先忽略,后面再讲)

  1. # coding=UTF-8
  2. import bz2
  3. import urllib.request
  4. import logging
  5. import os
  6. import os.path
  7. from sklearn.datasets import load_svmlight_file
  8. from sklearn.preprocessing import LabelEncoder
  9. import nni
  10. from sklearn.metrics import roc_auc_score
  11. import gc
  12. import pandas as pd
  13. import numpy as np
  14. from model import CatBoostModel
  15. from catboost.datasets import adult
  16. logger = logging.getLogger('auto_gbdt')
  17. TARGET_NAME = 'income'
  18. def get_default_parameters():
  19. params_cb = {
  20. 'boosting_type': 'Ordered',
  21. 'objective': 'CrossEntropy',
  22. 'eval_metric': 'AUC',
  23. 'custom_metric': ['AUC', 'Accuracy'],
  24. 'learning_rate': 0.05,
  25. 'random_seed': 2020,
  26. 'depth': 2,
  27. 'l2_leaf_reg': 3.8,
  28. 'thread_count': 16,
  29. 'use_best_model': True,
  30. 'verbose': 0
  31. }
  32. return params_cb
  33. def get_features_name():
  34. alls = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country']
  35. cats = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']
  36. return alls, cats
  37. def data_prepare_cleaner(target):
  38. train_df, test_df = adult()
  39. le = LabelEncoder()
  40. le.fit(train_df[target])
  41. train_df[target] = le.transform(train_df[target])
  42. le.fit(test_df[target])
  43. test_df[target] = le.transform(test_df[target])
  44. return train_df, test_df
  45. def cat_fea_cleaner(df, target_name, id_index, cat_features):
  46. df = df.loc[df[target_name].isnull() == False]
  47. feature_name = [i for i in df.columns if i not in [target_name, id_index]]
  48. for i in feature_name:
  49. if i in cat_features:
  50. if df[i].fillna('na').nunique() < 12:
  51. df.loc[:, i] = df.loc[:, i].fillna('na').astype('category')
  52. else:
  53. df.loc[:, i] = LabelEncoder().fit_transform(df.loc[:, i].fillna('na').astype(str))
  54. if type(df.loc[0,i])!=str or type(df.loc[0,i])!=int or type(df.loc[0,i])!=long:
  55. df.loc[:, i] = df.loc[:, i].astype(str)
  56. return df
  57. def trainer_and_tester_run(train_df, test_df, all_features, cat_features):
  58. # get parameters from tuner
  59. RECEIVED_PARAMS = nni.get_next_parameter()
  60. logger.debug(RECEIVED_PARAMS)
  61. PARAMS = get_default_parameters()
  62. PARAMS.update(RECEIVED_PARAMS)
  63. logger.debug(PARAMS)
  64. cb = CatBoostModel(catboost_params=PARAMS,
  65. eval_ratio=0.33,
  66. early_stopping_rounds=20,
  67. cat_features=cat_features,
  68. all_features=all_features,
  69. num_boost_round=1000)
  70. logger.debug("The trainning process is starting...")
  71. train_df = cat_fea_cleaner(train_df, TARGET_NAME, 'index', cat_features)
  72. test_df = cat_fea_cleaner(test_df, TARGET_NAME, 'index', cat_features)
  73. feature_imp, val_score, clf = \
  74. cb.catboost_model_train(df=train_df,
  75. target_name=TARGET_NAME,
  76. id_index='id')
  77. logger.info(feature_imp)
  78. logger.info(val_score)
  79. del train_df
  80. gc.collect()
  81. logger.debug("The trainning process is ended.")
  82. av_auc = inference(clf, test_df, all_features, cat_features)
  83. nni.report_final_result(av_auc)
  84. del test_df
  85. gc.collect()
  86. def inference(clf, test_df, fea, cat_fea):
  87. logger.debug("The testing process is starting...")
  88. try:
  89. y_pred = clf.predict(test_df[fea])
  90. auc = roc_auc_score(test_df[TARGET_NAME].values, y_pred)
  91. print("auc of prediction:{0}".format(auc))
  92. del test_df
  93. gc.collect()
  94. logger.debug("The inference process is ended.")
  95. return auc
  96. except ValueError:
  97. return 0
  98. def run_offline():
  99. alls, cats = get_features_name()
  100. print(alls, cats)
  101. train_df, test_df = data_prepare_cleaner(TARGET_NAME)
  102. print(train_df)
  103. print(test_df)
  104. trainer_and_tester_run(train_df, test_df, alls, cats)
  105. if __name__ == '__main__':
  106. run_offline()
  1. [root@GPU-AI01 cat_test]# python3 catboost_trainer.py
  2. ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country'] ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country']
  3. age workclass fnlwgt education education-num marital-status occupation ... race sex capital-gain capital-loss hours-per-week native-country income
  4. 0 39.0 State-gov 77516.0 Bachelors 13.0 Never-married Adm-clerical ... White Male 2174.0 0.0 40.0 United-States 0
  5. 1 50.0 Self-emp-not-inc 83311.0 Bachelors 13.0 Married-civ-spouse Exec-managerial ... White Male 0.0 0.0 13.0 United-States 0
  6. 2 38.0 Private 215646.0 HS-grad 9.0 Divorced Handlers-cleaners ... White Male 0.0 0.0 40.0 United-States 0
  7. 3 53.0 Private 234721.0 11th 7.0 Married-civ-spouse Handlers-cleaners ... Black Male 0.0 0.0 40.0 United-States 0
  8. 4 28.0 Private 338409.0 Bachelors 13.0 Married-civ-spouse Prof-specialty ... Black Female 0.0 0.0 40.0 Cuba 0
  9. ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
  10. 32556 27.0 Private 257302.0 Assoc-acdm 12.0 Married-civ-spouse Tech-support ... White Female 0.0 0.0 38.0 United-States 0
  11. 32557 40.0 Private 154374.0 HS-grad 9.0 Married-civ-spouse Machine-op-inspct ... White Male 0.0 0.0 40.0 United-States 1
  12. 32558 58.0 Private 151910.0 HS-grad 9.0 Widowed Adm-clerical ... White Female 0.0 0.0 40.0 United-States 0
  13. 32559 22.0 Private 201490.0 HS-grad 9.0 Never-married Adm-clerical ... White Male 0.0 0.0 20.0 United-States 0
  14. 32560 52.0 Self-emp-inc 287927.0 HS-grad 9.0 Married-civ-spouse Exec-managerial ... White Female 15024.0 0.0 40.0 United-States 1
  15. [32561 rows x 15 columns]
  16. age workclass fnlwgt education education-num marital-status ... sex capital-gain capital-loss hours-per-week native-country income
  17. 0 25.0 Private 226802.0 11th 7.0 Never-married ... Male 0.0 0.0 40.0 United-States 0
  18. 1 38.0 Private 89814.0 HS-grad 9.0 Married-civ-spouse ... Male 0.0 0.0 50.0 United-States 0
  19. 2 28.0 Local-gov 336951.0 Assoc-acdm 12.0 Married-civ-spouse ... Male 0.0 0.0 40.0 United-States 1
  20. 3 44.0 Private 160323.0 Some-college 10.0 Married-civ-spouse ... Male 7688.0 0.0 40.0 United-States 1
  21. 4 18.0 NaN 103497.0 Some-college 10.0 Never-married ... Female 0.0 0.0 30.0 United-States 0
  22. ... ... ... ... ... ... ... ... ... ... ... ... ... ...
  23. 16276 39.0 Private 215419.0 Bachelors 13.0 Divorced ... Female 0.0 0.0 36.0 United-States 0
  24. 16277 64.0 NaN 321403.0 HS-grad 9.0 Widowed ... Male 0.0 0.0 40.0 United-States 0
  25. 16278 38.0 Private 374983.0 Bachelors 13.0 Married-civ-spouse ... Male 0.0 0.0 50.0 United-States 0
  26. 16279 44.0 Private 83891.0 Bachelors 13.0 Divorced ... Male 5455.0 0.0 40.0 United-States 0
  27. 16280 35.0 Self-emp-inc 182148.0 Bachelors 13.0 Married-civ-spouse ... Male 0.0 0.0 60.0 United-States 1
  28. [16281 rows x 15 columns]
  29. [03/29/2020, 04:36:35 PM] WARNING (nni) Requesting parameter without NNI framework, returning empty dict
  30. ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country'] ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country']
  31. <IPython.core.display.HTML object>
  32. MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))
  33. 0: test: 0.8316184 best: 0.8316184 (0) total: 79ms remaining: 1m 18s
  34. 50: test: 0.9020779 best: 0.9020779 (50) total: 942ms remaining: 17.5s
  35. 100: test: 0.9081992 best: 0.9081992 (100) total: 1.71s remaining: 15.2s
  36. 150: test: 0.9108707 best: 0.9108707 (150) total: 2.48s remaining: 13.9s
  37. 200: test: 0.9133666 best: 0.9133666 (200) total: 3.27s remaining: 13s
  38. 250: test: 0.9161119 best: 0.9161119 (250) total: 4.05s remaining: 12.1s
  39. 300: test: 0.9181722 best: 0.9181726 (299) total: 4.8s remaining: 11.2s
  40. 350: test: 0.9191926 best: 0.9191926 (350) total: 5.56s remaining: 10.3s
  41. 400: test: 0.9205207 best: 0.9205207 (400) total: 6.3s remaining: 9.41s
  42. 450: test: 0.9210106 best: 0.9210106 (450) total: 7.05s remaining: 8.59s
  43. 500: test: 0.9215160 best: 0.9215160 (500) total: 7.81s remaining: 7.78s
  44. 550: test: 0.9221193 best: 0.9221193 (550) total: 8.59s remaining: 7s
  45. 600: test: 0.9227742 best: 0.9227742 (600) total: 9.37s remaining: 6.22s
  46. 650: test: 0.9232136 best: 0.9232136 (650) total: 10.1s remaining: 5.41s
  47. 700: test: 0.9234202 best: 0.9234202 (700) total: 10.8s remaining: 4.61s
  48. 750: test: 0.9238425 best: 0.9238428 (749) total: 11.6s remaining: 3.83s
  49. 800: test: 0.9242466 best: 0.9242466 (800) total: 12.4s remaining: 3.07s
  50. 850: test: 0.9244801 best: 0.9244826 (849) total: 13.1s remaining: 2.3s
  51. 900: test: 0.9246592 best: 0.9246633 (899) total: 13.8s remaining: 1.52s
  52. 950: test: 0.9248365 best: 0.9248397 (949) total: 14.6s remaining: 752ms
  53. 999: test: 0.9250032 best: 0.9250045 (998) total: 15.3s remaining: 0us
  54. bestTest = 0.9250045138
  55. bestIteration = 998
  56. Shrink model to first 999 iterations.
  57. 1) capital-gain 30.102152
  58. 2) relationship 21.504190
  59. 3) capital-loss 11.414022
  60. 4) age 10.377241
  61. 5) education-num 7.234561
  62. 6) occupation 6.331733
  63. 7) hours-per-week 4.476134
  64. 8) marital-status 4.354055
  65. 9) sex 1.146389
  66. 10) education 1.092454
  67. 11) workclass 0.776362
  68. 12) fnlwgt 0.577348
  69. 13) native-country 0.382720
  70. 14) race 0.230639
  71. another feature importances with prettified=True
  72. Feature Id Importances
  73. 0 capital-gain 30.102152
  74. 1 relationship 21.504190
  75. 2 capital-loss 11.414022
  76. 3 age 10.377241
  77. 4 education-num 7.234561
  78. 5 occupation 6.331733
  79. 6 hours-per-week 4.476134
  80. 7 marital-status 4.354055
  81. 8 sex 1.146389
  82. 9 education 1.092454
  83. 10 workclass 0.776362
  84. 11 fnlwgt 0.577348
  85. 12 native-country 0.382720
  86. 13 race 0.230639
  87. <IPython.core.display.HTML object>
  88. MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))
  89. AUC values:[0.83161836 0.85856642 0.85813673 0.86103314 0.86071689 0.86121893
  90. 0.86441522 0.86505078 0.8784926 0.87975055 0.88290187 0.88250307
  91. 0.88799037 0.8881507 0.89029036 0.88999382 0.88988625 0.89090279
  92. 0.89101198 0.89160979 0.89201995 0.89295319 0.89335747 0.89347394
  93. 0.89391874 0.89388563 0.89537425 0.89577286 0.89558464 0.89714728
  94. 0.89720971 0.89743348 0.89766614 0.89780103 0.89835231 0.89859723
  95. 0.89845935 0.89914869 0.89911805 0.89974913 0.89990595 0.90033138
  96. 0.90038467 0.90098889 0.90072697 0.90107988 0.90084311 0.90124622
  97. 0.90129837 0.90169675 0.90207794 0.90213554 0.90252241 0.90278218
  98. 0.90291333 0.90347889 0.90362643 0.90369515 0.90384743 0.90364836
  99. 0.90427544 0.90441459 0.90463238 0.90473693 0.90474323 0.9047752
  100. 0.90479827 0.90473968 0.90509671 0.9052363 0.90524989 0.90552186
  101. 0.90563292 0.90562407 0.90597002 0.90596054 0.90612973 0.90613532
  102. 0.90626529 0.90624885 0.90629925 0.9065094 0.90687207 0.90713255
  103. 0.90713464 0.90714229 0.90716147 0.90720291 0.90736582 0.9074965
  104. 0.90754282 0.90756655 0.907661 0.90770332 0.9077639 0.90790968
  105. 0.9079417 0.90798504 0.90802753 0.90811397 0.90819922 0.90826913
  106. 0.90827529 0.90832273 0.90834522 0.90836268 0.90852687 0.90865518
  107. 0.90867734 0.90877048 0.9088641 0.90892809 0.90899198 0.90898232
  108. 0.90901235 0.90917898 0.90931463 0.90933462 0.90946639 0.90955688
  109. 0.90957899 0.90955768 0.9096148 0.90963834 0.90960947 0.9097182
  110. 0.90968601 0.90968109 0.90973949 0.90986055 0.90987831 0.90990626
  111. 0.90988658 0.90996061 0.9100206 0.91016861 0.9102051 0.91031329
  112. 0.91035212 0.91039234 0.91038859 0.91055811 0.91059551 0.91060905
  113. 0.91068851 0.91069095 0.91070089 0.91075133 0.9107646 0.91083661
  114. 0.91087072 0.91088294 0.91089601 0.9109779 0.91109368 0.91112637
  115. 0.91117719 0.91131483 0.91133695 0.91135528 0.91138891 0.91139388
  116. 0.91141477 0.91143149 0.91147063 0.91155253 0.91162016 0.91166083
  117. 0.91172069 0.91181031 0.91196481 0.91200488 0.91207325 0.91211531
  118. 0.91215074 0.91216651 0.91216438 0.91216644 0.91217042 0.91222321
  119. 0.91224604 0.91226224 0.91238903 0.91245311 0.91250443 0.9125005
  120. 0.91256733 0.9126408 0.91267537 0.9127926 0.91283471 0.91296865
  121. 0.91309914 0.91321599 0.9132311 0.91326122 0.91331664 0.9133272
  122. 0.91333757 0.91335631 0.91336663 0.91339323 0.91351306 0.91353504
  123. 0.91364762 0.91371905 0.91374356 0.91382455 0.91383544 0.91387963
  124. 0.9138874 0.9139162 0.9139656 0.91401361 0.91409119 0.91417138
  125. 0.91415125 0.91426151 0.91427733 0.91430731 0.91443747 0.91448261
  126. 0.91453059 0.91459567 0.91462631 0.91465691 0.91466103 0.91475955
  127. 0.91475329 0.91484831 0.91488644 0.9149258 0.91522573 0.91525917
  128. 0.91535504 0.91542878 0.91546454 0.91549223 0.91553097 0.91555825
  129. 0.91559304 0.91562558 0.91568611 0.91569478 0.91571794 0.9157357
  130. 0.91576043 0.91577251 0.9158588 0.91609757 0.91611192 0.91617368
  131. 0.91615905 0.91616662 0.91621344 0.91624148 0.91624617 0.916279
  132. 0.91631196 0.91642947 0.91647835 0.91652667 0.91651392 0.91652581
  133. 0.91653713 0.91655006 0.91665427 0.91667051 0.916788 0.91681452
  134. 0.91685526 0.91696372 0.91706001 0.91709748 0.91709928 0.91717672
  135. 0.91719126 0.91729565 0.91739161 0.91742034 0.9174929 0.91750469
  136. 0.91754306 0.9175746 0.91757541 0.91759738 0.91761036 0.91759407
  137. 0.91763139 0.91772489 0.91776103 0.9179556 0.9179898 0.91801864
  138. 0.91801737 0.91803603 0.91805739 0.91813814 0.91814534 0.91817263
  139. 0.9181722 0.91824609 0.91827275 0.91829577 0.91829899 0.91831358
  140. 0.91837373 0.91842195 0.91842025 0.91844374 0.91845804 0.91856584
  141. 0.91857612 0.91863774 0.91862941 0.91863533 0.91871476 0.91873683
  142. 0.91873685 0.91873453 0.91876589 0.91877062 0.91877522 0.91875817
  143. 0.91876122 0.9187661 0.91875142 0.91882763 0.91882948 0.91885122
  144. 0.91885425 0.91883814 0.91885245 0.91885491 0.91889654 0.91892941
  145. 0.91899658 0.91898947 0.91903944 0.91905251 0.91906383 0.91905976
  146. 0.91905803 0.91906125 0.91911411 0.91912297 0.91911189 0.9191422
  147. 0.91916427 0.91918881 0.91919262 0.9192038 0.91922066 0.91922383
  148. 0.91924121 0.91924974 0.91929625 0.91931439 0.9193681 0.91936356
  149. 0.9194005 0.91940401 0.91942551 0.91941168 0.91949196 0.91950447
  150. 0.91950522 0.91950731 0.91951408 0.91954548 0.919564 0.91956135
  151. 0.91961966 0.91962117 0.91961662 0.91963093 0.91963782 0.9197918
  152. 0.91981146 0.91984902 0.91987796 0.91986474 0.91986896 0.91988572
  153. 0.91999293 0.91999407 0.9200143 0.92006417 0.920051 0.92006654
  154. 0.92015715 0.92027774 0.92028039 0.92028351 0.92029304 0.9203403
  155. 0.92041614 0.9204284 0.920495 0.92051366 0.92052067 0.92052967
  156. 0.92054984 0.92056874 0.92059948 0.9206144 0.92060015 0.92060062
  157. 0.9206 0.92060142 0.92060772 0.92063425 0.92063752 0.92064007
  158. 0.92064424 0.92065741 0.92066167 0.92067399 0.92068185 0.92068403
  159. 0.92072576 0.92073774 0.92074077 0.92074873 0.92073982 0.92076398
  160. 0.92077724 0.92078624 0.92078373 0.92078809 0.92077483 0.92077691
  161. 0.92080732 0.92081864 0.92082224 0.92081307 0.92081577 0.92081407
  162. 0.9208495 0.92085911 0.92088312 0.92088725 0.92088677 0.92088559
  163. 0.92089568 0.92089648 0.92090354 0.92092367 0.92092732 0.92094607
  164. 0.92101058 0.92102503 0.92103919 0.92102465 0.92103474 0.92103725
  165. 0.92102598 0.92102195 0.92102607 0.92103857 0.92103682 0.92103957
  166. 0.92106103 0.92107088 0.92107722 0.92108452 0.92108708 0.92114306
  167. 0.92115324 0.92119644 0.9212017 0.92121074 0.92122964 0.92124253
  168. 0.92124826 0.92127222 0.92129136 0.92131528 0.92131712 0.92132233
  169. 0.9213274 0.92132897 0.92131902 0.92132323 0.92133725 0.92138855
  170. 0.92138997 0.92144297 0.92145268 0.92144453 0.92144766 0.92148863
  171. 0.92148243 0.92148764 0.92149299 0.92150881 0.92149493 0.92149905
  172. 0.92150045 0.92151196 0.92151603 0.92152768 0.92153251 0.92151722
  173. 0.92152228 0.92152569 0.92153124 0.92162734 0.92162582 0.92164117
  174. 0.9216549 0.92165462 0.92171068 0.92171977 0.92172706 0.92172754
  175. 0.92173966 0.92175084 0.92175719 0.92175633 0.92175221 0.92176424
  176. 0.92176529 0.92177016 0.92177178 0.92178025 0.92176628 0.92184211
  177. 0.9218522 0.92185713 0.92185623 0.92186669 0.92191529 0.92194717
  178. 0.92195285 0.92197383 0.92197056 0.92197052 0.92197658 0.92198013
  179. 0.92196483 0.92197075 0.92197885 0.92198245 0.92199188 0.92204175
  180. 0.92204227 0.92209778 0.92210214 0.92210863 0.92211934 0.92216049
  181. 0.92216462 0.92216983 0.92217963 0.92217537 0.92221117 0.92241212
  182. 0.92241638 0.92242576 0.92243073 0.92243177 0.92244111 0.92245001
  183. 0.92246045 0.9224605 0.92249186 0.92249266 0.9225063 0.92253676
  184. 0.92254249 0.92254519 0.92255234 0.92255675 0.9225657 0.92257275
  185. 0.92256906 0.92257072 0.92255637 0.92255892 0.922562 0.92256087
  186. 0.92257162 0.92257129 0.92257129 0.92257015 0.92255466 0.92255267
  187. 0.92254007 0.92253922 0.92254197 0.92254647 0.92254566 0.9225891
  188. 0.9225908 0.92258696 0.92259497 0.92259644 0.92275269 0.9227705
  189. 0.92277424 0.92276742 0.92277045 0.92277491 0.92278784 0.92278968
  190. 0.9227921 0.92279944 0.92280498 0.92281464 0.92282772 0.92282279
  191. 0.92295612 0.92295963 0.92295584 0.92297024 0.92295882 0.92296086
  192. 0.92295636 0.92296015 0.92296214 0.92297611 0.92297597 0.92297367
  193. 0.92297282 0.9229821 0.92299015 0.92299285 0.92299162 0.92299494
  194. 0.92304206 0.92307967 0.92308431 0.92310553 0.92310695 0.92310871
  195. 0.92310894 0.9231188 0.92311927 0.92312131 0.92313343 0.92313457
  196. 0.92313679 0.92313154 0.92313566 0.92314866 0.92315825 0.92315967
  197. 0.92320145 0.92320557 0.92321357 0.92322352 0.92323133 0.92323763
  198. 0.92323815 0.92324663 0.92324834 0.92325137 0.92325492 0.92325402
  199. 0.92325374 0.92326515 0.92327003 0.92327538 0.92326444 0.92326927
  200. 0.92326937 0.92325672 0.92327946 0.92327837 0.92327979 0.92328097
  201. 0.92328386 0.92329125 0.92329248 0.92329977 0.92330783 0.92330854
  202. 0.92331124 0.92331299 0.92334018 0.92333804 0.92333752 0.92334894
  203. 0.92334851 0.92334979 0.92335064 0.92335164 0.92336736 0.92336869
  204. 0.92337068 0.9233623 0.92336433 0.92336395 0.92336964 0.92337428
  205. 0.92338446 0.92339521 0.92339592 0.92340004 0.92342017 0.92341999
  206. 0.92342126 0.92342264 0.92341984 0.92341984 0.92341833 0.92340795
  207. 0.92341307 0.92342212 0.92346853 0.92346801 0.92349373 0.92351225
  208. 0.92351931 0.92352476 0.92356085 0.92356151 0.92356483 0.92357894
  209. 0.92361887 0.92361683 0.92362503 0.92362484 0.92364317 0.92366462
  210. 0.92366391 0.92366415 0.92366709 0.92367481 0.92368707 0.92368769
  211. 0.92368816 0.92368991 0.92368925 0.92369735 0.92372847 0.92372814
  212. 0.92372667 0.92373089 0.92373169 0.92373538 0.92376077 0.92376134
  213. 0.92377977 0.9238052 0.92380307 0.92382703 0.92383556 0.92384276
  214. 0.92384248 0.92384338 0.92384195 0.92384991 0.92384778 0.92385673
  215. 0.92385702 0.92385579 0.92388179 0.92390358 0.92390452 0.92392219
  216. 0.92393938 0.9239624 0.92396657 0.92397055 0.92397022 0.92398855
  217. 0.92398784 0.92399916 0.92401161 0.92402559 0.92402535 0.92412785
  218. 0.92413405 0.92415006 0.92416323 0.92416356 0.92416346 0.9241593
  219. 0.92415797 0.92415541 0.92415636 0.92415508 0.92415357 0.9241574
  220. 0.92417905 0.92419297 0.92420628 0.92421429 0.92421362 0.92421457
  221. 0.92422599 0.92422475 0.92422423 0.92422291 0.92422414 0.9242293
  222. 0.92423915 0.92424469 0.92424659 0.92424834 0.9242482 0.92426392
  223. 0.92426804 0.9242796 0.92428093 0.92429978 0.92430196 0.92430196
  224. 0.92432313 0.92432455 0.9243191 0.92431726 0.9243156 0.9243146
  225. 0.92432185 0.924339 0.92433824 0.92435922 0.92436339 0.92436391
  226. 0.92437964 0.92437731 0.92438267 0.92438025 0.92437954 0.92437973
  227. 0.92441307 0.92441795 0.92441701 0.92442023 0.9244314 0.92442975
  228. 0.92442918 0.92442804 0.92442899 0.92443202 0.92443264 0.92443145
  229. 0.92444026 0.92445457 0.92445348 0.92445267 0.92445163 0.92445144
  230. 0.92446295 0.92447209 0.92446982 0.92448261 0.92448014 0.92448966
  231. 0.92450439 0.92450236 0.92450188 0.92450184 0.92449956 0.92451074
  232. 0.92451017 0.9245222 0.92452661 0.92452457 0.92453395 0.92453172
  233. 0.92454541 0.92454967 0.92454584 0.92454304 0.92454058 0.92454039
  234. 0.92453722 0.92455214 0.92455304 0.92454953 0.92456061 0.92457629
  235. 0.9245798 0.92457681 0.92459439 0.92460656 0.92462025 0.92461693
  236. 0.92462337 0.92461906 0.92463341 0.9246309 0.92463697 0.92463474
  237. 0.92463332 0.92463242 0.92462896 0.92462574 0.92463891 0.92464047
  238. 0.92464241 0.92464023 0.92464147 0.92465189 0.92465809 0.9246633
  239. 0.92465923 0.92466605 0.92466898 0.92467012 0.92466695 0.92466572
  240. 0.92467879 0.92468964 0.92471355 0.92472397 0.92472525 0.9247262
  241. 0.92473302 0.92472966 0.92473449 0.92473468 0.92473392 0.92473316
  242. 0.92473132 0.92472805 0.92472582 0.92472492 0.92472293 0.9247263
  243. 0.92473454 0.92473511 0.92473743 0.92473648 0.92475415 0.92476172
  244. 0.92475898 0.92477077 0.92479057 0.92479355 0.92479284 0.92480303
  245. 0.9248143 0.92481302 0.92482448 0.92482211 0.92482301 0.92482368
  246. 0.92482808 0.92482638 0.92482931 0.92483178 0.92483741 0.92483784
  247. 0.9248386 0.92483969 0.92483646 0.9248404 0.92485238 0.92485067
  248. 0.92486796 0.92488321 0.92488009 0.92487786 0.92487805 0.92488516
  249. 0.92488994 0.92489098 0.92490003 0.92490704 0.92490391 0.92491362
  250. 0.92492802 0.92493299 0.9249463 0.92494289 0.92493906 0.92493816
  251. 0.92494223 0.92493996 0.9249382 0.92493849 0.9249491 0.92494697
  252. 0.92494711 0.92495838 0.92495838 0.92497013 0.92498078 0.92497794
  253. 0.92497695 0.92497827 0.92497619 0.924976 0.92498211 0.92498495
  254. 0.92499078 0.92500338 0.92500087 0.92500049 0.92499874 0.92499874
  255. 0.92499883 0.92499935 0.92500451]
  256. AUC of prediction:0.8603217542453206

2.6 人工神经网络-Neural Network

神经网络在维基百科上的定义是:NN is a network inspired by biological neural networks (the central nervous systems of animals, in particular the brain) which are used to estimate or approximate functions that can depend on a large number of inputs that are generally unknown.(from wikipedia)


2.6.1 神经元

神经元是神经网络和SVM这类模型的基础模型和来源,它是一个具有如下结构的线性模型:


其输出模式为:

示意图如下:


2.6.2 神经网络的常用结构

神经网络由一系列神经元组成,典型的神经网络结构如下:


其中最左边是输入层,包含若干输入神经元,最右边是输出层,包含若干输出神经元,介于输入层和输出层的所有层都叫隐藏层,由于神经元的作用,任何权重的微小变化都会导致输出的微小变化,即这种变化是平滑的。

神经元的各种组合方式得到性质不一的神经网络结构 :



前馈神经网络



反向传播神经网络



循环神经网络



卷积神经网络



自编码器



Google DeepMind 记忆神经网络(用于AlphaGo)

2.6.3 一个简单的神经网络例子

假设随机变量 , 使用3层神经网络拟合该分布:

  1. import numpy as np
  2. import matplotlib
  3. matplotlib.use('Agg')
  4. import matplotlib.pyplot as plt
  5. import random
  6. import math
  7. import keras
  8. from keras.models import Sequential
  9. from keras.layers.core import Dense,Dropout,Activation
  10. def gd(x,m,s):
  11. left=1/(math.sqrt(2*math.pi)*s)
  12. right=math.exp(-math.pow(x-m,2)/(2*math.pow(s,2)))
  13. return left*right
  14. def pt(x, y1, y2):
  15. if len(x) != len(y1) or len(x) != len(y2):
  16. print 'input error.'
  17. return
  18. plt.figure(num=1, figsize=(20, 6))
  19. plt.title('NN fitting Gaussian distribution', size=14)
  20. plt.xlabel('x', size=14)
  21. plt.ylabel('y', size=14)
  22. plt.plot(x, y1, color='b', linestyle='--', label='Gaussian distribution')
  23. plt.plot(x, y2, color='r', linestyle='-', label='NN fitting')
  24. plt.legend(loc='upper left')
  25. plt.savefig('ann.png', format='png')
  26. def ann(train_d, train_l, prd_d):
  27. if len(train_d) == 0 or len(train_d) != len(train_l):
  28. print 'training data error.'
  29. return
  30. model = Sequential()
  31. model.add(Dense(30, input_dim=1))
  32. model.add(Activation('relu'))
  33. model.add(Dense(30))
  34. model.add(Activation('relu'))
  35. model.add(Dense(1, activation='sigmoid'))
  36. model.compile(loss='mse',
  37. optimizer='rmsprop',
  38. metrics=['accuracy'])
  39. model.fit(train_d,train_l,batch_size=250, nb_epoch=50, validation_split=0.2)
  40. p = model.predict(prd_d,batch_size=250)
  41. return p
  42. if __name__ == '__main__':
  43. x = np.linspace(-5, 5, 10000)
  44. idx = random.sample(x, 900)
  45. train_d = []
  46. train_l = []
  47. for i in idx:
  48. train_d.append(x[i])
  49. train_l.append(gd(x[i],0,1))
  50. y1 = []
  51. y2 = []
  52. for i in x:
  53. y1.append(gd(i,0,1))
  54. y2 = ann(np.array(train_d).reshape(len(train_d), 1), np.array(train_l), np.array(x).reshape(len(x), 1))
  55. pt(x, y1, y2.tolist())


References

如有遗漏请提醒我补充:
1、《Understanding the Bias-Variance Tradeoff》
http://scott.fortmann-roe.com/docs/BiasVariance.html
2、《Boosting Algorithms as Gradient Descent in Function Space》
http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.51.6893&rep=rep1&type=pdf
3、《Optimal Action Extraction for Random Forests and
Boosted Trees》
http://www.cse.wustl.edu/~ychen/public/OAE.pdf
4、《Applying Neural Network Ensemble Concepts for Modelling Project Success》
http://www.iaarc.org/publications/fulltext/Applying_Neural_Network_Ensemble_Concepts_for_Modelling_Project_Success.pdf
5、《Introduction to Boosted Trees》
https://homes.cs.washington.edu/~tqchen/data/pdf/BoostedTree.pdf
6、《Machine Learning:Perceptrons》
http://ml.informatik.uni-freiburg.de/_media/documents/teaching/ss09/ml/perceptrons.pdf
7、《An overview of gradient descent optimization algorithms》
http://sebastianruder.com/optimizing-gradient-descent/
8、《Ad Click Prediction: a View from the Trenches》
https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf
9、《ADADELTA: AN ADAPTIVE LEARNING RATE METHOD》
http://www.matthewzeiler.com/pubs/googleTR2012/googleTR2012.pdf
9、《Improving the Convergence of Back-Propagation Learning with Second Order Methods》
http://yann.lecun.com/exdb/publis/pdf/becker-lecun-89.pdf
10、《ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION》
https://arxiv.org/pdf/1412.6980v8.pdf
11、《Adaptive Subgradient Methods for Online Learning and Stochastic Optimization》
http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
11、《Sparse Allreduce: Efficient Scalable Communication for Power-Law Data》
https://arxiv.org/pdf/1312.3020.pdf
12、《Asynchronous Parallel Stochastic Gradient Descent》
https://arxiv.org/pdf/1505.04956v5.pdf
13、《Large Scale Distributed Deep Networks》
https://papers.nips.cc/paper/4687-large-scale-distributed-deep-networks.pdf
14、《Introduction to Optimization —— Second Order Optimization Methods》
https://ipvs.informatik.uni-stuttgart.de/mlr/marc/teaching/13-Optimization/04-secondOrderOpt.pdf
15、《On the complexity of steepest descent, Newton’s and regularized Newton’s methods for nonconvex unconstrained optimization》
http://www.maths.ed.ac.uk/ERGO/pubs/ERGO-09-013.pdf
16、《On Discriminative vs. Generative classifiers: A comparison of logistic regression and naive Bayes 》
http://papers.nips.cc/paper/2020-on-discriminative-vs-generative-classifiers-a-comparison-of-logistic-regression-and-naive-bayes.pdf
17、《Parametric vs Nonparametric Models》
http://mlss.tuebingen.mpg.de/2015/slides/ghahramani/gp-neural-nets15.pdf
18、《XGBoost: A Scalable Tree Boosting System》
https://arxiv.org/abs/1603.02754
19、一个可视化CNN的网站
http://shixialiu.com/publications/cnnvis/demo/
20、《Computer vision: LeNet-5, AlexNet, VGG-19, GoogLeNet》
http://euler.stat.yale.edu/~tba3/stat665/lectures/lec18/notebook18.html
21、François Chollet在Quora上的专题问答:
https://www.quora.com/session/Fran%C3%A7ois-Chollet/1
22、《将Keras作为tensorflow的精简接口》
https://keras-cn.readthedocs.io/en/latest/blog/keras_and_tensorflow/
23、《Upsampling and Image Segmentation with Tensorflow and TF-Slim》
https://warmspringwinds.github.io/tensorflow/tf-slim/2016/11/22/upsampling-and-image-segmentation-with-tensorflow-and-tf-slim/
24、《Click-Through Rate Estimation for Rare Events in Online Advertising

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注