[REM/NO REM] Pre-extracted basic features

Quick experiment trying to predict wether a sleep stage is REM or not using pre-extracted features.

In this quick experiment we are going to utilize the pre-extracted features to perform a classification between REM and No REM stages.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import yasa
import mne

from sklearn.model_selection import train_test_split, cross_validate, GroupKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn import metrics

from sleepstagingidal.data import *
from sleepstagingidal.dataa import *
from sleepstagingidal.feature_extraction import *
from sleepstagingidal.cross_validation import *

Load and filter the data

The first thing we have to do is loading the features we have previously extracted:

df = pd.read_csv(path_data, index_col=0)
df.head()
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 ... 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 Label Patient
0 0.495066 0.112756 0.112814 0.064382 0.165632 0.049350 0.362000 0.116360 0.171950 0.159524 0.138463 0.051704 0.543319 0.110169 0.118650 0.060939 0.117710 0.049213 0.476496 0.148164 0.147317 0.106843 0.090631 0.030549 0.531814 0.170641 0.118979 0.070764 0.080622 0.027180 0.475519 0.188196 0.133852 0.095128 0.080400 0.026905 0.527519 0.122290 0.133960 0.065811 ... 0.089403 0.032273 0.200669 0.291248 0.213199 0.142235 0.147633 0.005017 0.200751 0.292008 0.212910 0.141983 0.147257 0.005091 0.197748 0.173513 0.231294 0.181297 0.190539 0.025610 0.178459 0.252283 0.216593 0.158911 0.186102 0.007653 0.721550 0.071222 0.073576 0.043658 0.064126 0.025867 0.687305 0.084597 0.083275 0.047782 0.070681 0.026361 Sleep stage W PSG29.edf
1 0.465074 0.117853 0.177654 0.101286 0.104765 0.033368 0.347451 0.205411 0.208832 0.093656 0.111676 0.032974 0.630552 0.109801 0.111556 0.065251 0.059425 0.023416 0.480078 0.183739 0.148267 0.069932 0.085786 0.032198 0.616104 0.143776 0.113066 0.057664 0.053632 0.015759 0.545867 0.166890 0.141956 0.065644 0.060475 0.019168 0.531884 0.154121 0.136999 0.072946 ... 0.092982 0.030478 0.206725 0.284727 0.203176 0.151529 0.147839 0.006005 0.205184 0.285613 0.203717 0.151011 0.148545 0.005929 0.357687 0.172565 0.202435 0.112405 0.134012 0.020895 0.186969 0.243792 0.201004 0.168602 0.191435 0.008197 0.756929 0.068393 0.065536 0.032712 0.050156 0.026275 0.811640 0.054457 0.050954 0.026253 0.037776 0.018920 Sleep stage W PSG29.edf
2 0.493321 0.083727 0.160615 0.093394 0.127338 0.041605 0.509627 0.099740 0.174744 0.106081 0.083801 0.026007 0.597762 0.134673 0.127823 0.057744 0.060519 0.021479 0.519414 0.140630 0.167105 0.080623 0.068262 0.023966 0.510118 0.185070 0.156140 0.068210 0.062267 0.018194 0.534198 0.161565 0.153183 0.065650 0.065872 0.019532 0.586799 0.139837 0.121766 0.056820 ... 0.068092 0.022911 0.209975 0.288701 0.212362 0.143128 0.140391 0.005442 0.211023 0.287946 0.212669 0.142524 0.140397 0.005442 0.236724 0.155031 0.223585 0.181925 0.178715 0.024020 0.189490 0.254528 0.207994 0.163306 0.176389 0.008293 0.720894 0.077628 0.080070 0.042218 0.057003 0.022188 0.761949 0.065250 0.071764 0.036249 0.047086 0.017702 Sleep stage W PSG29.edf
3 0.496456 0.078696 0.145985 0.073315 0.168728 0.036820 0.415612 0.150171 0.162990 0.084801 0.145732 0.040695 0.610062 0.156481 0.096959 0.047532 0.071230 0.017735 0.481011 0.199789 0.142475 0.060787 0.085889 0.030048 0.532062 0.180797 0.130148 0.055573 0.080988 0.020431 0.493871 0.184382 0.134594 0.071134 0.086786 0.029233 0.543152 0.187464 0.110963 0.051094 ... 0.080532 0.022249 0.183674 0.338252 0.201751 0.136399 0.135297 0.004626 0.188792 0.336624 0.200293 0.135144 0.134579 0.004568 0.236594 0.153612 0.215825 0.169071 0.198682 0.026217 0.159044 0.288443 0.207075 0.160674 0.177189 0.007575 0.583039 0.117859 0.100131 0.045193 0.102050 0.051728 0.657698 0.107005 0.083062 0.039670 0.079945 0.032621 Sleep stage W PSG29.edf
4 0.499096 0.090408 0.128249 0.119394 0.126275 0.036577 0.355009 0.114010 0.227973 0.142860 0.120449 0.039700 0.665468 0.097059 0.102477 0.055994 0.057387 0.021614 0.567872 0.108560 0.147445 0.074687 0.072923 0.028513 0.649271 0.120768 0.096039 0.061096 0.056664 0.016162 0.606666 0.124536 0.117359 0.064167 0.063759 0.023512 0.545407 0.135738 0.135981 0.053538 ... 0.102558 0.032455 0.191534 0.316908 0.221889 0.128541 0.135913 0.005214 0.188943 0.317868 0.222645 0.128482 0.136874 0.005188 0.316837 0.164763 0.211233 0.130939 0.150966 0.025262 0.172198 0.274478 0.218936 0.148539 0.178593 0.007258 0.718889 0.060736 0.083312 0.042121 0.064617 0.030325 0.728514 0.062010 0.082649 0.035992 0.059911 0.030924 Sleep stage W PSG29.edf

