Commit af8e969b authored by Weng's avatar Weng
Browse files

Upload New File

parent 3fb053f7
%% Cell type:markdown id:758926cd tags:
# Random Survival Trees für BU-Tafeln
%% Cell type:code id:df8cdf81 tags:
``` python
#Benötigte Bibliotheken
import pandas as pd
import numpy as np
from random import randint
from sklearn.model_selection import train_test_split
from sksurv.tree import SurvivalTree
from sksurv.tree import SurvivalTreeTruncated
from sksurv.datasets.base import _get_x_y_survival
from sksurv.datasets.base import _get_x_y_survival_truncated
from sksurv.ensemble import RandomSurvivalForest
from sksurv.ensemble import RandomSurvivalForestTruncated
from export import print_tree
import eli5
from eli5.sklearn import PermutationImportance
```
%% Cell type:code id:78289b58 tags:
``` python
#Einlesen der Daten
path="KleinerBUBestand_Test.csv"
data = pd.read_csv(path)
data_train_ges,data_test_ges=train_test_split(data,test_size=0.2, random_state=20)
del data['Unnamed: 0']
```
%% Cell type:code id:3ed487a8 tags:
``` python
data_ges=data
```
%% Cell type:code id:869c86e3 tags:
``` python
x_train_2, y_train_2=_get_x_y_survival(data_train_ges, "Ereignis", "Austrittsalter", 1)
x_train_2.drop(['Eintrittsalter'],axis = 1, inplace = True)
x_test_2, y_test_2=_get_x_y_survival(data_test_ges, "Ereignis", "Austrittsalter", 1)
x_test_2.drop(['Eintrittsalter'],axis = 1, inplace = True)
#links-abgeschnitten und rechts-zensiert
x_train_3, y_train_3=_get_x_y_survival_truncated(data_train_ges, "Ereignis","Eintrittsalter", "Austrittsalter", 1)
x_test_3, y_test_3=_get_x_y_survival_truncated(data_test_ges, "Ereignis", "Eintrittsalter", "Austrittsalter", 1)
```
%% Cell type:code id:269f3b38 tags:
``` python
from lifelines import NelsonAalenFitter
import matplotlib.pyplot as plt
nf1 = NelsonAalenFitter(nelson_aalen_smoothing=False)
nf2 = NelsonAalenFitter(nelson_aalen_smoothing=False)
data_female=data_ges[data_ges['Geschlecht']==1]
data_male=data_ges[data_ges['Geschlecht']==2]
nf1.fit(data_male['Austrittsalter'], data_male['Ereignis'],entry=data_male['Eintrittsalter'])
nf2.fit(data_female['Austrittsalter'], data_female['Ereignis'],entry=data_female['Eintrittsalter'])
#in Hazardfunktion umrechnen
s1=nf1.cumulative_hazard_.to_numpy().reshape(43)
t1=np.diff(s1)
s2=nf2.cumulative_hazard_.to_numpy().reshape(43)
t2=np.diff(s2)
data_DAV = pd.read_csv("Daten/DAV_Tafeln.csv",sep=";",decimal=",")
p_f=data_DAV['DAV Female'].to_numpy(dtype='float')[0:36]
p_m=data_DAV['DAV Male'].to_numpy(dtype='float')[0:36]
plt.plot(t1,color="red",label="Männer Daten")
plt.plot(p_m,color="blue",label="Männer DAV")
plt.plot(t2,color="green",label="Frauen Daten")
plt.plot(p_f,color="orange",label="Frauen DAV")
plt.title("Vergleich Daten mit DAV-Tafel")
plt.show()
```
%%%% Output: display_data
![]()
%% Cell type:markdown id:b9997db0 tags:
## Survival Trees
%% Cell type:code id:e5896a4a tags:
``` python
estimator = SurvivalTreeTruncated(max_depth=5).fit(x_train_3,y_train_3)
estimator.score(x_test_3,y_test_3)
```
%%%% Output: execute_result
0.6693620440194961
%% Cell type:code id:9d625610 tags:
``` python
print_tree(estimator)
```
%%%% Output: stream
The binary tree structure has 39 nodes and has the following tree structure:
node=0 is a split node: go to node 1 if X[:, 2] <= 1.5 else to node 12.
node=1 is a split node: go to node 2 if X[:, 0] <= 36.0 else to node 3.
node=2 is a leaf node.
node=3 is a split node: go to node 4 if X[:, 0] <= 19973.5 else to node 11.
node=4 is a split node: go to node 5 if X[:, 6] <= 6.5 else to node 8.
node=5 is a split node: go to node 6 if X[:, 0] <= 13581.0 else to node 7.
node=6 is a leaf node.
node=7 is a leaf node.
node=8 is a split node: go to node 9 if X[:, 0] <= 11261.0 else to node 10.
node=9 is a leaf node.
node=10 is a leaf node.
node=11 is a leaf node.
node=12 is a split node: go to node 13 if X[:, 5] <= 1.5 else to node 26.
node=13 is a split node: go to node 14 if X[:, 3] <= 1.5 else to node 21.
node=14 is a split node: go to node 15 if X[:, 0] <= 19525.5 else to node 18.
node=15 is a split node: go to node 16 if X[:, 0] <= 16218.5 else to node 17.
node=16 is a leaf node.
node=17 is a leaf node.
node=18 is a split node: go to node 19 if X[:, 0] <= 19537.0 else to node 20.
node=19 is a leaf node.
node=20 is a leaf node.
node=21 is a split node: go to node 22 if X[:, 0] <= 6871.5 else to node 23.
node=22 is a leaf node.
node=23 is a split node: go to node 24 if X[:, 2] <= 3.5 else to node 25.
node=24 is a leaf node.
node=25 is a leaf node.
node=26 is a split node: go to node 27 if X[:, 5] <= 2.5 else to node 34.
node=27 is a split node: go to node 28 if X[:, 1] <= 0.5 else to node 31.
node=28 is a split node: go to node 29 if X[:, 2] <= 2.5 else to node 30.
node=29 is a leaf node.
node=30 is a leaf node.
node=31 is a split node: go to node 32 if X[:, 1] <= 2.5 else to node 33.
node=32 is a leaf node.
node=33 is a leaf node.
node=34 is a split node: go to node 35 if X[:, 4] <= 3.5 else to node 38.
node=35 is a split node: go to node 36 if X[:, 1] <= 1.5 else to node 37.
node=36 is a leaf node.
node=37 is a leaf node.
node=38 is a leaf node.
%% Cell type:markdown id:8f961248 tags:
## Random Survival Forest
%% Cell type:code id:fed384e2 tags:
``` python
random_state=20
rsf2=RandomSurvivalForestTruncated(n_estimators=10,max_depth=5,max_features="sqrt",n_jobs=2,random_state=random_state)
rsf2.fit(x_train_3, y_train_3)
rsf2.score(x_test_3,y_test_3)
#rsf2.score(x_train_3,y_train_3)
```
%%%% Output: execute_result
0.6730808979923856
%% Cell type:code id:8a0b4a40 tags:
``` python
#import joblib
#Nützliche Befehler, wenn man das Modell abspeichern bzw. laden möchte:
#joblib.dump(rsf2, "./random_forest_truncated.joblib")
#loaded_rf_truncated = joblib.load("./random_forest_truncated.joblib")
```
%% Cell type:code id:004eade0 tags:
``` python
perm=PermutationImportance(rsf2,n_iter=20,random_state=random_state)
perm.fit(x_test_3, y_test_3)
eli5.show_weights(perm,feature_names=x_test_3.columns.to_list())
```
%%%% Output: execute_result
<IPython.core.display.HTML object>
%% Cell type:markdown id:dde0d3ce tags:
## Prognose
%% Cell type:code id:cb538b34 tags:
``` python
#Teilmenge definieren
subset=data_test_ges
data_subset_x=x_test_3
```
%% Cell type:code id:216929d3 tags:
``` python
#Modellprognose
table= rsf2.predict_cumulative_hazard_function(data_subset_x,return_array=True)
curve=table.mean(axis=0)
curve=np.diff(curve)[0:40]
#Sterblichkeiten aus Nelson-Aalen-Estimator für Teilbestand
nf = NelsonAalenFitter(nelson_aalen_smoothing=False)
nf.fit(subset['Austrittsalter'], subset['Ereignis'],entry=subset['Eintrittsalter'])
s2=nf.cumulative_hazard_.to_numpy().reshape(43)
t2=np.diff(s2)
```
%% Cell type:code id:fa3bebaf tags:
``` python
plt.plot(np.arange(20,55), curve[0:35],'b',label='Modell')
plt.plot(np.arange(20,55), t2[0:35],'r',label='Daten')
plt.legend()
plt.title("Inzidenzwahrscheinlichkeiten im Vergleich")
plt.show()
```
%%%% Output: display_data
![]()
%% Cell type:code id:17ee73cb tags:
``` python
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment