import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yasa
import mne
from sklearn.model_selection import train_test_split, cross_validate, GroupKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn import metrics
from sleepstagingidal.data import *
from sleepstagingidal.dataa import *
from sleepstagingidal.feature_extraction import *
from sleepstagingidal.cross_validation import *
[REM/NO REM] Pre-extracted basic features
In this quick experiment we are going to utilize the pre-extracted features to perform a classification between REM and No REM stages.
Load and filter the data
The first thing we have to do is loading the features we have previously extracted:
= pd.read_csv(path_data, index_col=0)
df df.head()
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | ... | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | Label | Patient | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.495066 | 0.112756 | 0.112814 | 0.064382 | 0.165632 | 0.049350 | 0.362000 | 0.116360 | 0.171950 | 0.159524 | 0.138463 | 0.051704 | 0.543319 | 0.110169 | 0.118650 | 0.060939 | 0.117710 | 0.049213 | 0.476496 | 0.148164 | 0.147317 | 0.106843 | 0.090631 | 0.030549 | 0.531814 | 0.170641 | 0.118979 | 0.070764 | 0.080622 | 0.027180 | 0.475519 | 0.188196 | 0.133852 | 0.095128 | 0.080400 | 0.026905 | 0.527519 | 0.122290 | 0.133960 | 0.065811 | ... | 0.089403 | 0.032273 | 0.200669 | 0.291248 | 0.213199 | 0.142235 | 0.147633 | 0.005017 | 0.200751 | 0.292008 | 0.212910 | 0.141983 | 0.147257 | 0.005091 | 0.197748 | 0.173513 | 0.231294 | 0.181297 | 0.190539 | 0.025610 | 0.178459 | 0.252283 | 0.216593 | 0.158911 | 0.186102 | 0.007653 | 0.721550 | 0.071222 | 0.073576 | 0.043658 | 0.064126 | 0.025867 | 0.687305 | 0.084597 | 0.083275 | 0.047782 | 0.070681 | 0.026361 | Sleep stage W | PSG29.edf |
1 | 0.465074 | 0.117853 | 0.177654 | 0.101286 | 0.104765 | 0.033368 | 0.347451 | 0.205411 | 0.208832 | 0.093656 | 0.111676 | 0.032974 | 0.630552 | 0.109801 | 0.111556 | 0.065251 | 0.059425 | 0.023416 | 0.480078 | 0.183739 | 0.148267 | 0.069932 | 0.085786 | 0.032198 | 0.616104 | 0.143776 | 0.113066 | 0.057664 | 0.053632 | 0.015759 | 0.545867 | 0.166890 | 0.141956 | 0.065644 | 0.060475 | 0.019168 | 0.531884 | 0.154121 | 0.136999 | 0.072946 | ... | 0.092982 | 0.030478 | 0.206725 | 0.284727 | 0.203176 | 0.151529 | 0.147839 | 0.006005 | 0.205184 | 0.285613 | 0.203717 | 0.151011 | 0.148545 | 0.005929 | 0.357687 | 0.172565 | 0.202435 | 0.112405 | 0.134012 | 0.020895 | 0.186969 | 0.243792 | 0.201004 | 0.168602 | 0.191435 | 0.008197 | 0.756929 | 0.068393 | 0.065536 | 0.032712 | 0.050156 | 0.026275 | 0.811640 | 0.054457 | 0.050954 | 0.026253 | 0.037776 | 0.018920 | Sleep stage W | PSG29.edf |
2 | 0.493321 | 0.083727 | 0.160615 | 0.093394 | 0.127338 | 0.041605 | 0.509627 | 0.099740 | 0.174744 | 0.106081 | 0.083801 | 0.026007 | 0.597762 | 0.134673 | 0.127823 | 0.057744 | 0.060519 | 0.021479 | 0.519414 | 0.140630 | 0.167105 | 0.080623 | 0.068262 | 0.023966 | 0.510118 | 0.185070 | 0.156140 | 0.068210 | 0.062267 | 0.018194 | 0.534198 | 0.161565 | 0.153183 | 0.065650 | 0.065872 | 0.019532 | 0.586799 | 0.139837 | 0.121766 | 0.056820 | ... | 0.068092 | 0.022911 | 0.209975 | 0.288701 | 0.212362 | 0.143128 | 0.140391 | 0.005442 | 0.211023 | 0.287946 | 0.212669 | 0.142524 | 0.140397 | 0.005442 | 0.236724 | 0.155031 | 0.223585 | 0.181925 | 0.178715 | 0.024020 | 0.189490 | 0.254528 | 0.207994 | 0.163306 | 0.176389 | 0.008293 | 0.720894 | 0.077628 | 0.080070 | 0.042218 | 0.057003 | 0.022188 | 0.761949 | 0.065250 | 0.071764 | 0.036249 | 0.047086 | 0.017702 | Sleep stage W | PSG29.edf |
3 | 0.496456 | 0.078696 | 0.145985 | 0.073315 | 0.168728 | 0.036820 | 0.415612 | 0.150171 | 0.162990 | 0.084801 | 0.145732 | 0.040695 | 0.610062 | 0.156481 | 0.096959 | 0.047532 | 0.071230 | 0.017735 | 0.481011 | 0.199789 | 0.142475 | 0.060787 | 0.085889 | 0.030048 | 0.532062 | 0.180797 | 0.130148 | 0.055573 | 0.080988 | 0.020431 | 0.493871 | 0.184382 | 0.134594 | 0.071134 | 0.086786 | 0.029233 | 0.543152 | 0.187464 | 0.110963 | 0.051094 | ... | 0.080532 | 0.022249 | 0.183674 | 0.338252 | 0.201751 | 0.136399 | 0.135297 | 0.004626 | 0.188792 | 0.336624 | 0.200293 | 0.135144 | 0.134579 | 0.004568 | 0.236594 | 0.153612 | 0.215825 | 0.169071 | 0.198682 | 0.026217 | 0.159044 | 0.288443 | 0.207075 | 0.160674 | 0.177189 | 0.007575 | 0.583039 | 0.117859 | 0.100131 | 0.045193 | 0.102050 | 0.051728 | 0.657698 | 0.107005 | 0.083062 | 0.039670 | 0.079945 | 0.032621 | Sleep stage W | PSG29.edf |
4 | 0.499096 | 0.090408 | 0.128249 | 0.119394 | 0.126275 | 0.036577 | 0.355009 | 0.114010 | 0.227973 | 0.142860 | 0.120449 | 0.039700 | 0.665468 | 0.097059 | 0.102477 | 0.055994 | 0.057387 | 0.021614 | 0.567872 | 0.108560 | 0.147445 | 0.074687 | 0.072923 | 0.028513 | 0.649271 | 0.120768 | 0.096039 | 0.061096 | 0.056664 | 0.016162 | 0.606666 | 0.124536 | 0.117359 | 0.064167 | 0.063759 | 0.023512 | 0.545407 | 0.135738 | 0.135981 | 0.053538 | ... | 0.102558 | 0.032455 | 0.191534 | 0.316908 | 0.221889 | 0.128541 | 0.135913 | 0.005214 | 0.188943 | 0.317868 | 0.222645 | 0.128482 | 0.136874 | 0.005188 | 0.316837 | 0.164763 | 0.211233 | 0.130939 | 0.150966 | 0.025262 | 0.172198 | 0.274478 | 0.218936 | 0.148539 | 0.178593 | 0.007258 | 0.718889 | 0.060736 | 0.083312 | 0.042121 | 0.064617 | 0.030325 | 0.728514 | 0.062010 | 0.082649 | 0.035992 | 0.059911 | 0.030924 | Sleep stage W | PSG29.edf |
5 rows × 86 columns
Adding to this, we are going to filter our data to keep only the complete patients. We can do so using the info.csv
file we have previously created:
= pd.read_csv(path_info)
df_info df_info.head()
File | DifferentStages | Complete | |
---|---|---|---|
0 | PSG29.edf | 3 | False |
1 | PSG12.edf | 5 | True |
2 | PSG17.edf | 5 | True |
3 | PSG10.edf | 4 | False |
4 | PSG23.edf | 3 | False |
If we keep only the complete patients from this dataframe and join it with the other one, we can filter out the incomplete patients to train only with the complete ones:
= df_info[df_info.Complete]
df_info_complete len(df_info_complete)
30
We see that we are left with 30 out of 36 patients. Next we’ll join both dataframes:
= df.merge(right=df_info_complete, how="right", left_on="Patient", right_on="File")
df_complete = df_complete.drop(["Patient", "DifferentStages", "Complete"], axis=1)
df_complete df.shape, df_complete.shape
((27680, 86), (24121, 86))
df_complete.head()
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | ... | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | Label | File | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.828777 | 0.111004 | 0.036347 | 0.013696 | 0.008029 | 0.002147 | 0.866095 | 0.077208 | 0.029721 | 0.013948 | 0.010277 | 0.002751 | 0.893198 | 0.075901 | 0.015894 | 0.007203 | 0.006068 | 0.001736 | 0.897216 | 0.066448 | 0.016855 | 0.006646 | 0.010521 | 0.002313 | 0.839334 | 0.129764 | 0.020567 | 0.005832 | 0.003590 | 0.000913 | 0.874244 | 0.099975 | 0.015362 | 0.006066 | 0.003442 | 0.000911 | 0.871306 | 0.087882 | 0.021908 | 0.007706 | ... | 0.007495 | 0.002202 | 0.559114 | 0.187103 | 0.049373 | 0.076654 | 0.121630 | 0.006125 | 0.569851 | 0.184307 | 0.048731 | 0.073804 | 0.117428 | 0.005878 | 0.599962 | 0.072144 | 0.036558 | 0.069310 | 0.177553 | 0.044474 | 0.314669 | 0.139663 | 0.104543 | 0.142964 | 0.274585 | 0.023577 | 0.863405 | 0.085994 | 0.017926 | 0.008112 | 0.013523 | 0.011039 | 0.842687 | 0.082653 | 0.018246 | 0.008742 | 0.021826 | 0.025847 | Sleep stage W | PSG12.edf |
1 | 0.829209 | 0.104899 | 0.032051 | 0.022996 | 0.008892 | 0.001952 | 0.822251 | 0.105734 | 0.035186 | 0.024547 | 0.009898 | 0.002384 | 0.895753 | 0.062918 | 0.021768 | 0.011074 | 0.006787 | 0.001699 | 0.894197 | 0.061178 | 0.017440 | 0.012895 | 0.012034 | 0.002256 | 0.857363 | 0.103283 | 0.023190 | 0.010939 | 0.004160 | 0.001065 | 0.883609 | 0.079566 | 0.021131 | 0.010778 | 0.003813 | 0.001103 | 0.885997 | 0.071112 | 0.021702 | 0.012148 | ... | 0.007887 | 0.002063 | 0.545761 | 0.190174 | 0.052838 | 0.079403 | 0.126536 | 0.005288 | 0.550894 | 0.191843 | 0.051842 | 0.077397 | 0.122869 | 0.005155 | 0.589353 | 0.073115 | 0.034576 | 0.073617 | 0.185505 | 0.043835 | 0.287737 | 0.143439 | 0.113177 | 0.149858 | 0.284439 | 0.021351 | 0.897045 | 0.053111 | 0.016674 | 0.010652 | 0.013242 | 0.009276 | 0.873665 | 0.057653 | 0.016445 | 0.013007 | 0.019579 | 0.019652 | Sleep stage W | PSG12.edf |
2 | 0.842406 | 0.100744 | 0.029417 | 0.019372 | 0.006586 | 0.001473 | 0.870060 | 0.081720 | 0.022897 | 0.017082 | 0.006613 | 0.001629 | 0.909758 | 0.051352 | 0.021887 | 0.010397 | 0.005378 | 0.001229 | 0.921258 | 0.047191 | 0.015626 | 0.007201 | 0.007109 | 0.001616 | 0.895852 | 0.069141 | 0.022387 | 0.008700 | 0.003240 | 0.000680 | 0.905155 | 0.062947 | 0.019297 | 0.008643 | 0.003221 | 0.000736 | 0.904595 | 0.058342 | 0.020455 | 0.009654 | ... | 0.006376 | 0.001663 | 0.536013 | 0.199818 | 0.045428 | 0.075261 | 0.137987 | 0.005494 | 0.550753 | 0.195377 | 0.043902 | 0.072385 | 0.132292 | 0.005291 | 0.638595 | 0.061563 | 0.034590 | 0.063653 | 0.165439 | 0.036159 | 0.294978 | 0.148989 | 0.095591 | 0.137203 | 0.301252 | 0.021988 | 0.889069 | 0.060842 | 0.020352 | 0.010035 | 0.011098 | 0.008603 | 0.869873 | 0.061609 | 0.020109 | 0.010276 | 0.017389 | 0.020744 | Sleep stage W | PSG12.edf |
3 | 0.826268 | 0.114365 | 0.034281 | 0.016519 | 0.006972 | 0.001596 | 0.834174 | 0.086591 | 0.045124 | 0.020168 | 0.011114 | 0.002829 | 0.886998 | 0.068077 | 0.027164 | 0.009995 | 0.006277 | 0.001489 | 0.899655 | 0.052545 | 0.023429 | 0.009772 | 0.011983 | 0.002616 | 0.852234 | 0.110114 | 0.023666 | 0.009198 | 0.003865 | 0.000923 | 0.885270 | 0.079228 | 0.022851 | 0.008011 | 0.003732 | 0.000907 | 0.893624 | 0.060173 | 0.029191 | 0.008671 | ... | 0.007102 | 0.002081 | 0.558408 | 0.174786 | 0.052194 | 0.077200 | 0.131827 | 0.005585 | 0.563630 | 0.175305 | 0.051710 | 0.075393 | 0.128516 | 0.005446 | 0.645756 | 0.058029 | 0.035434 | 0.060650 | 0.158191 | 0.041940 | 0.283484 | 0.138802 | 0.114658 | 0.142425 | 0.298707 | 0.021923 | 0.887805 | 0.058118 | 0.023766 | 0.008757 | 0.012120 | 0.009433 | 0.873901 | 0.055302 | 0.022631 | 0.009176 | 0.018615 | 0.020375 | Sleep stage W | PSG12.edf |
4 | 0.892040 | 0.072569 | 0.016537 | 0.013553 | 0.004228 | 0.001073 | 0.887429 | 0.069672 | 0.017843 | 0.017650 | 0.005630 | 0.001776 | 0.949596 | 0.032758 | 0.009363 | 0.005098 | 0.002614 | 0.000571 | 0.924791 | 0.048967 | 0.010825 | 0.007579 | 0.006528 | 0.001310 | 0.928209 | 0.053080 | 0.011061 | 0.004976 | 0.002208 | 0.000466 | 0.927354 | 0.051554 | 0.012094 | 0.006003 | 0.002475 | 0.000520 | 0.937408 | 0.039077 | 0.012161 | 0.006905 | ... | 0.006053 | 0.001503 | 0.549693 | 0.203476 | 0.043815 | 0.072618 | 0.124463 | 0.005936 | 0.558714 | 0.200516 | 0.043050 | 0.071087 | 0.120879 | 0.005754 | 0.723112 | 0.048052 | 0.024039 | 0.054247 | 0.118286 | 0.032265 | 0.305310 | 0.156931 | 0.093174 | 0.134177 | 0.286451 | 0.023958 | 0.929078 | 0.041854 | 0.010671 | 0.006873 | 0.006212 | 0.005312 | 0.914073 | 0.043103 | 0.011320 | 0.007723 | 0.010389 | 0.013391 | Sleep stage W | PSG12.edf |
5 rows × 86 columns
And we see that we have removed the corresponding 3559 rows corresponding to incomplete patients.
Defining the metrics
AUC can be a good metrics when working with binary problems.
To measure the performance of our classifier, we’re going to use the AUC and the accuracy, the latter just for completion. Because we want to obtain the most realistic metric possible, we are going to perform a Patient-Fold where we train with all the patients and then test in a completelly different one. This will give us a good estimate of the generalization performance of our pipeline.
This can be achieved using cross_validate
in conjunction with GroupKFold
and setting the groups
parameter to be the File
column in our previous dataframe:
= {
metrics "Accuracy": "accuracy",
"AUC": "roc_auc",
}
= cross_validate(RandomForestClassifier(),
cvg =df_complete.drop(["Label", "File"], axis=1),
X=df_complete["Label"]=="Sleep stage R",
y=metrics,
scoring=True,
return_train_score=GroupKFold(n_splits=len(df_complete.File.unique())),
cv=df_complete["File"],
groups=5) n_jobs
CPU times: user 146 ms, sys: 104 ms, total: 250 ms
Wall time: 2min 37s
cvg
{'fit_time': array([27.03762555, 26.27106881, 26.99674702, 26.86529589, 26.24467683,
25.45121002, 23.67596412, 24.07731962, 24.42808008, 24.3764236 ,
25.62880945, 25.02343202, 24.51707554, 25.30792546, 25.24074316,
24.36477757, 24.58056784, 25.27363634, 24.7898047 , 24.37082696,
23.90679312, 24.51267362, 24.12866235, 24.50727344, 24.42071271,
24.26417112, 24.38420892, 24.73950148, 24.41905499, 24.89244962]),
'score_time': array([0.06685781, 0.06159592, 0.05747652, 0.05841327, 0.0619328 ,
0.05348492, 0.0628283 , 0.05884314, 0.06110263, 0.05982661,
0.05692792, 0.05831432, 0.05882573, 0.05572987, 0.06257844,
0.05658269, 0.05508018, 0.05524349, 0.0544374 , 0.05899739,
0.05471539, 0.05301929, 0.05601811, 0.05660391, 0.05370593,
0.05541277, 0.05205274, 0.05044127, 0.05035877, 0.04688859]),
'test_Accuracy': array([0.91080797, 0.94065934, 0.88548753, 0.904 , 0.86869871,
0.9144197 , 0.87573271, 0.90973036, 0.93184489, 0.81560284,
0.85 , 0.91746411, 0.87860577, 0.80853659, 0.81295844,
0.96924969, 0.75709001, 0.89440994, 0.98461538, 0.97293814,
0.87696335, 0.90944882, 0.94218134, 0.92368421, 0.91125828,
0.85066667, 0.95392954, 0.86573427, 0.83875969, 0.9492635 ]),
'train_Accuracy': array([1. , 0.99995692, 1. , 1. , 0.99995702,
1. , 1. , 1. , 1. , 0.99995704,
0.99995705, 1. , 1. , 1. , 0.99995709,
1. , 1. , 1. , 1. , 1. ,
0.99995719, 1. , 1. , 1. , 1. ,
1. , 1. , 1. , 1. , 0.99995746]),
'test_AUC': array([0.50650465, 0.41852717, 0.80062372, 0.8343038 , 0.82448892,
0.75410263, 0.85782295, 0.94624178, 0.73788735, 0.77484245,
0.91584851, 0.24220774, 0.69733775, 0.40263327, 0.31369251,
0.45761421, 0.67280278, 0.83897209, 0.31743707, 0.32397487,
0.48285964, 0.66438714, 0.35597914, 0.68175239, 0.36966178,
0.8807308 , 0.30128259, 0.76383847, 0.65165771, 0.51384872]),
'train_AUC': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])}
'test_AUC'], cvg['test_Accuracy']])
plt.boxplot([cvg[range(3), [" ", "AUC", "Accuracy"])
plt.xticks("Results of performing Patient-Fold with default RandomForest [REM/No REM]")
plt.title( plt.show()