5 rows × 86 columns

Adding to this, we are going to filter our data to keep only the complete patients. We can do so using the info.csv file we have previously created:

df_info = pd.read_csv(path_info)
df_info.head()
File DifferentStages Complete
0 PSG29.edf 3 False
1 PSG12.edf 5 True
2 PSG17.edf 5 True
3 PSG10.edf 4 False
4 PSG23.edf 3 False

If we keep only the complete patients from this dataframe and join it with the other one, we can filter out the incomplete patients to train only with the complete ones:

df_info_complete = df_info[df_info.Complete]
len(df_info_complete)
30

We see that we are left with 30 out of 36 patients. Next we’ll join both dataframes:

df_complete = df.merge(right=df_info_complete, how="right", left_on="Patient", right_on="File")
df_complete = df_complete.drop(["Patient", "DifferentStages", "Complete"], axis=1)
df.shape, df_complete.shape
((27680, 86), (24121, 86))
df_complete.head()
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 ... 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 Label File
0 0.828777 0.111004 0.036347 0.013696 0.008029 0.002147 0.866095 0.077208 0.029721 0.013948 0.010277 0.002751 0.893198 0.075901 0.015894 0.007203 0.006068 0.001736 0.897216 0.066448 0.016855 0.006646 0.010521 0.002313 0.839334 0.129764 0.020567 0.005832 0.003590 0.000913 0.874244 0.099975 0.015362 0.006066 0.003442 0.000911 0.871306 0.087882 0.021908 0.007706 ... 0.007495 0.002202 0.559114 0.187103 0.049373 0.076654 0.121630 0.006125 0.569851 0.184307 0.048731 0.073804 0.117428 0.005878 0.599962 0.072144 0.036558 0.069310 0.177553 0.044474 0.314669 0.139663 0.104543 0.142964 0.274585 0.023577 0.863405 0.085994 0.017926 0.008112 0.013523 0.011039 0.842687 0.082653 0.018246 0.008742 0.021826 0.025847 Sleep stage W PSG12.edf
1 0.829209 0.104899 0.032051 0.022996 0.008892 0.001952 0.822251 0.105734 0.035186 0.024547 0.009898 0.002384 0.895753 0.062918 0.021768 0.011074 0.006787 0.001699 0.894197 0.061178 0.017440 0.012895 0.012034 0.002256 0.857363 0.103283 0.023190 0.010939 0.004160 0.001065 0.883609 0.079566 0.021131 0.010778 0.003813 0.001103 0.885997 0.071112 0.021702 0.012148 ... 0.007887 0.002063 0.545761 0.190174 0.052838 0.079403 0.126536 0.005288 0.550894 0.191843 0.051842 0.077397 0.122869 0.005155 0.589353 0.073115 0.034576 0.073617 0.185505 0.043835 0.287737 0.143439 0.113177 0.149858 0.284439 0.021351 0.897045 0.053111 0.016674 0.010652 0.013242 0.009276 0.873665 0.057653 0.016445 0.013007 0.019579 0.019652 Sleep stage W PSG12.edf
2 0.842406 0.100744 0.029417 0.019372 0.006586 0.001473 0.870060 0.081720 0.022897 0.017082 0.006613 0.001629 0.909758 0.051352 0.021887 0.010397 0.005378 0.001229 0.921258 0.047191 0.015626 0.007201 0.007109 0.001616 0.895852 0.069141 0.022387 0.008700 0.003240 0.000680 0.905155 0.062947 0.019297 0.008643 0.003221 0.000736 0.904595 0.058342 0.020455 0.009654 ... 0.006376 0.001663 0.536013 0.199818 0.045428 0.075261 0.137987 0.005494 0.550753 0.195377 0.043902 0.072385 0.132292 0.005291 0.638595 0.061563 0.034590 0.063653 0.165439 0.036159 0.294978 0.148989 0.095591 0.137203 0.301252 0.021988 0.889069 0.060842 0.020352 0.010035 0.011098 0.008603 0.869873 0.061609 0.020109 0.010276 0.017389 0.020744 Sleep stage W PSG12.edf
3 0.826268 0.114365 0.034281 0.016519 0.006972 0.001596 0.834174 0.086591 0.045124 0.020168 0.011114 0.002829 0.886998 0.068077 0.027164 0.009995 0.006277 0.001489 0.899655 0.052545 0.023429 0.009772 0.011983 0.002616 0.852234 0.110114 0.023666 0.009198 0.003865 0.000923 0.885270 0.079228 0.022851 0.008011 0.003732 0.000907 0.893624 0.060173 0.029191 0.008671 ... 0.007102 0.002081 0.558408 0.174786 0.052194 0.077200 0.131827 0.005585 0.563630 0.175305 0.051710 0.075393 0.128516 0.005446 0.645756 0.058029 0.035434 0.060650 0.158191 0.041940 0.283484 0.138802 0.114658 0.142425 0.298707 0.021923 0.887805 0.058118 0.023766 0.008757 0.012120 0.009433 0.873901 0.055302 0.022631 0.009176 0.018615 0.020375 Sleep stage W PSG12.edf
4 0.892040 0.072569 0.016537 0.013553 0.004228 0.001073 0.887429 0.069672 0.017843 0.017650 0.005630 0.001776 0.949596 0.032758 0.009363 0.005098 0.002614 0.000571 0.924791 0.048967 0.010825 0.007579 0.006528 0.001310 0.928209 0.053080 0.011061 0.004976 0.002208 0.000466 0.927354 0.051554 0.012094 0.006003 0.002475 0.000520 0.937408 0.039077 0.012161 0.006905 ... 0.006053 0.001503 0.549693 0.203476 0.043815 0.072618 0.124463 0.005936 0.558714 0.200516 0.043050 0.071087 0.120879 0.005754 0.723112 0.048052 0.024039 0.054247 0.118286 0.032265 0.305310 0.156931 0.093174 0.134177 0.286451 0.023958 0.929078 0.041854 0.010671 0.006873 0.006212 0.005312 0.914073 0.043103 0.011320 0.007723 0.010389 0.013391 Sleep stage W PSG12.edf

