rajora commited on
Commit
061d93e
·
1 Parent(s): b9592e9

Add gradio app HF

Browse files
.gitattributes copy ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from utils.model import load_model
3
+ from utils.transformations import transform_image
4
+ from utils.prediction import predict_class
5
+
6
+ # Load the model only once when the application starts
7
+ net = load_model()
8
+
9
+ def classify_image(image_path):
10
+ """
11
+ Main function for the Gradio interface.
12
+ """
13
+ image_tensor = transform_image(image_path)
14
+ predicted_class, confidence = predict_class(image_tensor, net)
15
+ return f"Predicted Class: {predicted_class}, Confidence: {confidence:.2f}"
16
+
17
+ iface = gr.Interface(
18
+ fn=classify_image,
19
+ inputs=gr.Image(type="filepath"),
20
+ outputs="text",
21
+ title="Galaxy Classification",
22
+ description="Upload an image of a galaxy to classify it."
23
+ )
24
+
25
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may be used to create an environment using:
2
+ # $ conda create --name <env> --file <this file>
3
+ # platform: linux-64
4
+ _libgcc_mutex=0.1=conda_forge
5
+ _openmp_mutex=4.5=2_gnu
6
+ aiofiles=23.2.1=pypi_0
7
+ altair=5.3.0=pypi_0
8
+ annotated-types=0.7.0=pypi_0
9
+ anyio=4.4.0=pypi_0
10
+ asttokens=2.4.1=pyhd8ed1ab_0
11
+ attrs=23.2.0=pypi_0
12
+ bzip2=1.0.8=h4bc722e_7
13
+ ca-certificates=2024.7.4=hbcca054_0
14
+ certifi=2024.7.4=pypi_0
15
+ charset-normalizer=3.3.2=pypi_0
16
+ click=8.1.7=pypi_0
17
+ comm=0.2.2=pyhd8ed1ab_0
18
+ contourpy=1.2.1=pypi_0
19
+ cycler=0.12.1=pypi_0
20
+ debugpy=1.8.2=py312h7070661_0
21
+ decorator=5.1.1=pyhd8ed1ab_0
22
+ dnspython=2.6.1=pypi_0
23
+ efficientnet-pytorch=0.7.1=pypi_0
24
+ email-validator=2.2.0=pypi_0
25
+ exceptiongroup=1.2.2=pyhd8ed1ab_0
26
+ executing=2.0.1=pyhd8ed1ab_0
27
+ fastapi=0.111.1=pypi_0
28
+ fastapi-cli=0.0.4=pypi_0
29
+ ffmpy=0.3.2=pypi_0
30
+ filelock=3.15.4=pypi_0
31
+ fonttools=4.53.1=pypi_0
32
+ fsspec=2024.6.1=pypi_0
33
+ gradio=4.38.1=pypi_0
34
+ gradio-client=1.1.0=pypi_0
35
+ h11=0.14.0=pypi_0
36
+ httpcore=1.0.5=pypi_0
37
+ httptools=0.6.1=pypi_0
38
+ httpx=0.27.0=pypi_0
39
+ huggingface-hub=0.24.0=pypi_0
40
+ idna=3.7=pypi_0
41
+ importlib-metadata=8.0.0=pyha770c72_0
42
+ importlib-resources=6.4.0=pypi_0
43
+ importlib_metadata=8.0.0=hd8ed1ab_0
44
+ ipykernel=6.29.5=pyh3099207_0
45
+ ipython=8.26.0=pyh707e725_0
46
+ jedi=0.19.1=pyhd8ed1ab_0
47
+ jinja2=3.1.4=pypi_0
48
+ joblib=1.4.2=pypi_0
49
+ jsonschema=4.23.0=pypi_0
50
+ jsonschema-specifications=2023.12.1=pypi_0
51
+ jupyter_client=8.6.2=pyhd8ed1ab_0
52
+ jupyter_core=5.7.2=py312h7900ff3_0
53
+ keyutils=1.6.1=h166bdaf_0
54
+ kiwisolver=1.4.5=pypi_0
55
+ krb5=1.21.3=h659f571_0
56
+ ld_impl_linux-64=2.40=hf3520f5_7
57
+ libedit=3.1.20191231=he28a2e2_2
58
+ libexpat=2.6.2=h59595ed_0
59
+ libffi=3.4.2=h7f98852_5
60
+ libgcc-ng=14.1.0=h77fa898_0
61
+ libgomp=14.1.0=h77fa898_0
62
+ libnsl=2.0.1=hd590300_0
63
+ libsodium=1.0.18=h36c2ea0_1
64
+ libsqlite=3.46.0=hde9e2c9_0
65
+ libstdcxx-ng=14.1.0=hc0a3c3a_0
66
+ libuuid=2.38.1=h0b41bf4_0
67
+ libxcrypt=4.4.36=hd590300_1
68
+ libzlib=1.3.1=h4ab18f5_1
69
+ markdown-it-py=3.0.0=pypi_0
70
+ markupsafe=2.1.5=pypi_0
71
+ matplotlib=3.9.1=pypi_0
72
+ matplotlib-inline=0.1.7=pyhd8ed1ab_0
73
+ mdurl=0.1.2=pypi_0
74
+ mpmath=1.3.0=pypi_0
75
+ ncurses=6.5=h59595ed_0
76
+ nest-asyncio=1.6.0=pyhd8ed1ab_0
77
+ networkx=3.3=pypi_0
78
+ numpy=2.0.0=pypi_0
79
+ nvidia-cublas-cu12=12.1.3.1=pypi_0
80
+ nvidia-cuda-cupti-cu12=12.1.105=pypi_0
81
+ nvidia-cuda-nvrtc-cu12=12.1.105=pypi_0
82
+ nvidia-cuda-runtime-cu12=12.1.105=pypi_0
83
+ nvidia-cudnn-cu12=8.9.2.26=pypi_0
84
+ nvidia-cufft-cu12=11.0.2.54=pypi_0
85
+ nvidia-curand-cu12=10.3.2.106=pypi_0
86
+ nvidia-cusolver-cu12=11.4.5.107=pypi_0
87
+ nvidia-cusparse-cu12=12.1.0.106=pypi_0
88
+ nvidia-nccl-cu12=2.20.5=pypi_0
89
+ nvidia-nvjitlink-cu12=12.5.82=pypi_0
90
+ nvidia-nvtx-cu12=12.1.105=pypi_0
91
+ openssl=3.3.1=h4bc722e_2
92
+ orjson=3.10.6=pypi_0
93
+ packaging=24.1=pyhd8ed1ab_0
94
+ pandas=2.2.2=pypi_0
95
+ parso=0.8.4=pyhd8ed1ab_0
96
+ pexpect=4.9.0=pyhd8ed1ab_0
97
+ pickleshare=0.7.5=py_1003
98
+ pillow=10.4.0=pypi_0
99
+ pip=24.0=pyhd8ed1ab_0
100
+ platformdirs=4.2.2=pyhd8ed1ab_0
101
+ prompt-toolkit=3.0.47=pyha770c72_0
102
+ psutil=6.0.0=py312h9a8786e_0
103
+ ptyprocess=0.7.0=pyhd3deb0d_0
104
+ pure_eval=0.2.2=pyhd8ed1ab_0
105
+ pydantic=2.8.2=pypi_0
106
+ pydantic-core=2.20.1=pypi_0
107
+ pydub=0.25.1=pypi_0
108
+ pygments=2.18.0=pyhd8ed1ab_0
109
+ pyparsing=3.1.2=pypi_0
110
+ python=3.12.4=h194c7f8_0_cpython
111
+ python-dateutil=2.9.0=pyhd8ed1ab_0
112
+ python-dotenv=1.0.1=pypi_0
113
+ python-multipart=0.0.9=pypi_0
114
+ python_abi=3.12=4_cp312
115
+ pytz=2024.1=pypi_0
116
+ pyyaml=6.0.1=pypi_0
117
+ pyzmq=26.0.3=py312h8fd38d8_0
118
+ readline=8.2=h8228510_1
119
+ referencing=0.35.1=pypi_0
120
+ requests=2.32.3=pypi_0
121
+ rich=13.7.1=pypi_0
122
+ rpds-py=0.19.0=pypi_0
123
+ ruff=0.5.3=pypi_0
124
+ scikit-learn=1.5.1=pypi_0
125
+ scipy=1.14.0=pypi_0
126
+ semantic-version=2.10.0=pypi_0
127
+ setuptools=71.0.3=pyhd8ed1ab_0
128
+ shellingham=1.5.4=pypi_0
129
+ six=1.16.0=pyh6c4a22f_0
130
+ sniffio=1.3.1=pypi_0
131
+ stack_data=0.6.2=pyhd8ed1ab_0
132
+ starlette=0.37.2=pypi_0
133
+ sympy=1.13.1=pypi_0
134
+ threadpoolctl=3.5.0=pypi_0
135
+ tk=8.6.13=noxft_h4845f30_101
136
+ tomlkit=0.12.0=pypi_0
137
+ toolz=0.12.1=pypi_0
138
+ torch=2.3.1=pypi_0
139
+ torchvision=0.18.1=pypi_0
140
+ tornado=6.4.1=py312h9a8786e_0
141
+ tqdm=4.66.4=pypi_0
142
+ traitlets=5.14.3=pyhd8ed1ab_0
143
+ typer=0.12.3=pypi_0
144
+ typing_extensions=4.12.2=pyha770c72_0
145
+ tzdata=2024.1=pypi_0
146
+ urllib3=2.2.2=pypi_0
147
+ uvicorn=0.30.1=pypi_0
148
+ uvloop=0.19.0=pypi_0
149
+ watchfiles=0.22.0=pypi_0
150
+ wcwidth=0.2.13=pyhd8ed1ab_0
151
+ websockets=11.0.3=pypi_0
152
+ wheel=0.43.0=pyhd8ed1ab_1
153
+ xz=5.2.6=h166bdaf_0
154
+ zeromq=4.3.5=h75354e8_4
155
+ zipp=3.19.2=pyhd8ed1ab_0
utils/__pycache__/model.cpython-312.pyc ADDED
Binary file (2.12 kB). View file
 
