gaviego commited on
Commit
4417b5c
·
1 Parent(s): 127e87d

more layers

Browse files
Files changed (4) hide show
  1. app.py +17 -5
  2. mnist.pth +0 -0
  3. model.py +5 -3
  4. requirements.txt +0 -2
app.py CHANGED
@@ -16,9 +16,21 @@ def predict(img):
16
  topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
17
  return [str(k) for k in topk_indices[0].tolist()]
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- sp = gr.Sketchpad(shape=(28, 28))
21
-
22
- gr.Interface(fn=predict,
23
- inputs=sp,
24
- outputs=['label','label']).launch()
 
16
  topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
17
  return [str(k) for k in topk_indices[0].tolist()]
18
 
19
+ with gr.Blocks() as iface:
20
+ gr.Markdown("# MNIST + Gradio End to End")
21
+ gr.HTML("Shows end to end MNIST training with Gradio interface")
22
+ with gr.Row():
23
+ with gr.Column():
24
+ sp = gr.Sketchpad(shape=(28, 28))
25
+ with gr.Row():
26
+ with gr.Column():
27
+ pred_button = gr.Button("Predict")
28
+ with gr.Column():
29
+ clear = gr.Button("Clear")
30
+ with gr.Column():
31
+ label1 = gr.Label(label='1st Pred')
32
+ label2 = gr.Label(label='2nd Pred')
33
 
34
+ pred_button.click(predict, inputs=sp, outputs=[label1,label2])
35
+ clear.click(lambda: None, None, sp, queue=False)
36
+ iface.launch()
 
 
mnist.pth CHANGED
Binary files a/mnist.pth and b/mnist.pth differ
 
model.py CHANGED
@@ -5,11 +5,13 @@ class Net(nn.Module):
5
  def __init__(self):
6
  super(Net, self).__init__()
7
  self.fc1 = nn.Linear(28*28, 128) # MNIST images are 28x28
8
- self.fc2 = nn.Linear(128, 64)
9
- self.fc3 = nn.Linear(64, 10) # There are 10 classes (0 through 9)
 
10
 
11
  def forward(self, x):
12
  x = x.view(x.shape[0], -1) # Flatten the input
13
  x = torch.relu(self.fc1(x))
14
  x = torch.relu(self.fc2(x))
15
- return self.fc3(x)
 
 
5
  def __init__(self):
6
  super(Net, self).__init__()
7
  self.fc1 = nn.Linear(28*28, 128) # MNIST images are 28x28
8
+ self.fc2 = nn.Linear(128, 128)
9
+ self.fc3 = nn.Linear(128, 64)
10
+ self.fc4 = nn.Linear(64, 10) # There are 10 classes (0 through 9)
11
 
12
  def forward(self, x):
13
  x = x.view(x.shape[0], -1) # Flatten the input
14
  x = torch.relu(self.fc1(x))
15
  x = torch.relu(self.fc2(x))
16
+ x = torch.relu(self.fc3(x))
17
+ return self.fc4(x)
requirements.txt CHANGED
@@ -1,5 +1,3 @@
1
- gradio==3.29.0
2
- numpy==1.23.5
3
  Pillow==9.1.0
4
  torch==2.0.1
5
  torchvision==0.15.2
 
 
 
1
  Pillow==9.1.0
2
  torch==2.0.1
3
  torchvision==0.15.2