5 rows × 86 columns

And we see that we have removed the corresponding 3559 rows corresponding to incomplete patients.

Defining the metrics

AUC can be a good metrics when working with binary problems.

To measure the performance of our classifier, we’re going to use the AUC and the accuracy, the latter just for completion. Because we want to obtain the most realistic metric possible, we are going to perform a Patient-Fold where we train with all the patients and then test in a completelly different one. This will give us a good estimate of the generalization performance of our pipeline.

This can be achieved using cross_validate in conjunction with GroupKFold and setting the groups parameter to be the File column in our previous dataframe:

metrics = {
    "Accuracy": "accuracy",
    "AUC": "roc_auc",
}
cvg = cross_validate(RandomForestClassifier(), 
                     X=df_complete.drop(["Label", "File"], axis=1),
                     y=df_complete["Label"]=="Sleep stage R",
                     scoring=metrics,
                     return_train_score=True,
                     cv=GroupKFold(n_splits=len(df_complete.File.unique())),
                     groups=df_complete["File"],
                     n_jobs=5)
CPU times: user 146 ms, sys: 104 ms, total: 250 ms
Wall time: 2min 37s
cvg
{'fit_time': array([27.03762555, 26.27106881, 26.99674702, 26.86529589, 26.24467683,
        25.45121002, 23.67596412, 24.07731962, 24.42808008, 24.3764236 ,
        25.62880945, 25.02343202, 24.51707554, 25.30792546, 25.24074316,
        24.36477757, 24.58056784, 25.27363634, 24.7898047 , 24.37082696,
        23.90679312, 24.51267362, 24.12866235, 24.50727344, 24.42071271,
        24.26417112, 24.38420892, 24.73950148, 24.41905499, 24.89244962]),
 'score_time': array([0.06685781, 0.06159592, 0.05747652, 0.05841327, 0.0619328 ,
        0.05348492, 0.0628283 , 0.05884314, 0.06110263, 0.05982661,
        0.05692792, 0.05831432, 0.05882573, 0.05572987, 0.06257844,
        0.05658269, 0.05508018, 0.05524349, 0.0544374 , 0.05899739,
        0.05471539, 0.05301929, 0.05601811, 0.05660391, 0.05370593,
        0.05541277, 0.05205274, 0.05044127, 0.05035877, 0.04688859]),
 'test_Accuracy': array([0.91080797, 0.94065934, 0.88548753, 0.904     , 0.86869871,
        0.9144197 , 0.87573271, 0.90973036, 0.93184489, 0.81560284,
        0.85      , 0.91746411, 0.87860577, 0.80853659, 0.81295844,
        0.96924969, 0.75709001, 0.89440994, 0.98461538, 0.97293814,
        0.87696335, 0.90944882, 0.94218134, 0.92368421, 0.91125828,
        0.85066667, 0.95392954, 0.86573427, 0.83875969, 0.9492635 ]),
 'train_Accuracy': array([1.        , 0.99995692, 1.        , 1.        , 0.99995702,
        1.        , 1.        , 1.        , 1.        , 0.99995704,
        0.99995705, 1.        , 1.        , 1.        , 0.99995709,
        1.        , 1.        , 1.        , 1.        , 1.        ,
        0.99995719, 1.        , 1.        , 1.        , 1.        ,
        1.        , 1.        , 1.        , 1.        , 0.99995746]),
 'test_AUC': array([0.50650465, 0.41852717, 0.80062372, 0.8343038 , 0.82448892,
        0.75410263, 0.85782295, 0.94624178, 0.73788735, 0.77484245,
        0.91584851, 0.24220774, 0.69733775, 0.40263327, 0.31369251,
        0.45761421, 0.67280278, 0.83897209, 0.31743707, 0.32397487,
        0.48285964, 0.66438714, 0.35597914, 0.68175239, 0.36966178,
        0.8807308 , 0.30128259, 0.76383847, 0.65165771, 0.51384872]),
 'train_AUC': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])}
plt.boxplot([cvg['test_AUC'], cvg['test_Accuracy']])
plt.xticks(range(3), [" ", "AUC", "Accuracy"])
plt.title("Results of performing Patient-Fold with default RandomForest [REM/No REM]")
plt.show()