Revrse commited on
Commit
05a0149
·
verified ·
1 Parent(s): 5528c9c

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +71 -0
  2. best.pt +3 -0
  3. requirements.txt +149 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import numpy as np
4
+ from ultralytics import YOLO
5
+ import inference
6
+ from langchain.chat_models import AzureChatOpenAI
7
+ import os
8
+ from langchain.schema import HumanMessage
9
+ import json
10
+
11
+
12
+ yolo_model = YOLO('best.pt')
13
+ roboflow_model = inference.get_model("web-icon-classification/1")
14
+ chat4 = AzureChatOpenAI(
15
+ openai_api_base=os.environ['BASE_URL'],
16
+ openai_api_version="2024-02-15-preview",
17
+ deployment_name="gpt-4",
18
+ openai_api_key=os.environ["OPENAI_API_KEY"],
19
+ openai_api_type="azure",
20
+ temperature=0,
21
+ request_timeout=30,
22
+ max_retries=3
23
+ )
24
+
25
+ def initiate_prompt(icon_name):
26
+ prompt = '''Given the name of an app icon, return a list of alternative names that represent similar functionality in the context of a web or mobile app.
27
+
28
+ User Input: "Settings"
29
+
30
+ Expected Output: Generate a list of alternative names that convey the same or similar functionality as "Settings" in the context of web or mobile apps.
31
+
32
+ Model Response: {
33
+ "alternatives": ["Preferences", "Options", "Controls", "Configuration", "Setup"]
34
+ }
35
+
36
+ User Input:''' + icon_name +'\n'+ " Model Response:"
37
+
38
+ return prompt
39
+
40
+ st.title("App/Web Icon Classification Comparison")
41
+
42
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
43
+ if uploaded_file is not None:
44
+ image = Image.open(uploaded_file)
45
+ st.image(image, caption='Uploaded Image', use_column_width=True)
46
+
47
+ if st.button("Classify Image"):
48
+ with st.spinner('Classifying...'):
49
+ try:
50
+ prediction = yolo_model(image)
51
+ class_id_1 = prediction[0].names[prediction[0].probs.top1]
52
+ classes_1 = json.loads(chat4.predict_messages(messages=[HumanMessage(content=initiate_prompt(class_id_1))]).content)['alternatives']
53
+ classes_1.insert(0, class_id_1)
54
+ except:
55
+ classes_1 = "None"
56
+
57
+ try:
58
+ prediction = roboflow_model.infer(image)
59
+ class_id_2 = prediction[0].predicted_classes[0]
60
+ classes_2 = json.loads(chat4.predict_messages(messages=[HumanMessage(content=initiate_prompt(class_id_2))]).content)['alternatives']
61
+ classes_2.insert(0, class_id_2)
62
+ except:
63
+ classes_2 = "None"
64
+
65
+ col1, col2 = st.columns(2)
66
+ with col1:
67
+ st.subheader("Yolov8-x Prediction")
68
+ st.write(f"Predicted Class: {classes_1}")
69
+ with col2:
70
+ st.subheader("ViT Prediction")
71
+ st.write(f"Predicted Class: {classes_2}")
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60bbffada39a6907f8cd6592ace509782cbaecac253bf006fd28f85dc72c20e8
3
+ size 112752945
requirements.txt ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.9.5
2
+ aioresponses==0.7.6
3
+ aiosignal==1.3.1
4
+ altair==5.3.0
5
+ annotated-types==0.6.0
6
+ anyio==4.3.0
7
+ APScheduler==3.10.1
8
+ async-timeout==4.0.3
9
+ attrs==23.2.0
10
+ awscli==1.29.54
11
+ backoff==2.2.1
12
+ blinker==1.7.0
13
+ boto3==1.28.23
14
+ botocore==1.31.54
15
+ cachetools==5.3.3
16
+ certifi==2024.2.2
17
+ cffi==1.16.0
18
+ charset-normalizer==3.3.2
19
+ click==8.1.7
20
+ colorama==0.4.4
21
+ coloredlogs==15.0.1
22
+ contourpy==1.2.1
23
+ cryptography==42.0.5
24
+ cycler==0.12.1
25
+ Cython==3.0.0
26
+ dataclasses-json==0.6.4
27
+ defusedxml==0.7.1
28
+ distro==1.9.0
29
+ docker==6.1.3
30
+ docutils==0.16
31
+ exceptiongroup==1.2.1
32
+ fastapi==0.110.2
33
+ filelock==3.13.4
34
+ flatbuffers==24.3.25
35
+ fonttools==4.51.0
36
+ frozenlist==1.4.1
37
+ fsspec==2024.3.1
38
+ gitdb==4.0.11
39
+ GitPython==3.1.43
40
+ GPUtil==1.4.0
41
+ h11==0.14.0
42
+ httpcore==1.0.5
43
+ httpx==0.27.0
44
+ humanfriendly==10.0
45
+ idna==3.7
46
+ imageio==2.34.1
47
+ inference==0.9.22
48
+ iniconfig==2.0.0
49
+ Jinja2==3.1.3
50
+ jmespath==1.0.1
51
+ jsonpatch==1.33
52
+ jsonpointer==2.4
53
+ jsonschema==4.21.1
54
+ jsonschema-specifications==2023.12.1
55
+ kiwisolver==1.4.5
56
+ langchain==0.1.16
57
+ langchain-community==0.0.34
58
+ langchain-core==0.1.46
59
+ langchain-text-splitters==0.0.1
60
+ langsmith==0.1.51
61
+ lazy_loader==0.4
62
+ markdown-it-py==3.0.0
63
+ MarkupSafe==2.1.5
64
+ marshmallow==3.21.1
65
+ matplotlib==3.8.4
66
+ mdurl==0.1.2
67
+ mpmath==1.3.0
68
+ multidict==6.0.5
69
+ mypy-extensions==1.0.0
70
+ networkx==3.3
71
+ numpy==1.25.2
72
+ onnxruntime==1.15.1
73
+ openai==1.23.6
74
+ opencv-python==4.8.0.76
75
+ opencv-python-headless==4.9.0.80
76
+ orjson==3.10.1
77
+ packaging==23.2
78
+ pandas==2.2.2
79
+ pendulum==3.0.0
80
+ piexif==1.1.3
81
+ pillow==10.3.0
82
+ pluggy==1.5.0
83
+ prettytable==3.10.0
84
+ prometheus-fastapi-instrumentator==6.0.0
85
+ prometheus_client==0.20.0
86
+ protobuf==4.25.3
87
+ psutil==5.9.8
88
+ PuLP==2.8.0
89
+ py-cpuinfo==9.0.0
90
+ pyarrow==16.0.0
91
+ pyasn1==0.6.0
92
+ pybase64==1.3.2
93
+ pycparser==2.22
94
+ pydantic==2.7.1
95
+ pydantic_core==2.18.2
96
+ pydeck==0.9.0b1
97
+ Pygments==2.17.2
98
+ pyparsing==3.1.2
99
+ pytest==8.1.2
100
+ pytest-asyncio==0.21.1
101
+ python-dateutil==2.9.0.post0
102
+ python-dotenv==1.0.1
103
+ pytz==2024.1
104
+ PyYAML==6.0.1
105
+ redis==5.0.4
106
+ referencing==0.35.0
107
+ requests==2.31.0
108
+ requests-toolbelt==1.0.0
109
+ rich==13.5.2
110
+ rpds-py==0.18.0
111
+ rsa==4.7.2
112
+ s3transfer==0.6.2
113
+ scikit-image==0.23.2
114
+ scipy==1.13.0
115
+ seaborn==0.13.2
116
+ shapely==2.0.1
117
+ six==1.16.0
118
+ skypilot==0.4.1
119
+ smmap==5.0.1
120
+ sniffio==1.3.1
121
+ SQLAlchemy==2.0.29
122
+ starlette==0.37.2
123
+ streamlit==1.33.0
124
+ structlog==24.1.0
125
+ supervision==0.20.0
126
+ sympy==1.12
127
+ tabulate==0.9.0
128
+ tenacity==8.2.3
129
+ thop==0.1.1.post2209072238
130
+ tifffile==2024.4.24
131
+ time-machine==2.14.1
132
+ toml==0.10.2
133
+ tomli==2.0.1
134
+ toolz==0.12.1
135
+ torch==2.3.0
136
+ torchvision==0.18.0
137
+ tornado==6.4
138
+ tqdm==4.66.2
139
+ typer==0.9.0
140
+ typing-inspect==0.9.0
141
+ typing_extensions==4.11.0
142
+ tzdata==2024.1
143
+ tzlocal==5.2
144
+ ultralytics==8.2.3
145
+ urllib3==1.26.18
146
+ wcwidth==0.2.13
147
+ websocket-client==1.8.0
148
+ yarl==1.9.4
149
+ zxing-cpp==2.2.0