xszheng2020 commited on
Commit
44a6f1e
·
verified ·
1 Parent(s): 86e566d

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +181 -0
  2. data.csv +129 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import sklearn
2
+ import gradio as gr
3
+ # import joblib
4
+ import pandas as pd
5
+ import numpy as np
6
+ import lightgbm as lgb
7
+ from sklearn.model_selection import train_test_split
8
+ from PIL import Image
9
+ # import datasets
10
+
11
+ # pipe = joblib.load("./model.pkl")
12
+
13
+ title = "RegMix"
14
+ description = "TBD."
15
+
16
+ df = pd.read_csv('data.csv')
17
+ headers = df.columns.tolist()
18
+
19
+ inputs = [gr.Dataframe(headers=headers, row_count = (8, "dynamic"), datatype='number', col_count=(4,"fixed"), label="Dataset", interactive=1)]
20
+ outputs = [gr.ScatterPlot(), gr.Image(), gr.Dataframe(row_count = (2, "dynamic"), col_count=(2, "fixed"), datatype='number', label="Results", headers=["True Loss", "Pred Loss"])]
21
+
22
+ def infer(inputs):
23
+ df = pd.DataFrame(inputs, columns=headers)
24
+
25
+ X_columns = df.columns[0:-1]
26
+ y_column = df.columns[-1]
27
+
28
+ df_train, df_val = train_test_split(df, test_size=0.125, random_state=42)
29
+
30
+ hyper_params = {
31
+ 'task': 'train',
32
+ 'boosting_type': 'gbdt',
33
+ 'objective': 'regression',
34
+ 'metric': ['l1','l2'],
35
+ "num_iterations": 1000,
36
+ 'seed': 42,
37
+ 'learning_rate': 1e-2,
38
+ }
39
+
40
+ target = df_train[y_column]
41
+ eval_target = df_val[y_column]
42
+
43
+ np.random.seed(42)
44
+
45
+ gbm = lgb.LGBMRegressor(**hyper_params)
46
+
47
+ reg = gbm.fit(df_train[X_columns].values, target,
48
+ eval_set=[(df_val[X_columns].values, eval_target)],
49
+ eval_metric='l2',
50
+ callbacks=[
51
+ lgb.early_stopping(stopping_rounds=3),
52
+ ]
53
+ )
54
+
55
+ predictions = reg.predict(df_val[X_columns].values)
56
+ df_val['Prediction'] = predictions
57
+
58
+ ####
59
+ import matplotlib.pyplot as plt
60
+ plt.rcParams["font.family"] = "Times New Roman" # !!!!
61
+ plt.rcParams.update({'font.size': 24})
62
+ plt.rcParams.update({'axes.labelpad': 20})
63
+
64
+ from matplotlib import cm
65
+ from matplotlib.ticker import LinearLocator
66
+
67
+ fig, ax = plt.subplots(figsize=(12, 12), layout='compressed', subplot_kw={"projection": "3d"})
68
+
69
+ stride = 0.025
70
+ X = np.arange(0, 1+stride, stride)
71
+ Y = np.arange(0, 1+stride, stride)
72
+
73
+ X, Y = np.meshgrid(X, Y)
74
+ Z = []
75
+ for (x,y) in zip(X.reshape(-1), Y.reshape(-1)):
76
+ if (x+y)>1:
77
+ Z.append(np.inf)
78
+ else:
79
+ Z.append(
80
+ reg.predict(np.asarray([x, y, 1-x-y]).reshape(1, -1)
81
+ )[0])
82
+ Z = np.asarray(Z).reshape(len(np.arange(0, 1+stride, stride)), len(np.arange(0, 1+stride, stride)))
83
+
84
+ # Plot the surface.
85
+ surf = ax.plot_surface(X, Y, Z,
86
+ edgecolor='white',
87
+ lw=0.5, rstride=2, cstride=2,
88
+ alpha=0.85,
89
+ cmap='coolwarm',
90
+ vmin=min(Z[Z!=np.inf]),
91
+ vmax=max(Z[Z!=np.inf]),
92
+ # linewidth=8,
93
+ antialiased=False, )
94
+
95
+ ax.zaxis.set_major_locator(LinearLocator(10))
96
+ ax.zaxis.set_major_formatter('{x:.02f}')
97
+
98
+ ax.view_init(elev=25, azim=45, roll=0) #####
99
+
100
+ ax.contourf(X, Y, Z, zdir='z',
101
+ offset=np.min(Z)-0.35,
102
+ cmap=cm.coolwarm)
103
+
104
+ from matplotlib.patches import Circle
105
+ from mpl_toolkits.mplot3d import art3d
106
+
107
+ def add_point(ax, x, y, z, fc = None, ec = None, radius = 0.005):
108
+ xy_len, z_len = ax.get_figure().get_size_inches()
109
+ axis_length = [x[1] - x[0] for x in [ax.get_xbound(), ax.get_ybound(), ax.get_zbound()]]
110
+ axis_rotation = {'z': ((x, y, z), axis_length[1]/axis_length[0]),
111
+ 'y': ((x, z, y), axis_length[2]/axis_length[0]*xy_len/z_len),
112
+ 'x': ((y, z, x), axis_length[2]/axis_length[1]*xy_len/z_len)}
113
+ for a, ((x0, y0, z0), ratio) in axis_rotation.items():
114
+ p = Circle((x0, y0), radius, lw=1.5,
115
+ # width = radius, height = radius*ratio,
116
+ fc=fc,
117
+ ec=ec)
118
+ ax.add_patch(p)
119
+ art3d.pathpatch_2d_to_3d(p, z=z0, zdir=a)
120
+
121
+
122
+ add_point(ax, X.reshape(-1)[np.argmin(Z)], Y.reshape(-1)[np.argmin(Z)], np.min(Z),
123
+ fc='Red',
124
+ ec='Red', radius=0.015)
125
+
126
+ add_point(ax, X.reshape(-1)[np.argmin(Z)], Y.reshape(-1)[np.argmin(Z)], np.min(Z)-0.35,
127
+ fc='Red',
128
+ ec='Red', radius=0.015)
129
+
130
+
131
+ ax.set_xlabel('Github (%)', fontdict={
132
+ 'size':24
133
+ })
134
+ ax.set_ylabel('Hacker News (%)', fontdict={
135
+ 'size':24
136
+ })
137
+
138
+ ax.set_xticks(np.arange(0, 1, 0.2), [str(np.round(num, 1)) for num in np.arange(0, 100, 20)], )
139
+ ax.set_yticks(np.arange(0, 1, 0.2), [str(np.round(num, 1)) for num in np.arange(0, 100, 20)], )
140
+
141
+ ax.set_zticks(np.arange(np.min(Z), np.max(Z[Z!=np.inf]), 0.2), [str(np.round(num, 1)) for num in np.arange(np.min(Z), np.max(Z[Z!=np.inf]), 0.2)], )
142
+
143
+ ax.zaxis.labelpad=1
144
+
145
+ ax.set_zlim(np.min(Z)-0.35, max(Z[Z!=np.inf])+0.01)
146
+ ax.set_xlim(0, 1)
147
+ ax.set_ylim(0, 1)
148
+ ax.set_box_aspect(aspect=None, zoom=0.775)
149
+
150
+ ax.zaxis._axinfo['juggled'] = (1,2,2)
151
+
152
+ # Add a color bar which maps values to colors.
153
+ cbar = fig.colorbar(surf,
154
+ shrink=0.5,
155
+ aspect=25, pad=0.01
156
+ )
157
+ cbar.ax.set_ylabel('Prediction', fontdict={
158
+ 'size':32
159
+ },
160
+ # rotation=270,
161
+ # labelpad=-90
162
+ )
163
+
164
+
165
+ filename = "tmp.png"
166
+ plt.savefig(filename, bbox_inches='tight', pad_inches=0.1)
167
+ ####
168
+ return [gr.ScatterPlot(
169
+ value=df_val,
170
+ x="Prediction",
171
+ y="Target",
172
+ title="Scatter",
173
+ tooltip=["Prediction", "Target"],
174
+ x_lim=[min(min(predictions), min(df_val[y_column]))-0.25, max(max(predictions), max(df_val[y_column]))+0.25],
175
+ y_lim=[min(min(predictions), min(df_val[y_column]))-0.25, max(max(predictions), max(df_val[y_column]))+0.25]
176
+ ),
177
+ gr.Image(Image.open('tmp.png')),
178
+ df_val[['Target', 'Prediction']], ]
179
+
180
+ gr.Interface(infer, inputs = inputs, outputs = outputs, title = title,
181
+ description = description, examples=[df], cache_examples=False, allow_flagging='never').launch(debug=False)
data.csv ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Github,Hacker News,Philpapers,Target
2
+ 0.616,0.227,0.157,5.306806564331055
3
+ 0.025,0.131,0.844,6.012040138244629
4
+ 0.183,0.28,0.538,5.527338981628418
5
+ 0.032,0.548,0.421,5.821677207946777
6
+ 0.002,0.347,0.651,6.326714515686035
7
+ 0.083,0.827,0.09,5.659005641937256
8
+ 0.011,0.478,0.511,6.063345909118652
9
+ 0.699,0.002,0.3,5.410695552825928
10
+ 0.363,0.389,0.249,5.3828229904174805
11
+ 0.542,0.257,0.201,5.31920337677002
12
+ 0.262,0.311,0.427,5.465338230133057
13
+ 0.677,0.251,0.072,5.3108978271484375
14
+ 0.705,0.203,0.092,5.305531024932861
15
+ 0.68,0.208,0.112,5.305020809173584
16
+ 0.109,0.22,0.671,5.645120620727539
17
+ 0.307,0.074,0.619,5.510026931762695
18
+ 0.198,0.69,0.112,5.476624488830566
19
+ 0.197,0.401,0.402,5.487608432769775
20
+ 0.056,0.338,0.606,5.778319358825684
21
+ 0.34,0.256,0.405,5.406895637512207
22
+ 0.534,0.3,0.167,5.321551322937012
23
+ 0.022,0.0,0.977,6.284101486206055
24
+ 0.732,0.228,0.04,5.291826248168945
25
+ 0.18,0.1,0.72,5.596135139465332
26
+ 0.279,0.539,0.182,5.439793586730957
27
+ 0.228,0.442,0.33,5.461205959320068
28
+ 0.278,0.271,0.45,5.446168899536133
29
+ 0.184,0.511,0.305,5.51557731628418
30
+ 0.008,0.542,0.45,6.084725379943848
31
+ 0.659,0.121,0.22,5.333937644958496
32
+ 0.134,0.041,0.825,5.712595462799072
33
+ 0.046,0.002,0.952,6.052066326141357
34
+ 0.022,0.236,0.742,6.000083923339844
35
+ 0.573,0.044,0.382,5.3919196128845215
36
+ 0.57,0.181,0.249,5.336789608001709
37
+ 0.244,0.424,0.333,5.454492568969727
38
+ 0.464,0.031,0.505,5.44106388092041
39
+ 0.349,0.046,0.605,5.503855705261231
40
+ 0.435,0.019,0.545,5.472304344177246
41
+ 0.011,0.571,0.418,6.020427703857422
42
+ 0.083,0.794,0.123,5.637621879577637
43
+ 0.433,0.125,0.442,5.397417545318604
44
+ 0.032,0.457,0.512,5.846551895141602
45
+ 0.248,0.128,0.624,5.5152506828308105
46
+ 0.159,0.747,0.094,5.518942832946777
47
+ 0.03,0.322,0.648,5.870717525482178
48
+ 0.389,0.248,0.363,5.399688720703125
49
+ 0.487,0.234,0.279,5.344883918762207
50
+ 0.385,0.363,0.252,5.373892784118652
51
+ 0.793,0.029,0.178,5.366541862487793
52
+ 0.62,0.38,0.0,5.323075294494629
53
+ 0.024,0.635,0.34,5.89354133605957
54
+ 0.848,0.152,0.0,5.330050468444824
55
+ 0.082,0.257,0.661,5.695749759674072
56
+ 0.111,0.747,0.142,5.571730136871338
57
+ 0.997,0.001,0.002,5.432588577270508
58
+ 0.484,0.064,0.452,5.41372537612915
59
+ 0.257,0.023,0.72,5.593489646911621
60
+ 0.908,0.064,0.028,5.33869743347168
61
+ 0.407,0.575,0.018,5.356371879577637
62
+ 0.716,0.209,0.074,5.299798488616943
63
+ 0.499,0.467,0.034,5.316855430603027
64
+ 0.463,0.09,0.447,5.408260822296143
65
+ 0.347,0.164,0.49,5.455391883850098
66
+ 0.22,0.31,0.47,5.478835105895996
67
+ 0.085,0.899,0.017,5.63015079498291
68
+ 0.831,0.042,0.126,5.347104549407959
69
+ 0.083,0.845,0.072,5.637035846710205
70
+ 0.009,0.352,0.639,6.105093955993652
71
+ 0.373,0.177,0.45,5.426303386688232
72
+ 0.0,1.0,0.0,6.3570756912231445
73
+ 0.001,0.0,0.999,7.201324462890625
74
+ 0.577,0.032,0.391,5.384527683258057
75
+ 0.699,0.248,0.053,5.30719518661499
76
+ 0.131,0.379,0.491,5.566975593566895
77
+ 0.042,0.865,0.093,5.747323036193848
78
+ 0.009,0.773,0.218,6.052563190460205
79
+ 0.593,0.198,0.209,5.319425582885742
80
+ 0.335,0.063,0.602,5.493361473083496
81
+ 0.508,0.2,0.292,5.362349033355713
82
+ 0.001,0.073,0.926,6.6276702880859375
83
+ 0.472,0.164,0.364,5.364439487457275
84
+ 0.021,0.415,0.563,5.923987865447998
85
+ 0.995,0.0,0.005,5.450720310211182
86
+ 0.613,0.221,0.166,5.330704689025879
87
+ 0.238,0.668,0.093,5.449023246765137
88
+ 0.521,0.08,0.399,5.382278919219971
89
+ 0.102,0.138,0.76,5.70582389831543
90
+ 0.627,0.02,0.353,5.393834590911865
91
+ 0.027,0.955,0.018,5.853610515594482
92
+ 0.215,0.713,0.071,5.452286243438721
93
+ 0.265,0.092,0.643,5.527949333190918
94
+ 0.178,0.002,0.82,5.708561897277832
95
+ 0.028,0.029,0.943,6.10589599609375
96
+ 0.002,0.305,0.693,6.344947338104248
97
+ 0.608,0.358,0.034,5.328068733215332
98
+ 0.579,0.226,0.195,5.31940221786499
99
+ 0.171,0.04,0.789,5.646611213684082
100
+ 0.056,0.483,0.461,5.769045352935791
101
+ 0.175,0.358,0.467,5.5205817222595215
102
+ 0.223,0.713,0.065,5.465519428253174
103
+ 0.359,0.095,0.546,5.459996223449707
104
+ 0.051,0.672,0.276,5.736810684204102
105
+ 0.727,0.198,0.075,5.289270401000977
106
+ 0.019,0.203,0.778,6.054460525512695
107
+ 0.12,0.877,0.003,5.571919918060303
108
+ 0.771,0.026,0.203,5.370471477508545
109
+ 0.642,0.091,0.267,5.33230447769165
110
+ 0.209,0.089,0.702,5.567383766174316
111
+ 0.603,0.036,0.361,5.384149551391602
112
+ 0.185,0.31,0.504,5.524736404418945
113
+ 0.489,0.328,0.183,5.332746982574463
114
+ 0.014,0.245,0.742,6.066995620727539
115
+ 0.75,0.241,0.009,5.288382530212402
116
+ 0.527,0.352,0.121,5.32275915145874
117
+ 0.291,0.301,0.408,5.430193901062012
118
+ 0.046,0.755,0.199,5.744025707244873
119
+ 0.031,0.949,0.02,5.835524559020996
120
+ 0.252,0.015,0.733,5.597743988037109
121
+ 0.524,0.004,0.471,5.446127414703369
122
+ 0.0,0.619,0.381,6.416290760040283
123
+ 0.08,0.903,0.017,5.660269260406494
124
+ 0.0,0.87,0.13,6.306185245513916
125
+ 0.209,0.205,0.586,5.519876003265381
126
+ 0.057,0.533,0.41,5.761989116668701
127
+ 0.307,0.597,0.096,5.392508506774902
128
+ 0.008,0.98,0.011,6.026959896087647
129
+ 0.865,0.039,0.096,5.351530075073242