Upload server.py

#1
by Nos7 - opened
Files changed (1) hide show
  1. server.py +42 -55
server.py CHANGED
@@ -3,6 +3,8 @@ import pandas as pd
3
  import seaborn as sns
4
  import matplotlib.pyplot as plt
5
  import joblib
 
 
6
 
7
  import os
8
  import shutil
@@ -20,6 +22,7 @@ else:
20
 
21
  data=pd.read_csv('data/heart.xls')
22
 
 
23
  data.info() #checking the info
24
 
25
  data_corr=data.corr()
@@ -27,7 +30,41 @@ data_corr=data.corr()
27
  plt.figure(figsize=(20,20))
28
  sns.heatmap(data=data_corr,annot=True)
29
  #Heatmap for data
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  feature_value=np.array(data_corr['output'])
32
  for i in range(len(feature_value)):
33
  if feature_value[i]<0:
@@ -45,9 +82,6 @@ feature_selected #selected features which are very much correalated
45
 
46
  clean_data=data[feature_selected]
47
 
48
- from sklearn.tree import DecisionTreeClassifier #using sklearn decisiontreeclassifier
49
- from sklearn.model_selection import train_test_split
50
-
51
  #making input and output dataset
52
  X=clean_data.iloc[:,1:]
53
  Y=clean_data['output']
@@ -63,7 +97,7 @@ x_train=sc.fit_transform(x_train)
63
  x_test=sc.transform(x_test)
64
 
65
  #training our model
66
- dt=DecisionTreeClassifier(criterion='entropy',max_depth=6)
67
  dt.fit(x_train,y_train)
68
  #dt.compile(x_trqin)
69
 
@@ -79,9 +113,10 @@ print("\nThe accuracy of decisiontreelassifier on Heart disease prediction datas
79
 
80
  joblib.dump(dt, 'heart_disease_dt_model.pkl')
81
 
82
- from concrete.ml.sklearn.tree import DecisionTreeClassifier
 
83
 
84
- fhe_compatible = DecisionTreeClassifier.from_sklearn_model(dt, x_train, n_bits = 10)
85
  fhe_compatible.compile(x_train)
86
 
87
 
@@ -99,51 +134,3 @@ dev.save()
99
  # Setup the server
100
  server = FHEModelServer(path_dir=fhe_directory)
101
  server.load()
102
-
103
-
104
-
105
-
106
-
107
-
108
-
109
- ####### client
110
-
111
- from concrete.ml.deployment import FHEModelDev, FHEModelClient, FHEModelServer
112
-
113
- # Setup the client
114
- client = FHEModelClient(path_dir=fhe_directory, key_dir="/tmp/keys_client")
115
- serialized_evaluation_keys = client.get_serialized_evaluation_keys()
116
-
117
-
118
- # Load the dataset and select the relevant features
119
- data = pd.read_csv('data/heart.xls')
120
-
121
- # Perform the correlation analysis
122
- data_corr = data.corr()
123
-
124
- # Select features based on correlation with 'output'
125
- feature_value = np.array(data_corr['output'])
126
- for i in range(len(feature_value)):
127
- if feature_value[i] < 0:
128
- feature_value[i] = -feature_value[i]
129
-
130
- features_corr = pd.DataFrame(feature_value, index=data_corr['output'].index, columns=['correlation'])
131
- feature_sorted = features_corr.sort_values(by=['correlation'], ascending=False)
132
- feature_selected = feature_sorted.index
133
-
134
- # Clean the data by selecting the most correlated features
135
- clean_data = data[feature_selected]
136
-
137
- # Extract the first row of feature data for prediction (excluding 'output' column)
138
- sample_data = clean_data.iloc[0, 1:].values.reshape(1, -1) # Reshape to 2D array for model input
139
-
140
- encrypted_data = client.quantize_encrypt_serialize(sample_data)
141
-
142
-
143
-
144
- ##### end client
145
-
146
- encrypted_result = server.run(encrypted_data, serialized_evaluation_keys)
147
-
148
- result = client.deserialize_decrypt_dequantize(encrypted_result)
149
- print(result)
 
3
  import seaborn as sns
4
  import matplotlib.pyplot as plt
5
  import joblib
6
+ from sklearn.tree import DecisionTreeClassifier, XGBClassifier #using sklearn decisiontreeclassifier
7
+ from sklearn.model_selection import train_test_split
8
 
9
  import os
10
  import shutil
 
22
 
23
  data=pd.read_csv('data/heart.xls')
24
 
25
+
26
  data.info() #checking the info
27
 
28
  data_corr=data.corr()
 
30
  plt.figure(figsize=(20,20))
31
  sns.heatmap(data=data_corr,annot=True)
32
  #Heatmap for data
33
+ """
34
+ # Get the Data
35
+ X_train, y_train, X_val, y_val = train_test_split()
36
+ classifier = XGBClassifier()
37
+ # Training the Model
38
+ classifier = classifier.fit(X_train, y_train)
39
+ # Trained Model Evaluation on Validation Dataset
40
+ confidence = classifier.score(X_val, y_val)
41
+ # Validation Data Prediction
42
+ y_pred = classifier.predict(X_val)
43
+ # Model Validation Accuracy
44
+ accuracy = accuracy_score(y_val, y_pred)
45
+ # Model Confusion Matrix
46
+ conf_mat = confusion_matrix(y_val, y_pred)
47
+ # Model Classification Report
48
+ clf_report = classification_report(y_val, y_pred)
49
+ # Model Cross Validation Score
50
+ score = cross_val_score(classifier, X_val, y_val, cv=3)
51
+
52
+ try:
53
+ # Load Trained Model
54
+ clf = load(str(self.model_save_path + saved_model_name + ".joblib"))
55
+ except Exception as e:
56
+ print("Model not found...")
57
+
58
+ if test_data is not None:
59
+ result = clf.predict(test_data)
60
+ print(result)
61
+ else:
62
+ result = clf.predict(self.test_features)
63
+ accuracy = accuracy_score(self.test_labels, result)
64
+ clf_report = classification_report(self.test_labels, result)
65
+ print(accuracy, clf_report)
66
+ """
67
+ ####################
68
  feature_value=np.array(data_corr['output'])
69
  for i in range(len(feature_value)):
70
  if feature_value[i]<0:
 
82
 
83
  clean_data=data[feature_selected]
84
 
 
 
 
85
  #making input and output dataset
86
  X=clean_data.iloc[:,1:]
87
  Y=clean_data['output']
 
97
  x_test=sc.transform(x_test)
98
 
99
  #training our model
100
+ dt=XGBClassifier(criterion='entropy',max_depth=6)
101
  dt.fit(x_train,y_train)
102
  #dt.compile(x_trqin)
103
 
 
113
 
114
  joblib.dump(dt, 'heart_disease_dt_model.pkl')
115
 
116
+ from concrete.ml.sklearn import DecisionTreeClassifier as ConcreteDecisionTreeClassifier
117
+ from concrete.ml.sklearn import XGBClassifier as ConcreteXGBClassifier
118
 
119
+ fhe_compatible = ConcreteXGBClassifier.from_sklearn_model(dt, x_train, n_bits = 10) #de FHE
120
  fhe_compatible.compile(x_train)
121
 
122
 
 
134
  # Setup the server
135
  server = FHEModelServer(path_dir=fhe_directory)
136
  server.load()