trangngiosds's picture
initial commit
fadd436
raw
history blame
5.56 kB
import streamlit as st
import json
from autogluon.multimodal import MultiModalPredictor
import pandas as pd
from geopy.geocoders import GoogleV3
import os
import tempfile
st.set_page_config(layout="wide")
if "price_text" not in st.session_state:
st.session_state.price_text = 0
@st.cache_resource
def load_geocoder():
return GoogleV3(api_key=os.environ.get("GOOGLE_MAP_API_KEY"))
geocoder = load_geocoder()
@st.cache_resource
def load_mm_text_no_price_model():
return MultiModalPredictor.load("models/mm-text-no-price/", verbosity=0)
mm_text_no_price_predictor = load_mm_text_no_price_model()
@st.cache_resource
def load_city_map():
return json.load(open("city-map.json"))
city_map = load_city_map()
@st.cache_resource
def load_city_district_map():
return json.load(open("city-district-map.json"))
city_district_map = load_city_district_map()
CERT_STATUS = pd.CategoricalDtype(
categories=["Không có", "hợp đồng", "sổ đỏ / sổ hồng"], ordered=False
)
DIRECTION = pd.CategoricalDtype(
categories=[
"Không có",
"Tây - Nam",
"Đông - Nam",
"Đông - Bắc",
"Tây - Bắc",
"Nam",
"Tây",
"Bắc",
"Đông",
],
ordered=False,
)
CITY = pd.CategoricalDtype(categories=city_map.keys(), ordered=False)
DISTRICT = pd.CategoricalDtype(
categories=sum([list(map(int, v.keys())) for v in city_district_map.values()], []),
ordered=False,
)
location_options = st.columns([1, 1, 2, 1, 1])
with location_options[0]:
city = st.selectbox(
"Choose city", options=city_map.items(), format_func=lambda x: x[1]
)
with location_options[1]:
district = st.selectbox(
"Choose district",
options=city_district_map[city[0]].items(),
format_func=lambda x: x[1],
)
with location_options[2]:
location = st.text_input("Enter precise location")
location = (location + ", " if location else "") + city[1] + ", " + district[1]
geocode_result = geocoder.geocode(query=location, region="vn", language="vi")
latitude = geocode_result.latitude
longitude = geocode_result.longitude
with location_options[3]:
latitude = st.number_input(
"Enter latitude", value=latitude, step=1e-8, format="%.7f"
)
with location_options[4]:
longitude = st.number_input(
"Enter longitude", value=longitude, step=1e-8, format="%.7f"
)
numerical_options = st.columns(6)
with numerical_options[0]:
area = st.number_input("Area (m2)", min_value=1.0)
with numerical_options[1]:
bedrooms = st.number_input("Number of bedrooms", min_value=1, value=1)
with numerical_options[2]:
bathrooms = st.number_input("Number of bathrooms", min_value=1, value=1)
with numerical_options[3]:
floors = st.number_input("Number of floors", min_value=1, value=1)
with numerical_options[4]:
front_width = st.number_input(
"Front width, leave 0 for N/A", min_value=0.0, value=0.0, step=0.1
)
with numerical_options[5]:
road_width = st.number_input(
"Road width, leave 0 for N/A", min_value=0.0, value=0.0, step=0.1
)
cat_time_columns = st.columns(4)
with cat_time_columns[0]:
timestamp = st.date_input("Date posted", format="DD/MM/YYYY")
with cat_time_columns[1]:
cert_status = st.selectbox("Certification status", options=CERT_STATUS.categories)
with cat_time_columns[2]:
direction = st.selectbox("Direction", options=DIRECTION.categories)
with cat_time_columns[3]:
balcony_direction = st.selectbox("Balcony direction", options=DIRECTION.categories)
description = st.text_area("Description")
title = description.split(".", maxsplit=1)[0]
uploaded_image = st.file_uploader("Upload an image")
image_tmp = None
if uploaded_image:
image_tmp = tempfile.NamedTemporaryFile(suffix=uploaded_image.name)
image_tmp.write(uploaded_image.read())
print(image_tmp.name)
df = pd.DataFrame(
[
{
"Title": title,
"Area": area,
"Location": location,
"Time stamp": timestamp,
"Certification status": cert_status,
"Direction": direction,
"Bedrooms": bedrooms,
"Bathrooms": bathrooms,
"Front width": front_width or float("nan"),
"Floor": floors,
"Description": description,
"Image URL": image_tmp.name if image_tmp else None,
"Road width": road_width or float("nan"),
"City_code": city[0],
"DistrictId": int(district[0]),
"Lattitude": latitude,
"Longitude": longitude,
"Balcony_Direction": balcony_direction,
}
]
).astype(
{
"Title": "str",
"Area": "float",
"Location": "str",
"Time stamp": "datetime64[ns]",
"Certification status": CERT_STATUS,
"Direction": DIRECTION,
"Bedrooms": "int",
"Bathrooms": "int",
"Front width": "float",
"Floor": "int",
"Description": "str",
"Image URL": "str",
"Road width": "float",
"City_code": CITY,
"DistrictId": DISTRICT,
"Lattitude": "float",
"Longitude": "float",
"Balcony_Direction": DIRECTION,
}
)
if st.button("Get estimated price with text"):
st.session_state.price_text = mm_text_no_price_predictor.predict(
df, as_pandas=False
).item()
st.text(
"Estimated price: {0:,} VND".format(int(st.session_state.price_text * 1e6))
if st.session_state.price_text
else "No price estimated."
)