Monday, September 9, 2013

Learning curve generator for Learning Models in Python and scikit-learn

This particular program draws the learning curve for the Gaussian Naive Bayes Model. But the function is generic such that it can generate the Learning curve once the model for the data provided.
I have used the scikit learn library and used the "digits" data set for the calculations.
Uses simple rms error to draw the plot
Hope this helps.
Note: %pylab inline will only work if you are using ipython. If you are not, import the numpy,matplotlib.pyplot modules

%pylab inline

from sklearn.naive_bayes import GaussianNB
from sklearn.datasets import load_digits
import sklearn.cross_validation

#loading the digits dataset
digits = load_digits()

#seperating data sets for cross validation
data_train,data_test,target_train,target_test = cross_validation.train_test_split(digits.data,digits.target,test_size = 0.20, random_state = 42)

#assigning the Gaussian Naive Bayes Model
clf = GaussianNB()

#compute the rms error
def compute_error(x, y, model):
    yfit = model.predict(x)
    return np.sqrt(np.mean((y - yfit) ** 2))
    

def drawLearningCurve(model):
    sizes = np.linspace(2, 200, 50).astype(int)
    train_error = np.zeros(sizes.shape)
    crossval_error = np.zeros(sizes.shape)
    
    for i,size in enumerate(sizes):
        
        #getting the predicted results of the GaussianNB
        model.fit(data_train[:size,:],target_train[:size])
        predicted = model.predict(data_train)
        
        #compute the validation error
        crossval_error[i] = compute_error(data_test,target_test,model)
        
        #compute the training error
        train_error[i] = compute_error(data_train[:size,:],target_train[:size],model)
       
    #draw the plot
    fig,ax = plt.subplots()
    ax.plot(sizes,crossval_error,lw = 2, label='cross validation error')
    ax.plot(sizes,train_error, lw = 2, label='training error')
    ax.set_xlabel('cross val error')
    ax.set_ylabel('rms error')
    
    ax.legend(loc = 0)
    ax.set_xlim(0,99)
    ax.set_title('Learning Curve' )
        
drawLearningCurve(clf)

3 comments:

  1. Great!
    Keep it up.
    Hoping to see more on this ;)

    ReplyDelete
  2. Thanks but can u please write a comment before each line which explain it

    ReplyDelete