Liz's Blog

Python學習筆記#16:機器學習之SVM實作篇

| Comments

學習SVM(Support Vector Machine)的概論時,我自己很喜歡這篇說明,算是用比較簡單又算詳細的方式來解釋hyperplane的。

Udemy
課程名稱:Python for Data Science and Machine Learning Bootcamp
講師:Jose Portilla

1.載入套件

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

2.取得Scikit Learn上的乳癌資料集

from sklearn.datasets import load_breast_cancer

#資料集是以dictionary的形式存在
cancer = load_breast_cancer()
cancer.keys()

#印出資料集的描述文件
print(cancer['DESCR'])

cancer['feature_names']

3.轉成data frame來做操作

df_feat = pd.DataFrame(cancer['data'],columns=cancer['feature_names'])
df_feat.info()
cancer['target']
df_feat.head()

這邊講師有說明過去探索資料時都有做資料視覺化來掃描趨勢,但這章沒做的原因在於除非你本身有癌症相關知識,否則有許多變數是沒辦法解釋的,所以這個資料集沒有做。

4.將資料分成訓練組及測試組

from sklearn.model_selection import train_test_split

X = df_feat
y = cancer['target']
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.30, random_state=101)

5.載入Support Vector Classifier套件

from sklearn.svm import SVC
model = SVC()

#使用Support Vector Classifier來建立模型
model.fit(X_train,y_train)

6.利用測試組資料來測試模型結果

predictions = model.predict(X_test)

#載入classification_report & confusion_matrix來評估模型好壞
from sklearn.metrics import classification_report,confusion_matrix
print(confusion_matrix(y_test,predictions))
print('\n')
print(classification_report(y_test,predictions))

結果發現,分類結果都只出現在特定的類別,所以要修正模型。剛在跑model.fit(X_train,y_train)時,有看到輸出結果中有許多參數可以修改,像是C及gamma值。由於要選到適合的C值及gamma值,有好幾種排列組合,這時候就可以使用格狀組合的參數來找到最佳結果的Gridsearch。

7.載入GridSearchCV

from sklearn.model_selection import GridSearchCV

#GridSearchCV是建立一個dictionary來組合要測試的參數
param_grid = {'C':[0.1,1,10,100,1000],'gamma':[1,0.1,0.01,0.001,0.0001]}

#GridSearchCV算是一個meta-estimator,參數中帶有estimator,像是SVC,重點是會創造一個新的estimator,但又表現的一模一樣。也就是estimator=SVC時,就是作為分類器
#Verbose可設定為任一整數,它只是代表數字越高,文字解釋越多
grid = GridSearchCV(SVC(),param_grid,verbose=3)

#利用剛剛設定的參數來找到最適合的模型
grid.fit(X_train,y_train)

#顯示最佳參數組合
grid.best_params_

#顯示最佳estimator參數
grid.best_estimator_

#利用剛剛的最佳參數再重新預測測試組
grid_predictions = grid.predict(X_test)

#評估新參數的預測結果好壞
print(confusion_matrix(y_test,predictions))
print('\n')
print(classification_report(y_test,grid_predictions))

Comments

comments powered by Disqus