utils/__pycache__/prediction.cpython-312.pyc ADDED
Binary file (1.47 kB). View file
 
utils/__pycache__/transformations.cpython-312.pyc ADDED
Binary file (1.22 kB). View file
 
utils/model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from efficientnet_pytorch import EfficientNet
5
+
6
+ class EffNet(nn.Module):
7
+ def __init__(self, n_classes):
8
+ super(EffNet, self).__init__()
9
+ self.b4 = EfficientNet.from_pretrained('efficientnet-b0')
10
+ self.drop = nn.Dropout(0.2)
11
+ self.fc = nn.Linear(1000, n_classes)
12
+
13
+ def forward(self, image):
14
+ x = self.b4(image)
15
+ x = self.drop(x)
16
+ out = self.fc(x)
17
+ return out
18
+
19
+ def load_model():
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+ net = EffNet(n_classes=2).to(device)
22
+ model_path = os.path.join(os.path.dirname(__file__), 'models', 'modelo_galaxias.pth')
23
+ net.load_state_dict(torch.load(model_path)) # Adjust path if needed
24
+ net.eval()
25
+ return net
utils/models/modelo_galaxias.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36aef445889deb63f7581a82cda3b0410a8bad183ff8a8dfd211460e8d79f4c4
3
+ size 21459918
utils/prediction.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def predict_class(image_tensor, model):
4
+ """
5
+ Predicts the class of the image using the model.
6
+
7
+ Args:
8
+ image_tensor (torch.Tensor): Transformed image.
9
+ model (torch.nn.Module): The loaded PyTorch model.
10
+
11
+ Returns:
12
+ str: Predicted class (eliptical or spiral).
13
+ float: Confidence of the prediction.
14
+ """
15
+
16
+ with torch.no_grad():
17
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18
+ output = model(image_tensor.to(device))
19
+ probabilities = torch.softmax(output, dim=1)
20
+ predicted_class_index = torch.argmax(probabilities, dim=1).item()
21
+ confidence = torch.max(probabilities).item()
22
+
23
+ classes = ['eliptical', 'spiral']
24
+ predicted_class = classes[predicted_class_index]
25
+
26
+ return predicted_class, confidence
utils/transformations.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torchvision.transforms as transforms
3
+
4
+ IMAGE_SHAPE = (200, 200)
5
+
6
+ def transform_image(image_path):
7
+ """
8
+ Loads, transforms, and prepares the image for the model.
9
+
10
+ Args:
11
+ image_path (str): Path to the image.
12
+
13
+ Returns:
14
+ torch.Tensor: Transformed image ready for the model.
15
+ """
16
+
17
+ transform = transforms.Compose([
18
+ transforms.Resize(IMAGE_SHAPE),
19
+ transforms.RandomHorizontalFlip(),
20
+ transforms.RandomRotation(degrees=10),
21
+ transforms.ToTensor(),
22
+ ])
23
+
24
+ image = Image.open(image_path).convert('RGB')
25
+ image = transform(image)
26
+ image = image.unsqueeze(0) # Add batch dimension
27
+ return image