alexander-lazarin's picture
fix bugs: remove import of psycopg2, fix product_name→name in default db filter
dd74596
raw
history blame
10.2 kB
import gradio as gr
import pandas as pd
from prophet import Prophet
import plotly.graph_objs as go
import re
import logging
import os
import torch
from chronos import ChronosPipeline
import numpy as np
import requests
import tempfile
try:
from google.colab import userdata
PG_PASSWORD = userdata.get('FASHION_PG_PASS')
CH_PASSWORD = userdata.get('FASHION_CH_PASS')
except:
PG_PASSWORD = os.environ['FASHION_PG_PASS']
CH_PASSWORD = os.environ['FASHION_CH_PASS']
logging.getLogger("prophet").setLevel(logging.WARNING)
logging.getLogger("cmdstanpy").setLevel(logging.WARNING)
# Dictionary to map Russian month names to month numbers
russian_months = {
"январь": "01", "февраль": "02", "март": "03", "апрель": "04",
"май": "05", "июнь": "06", "июль": "07", "август": "08",
"сентябрь": "09", "октябрь": "10", "ноябрь": "11", "декабрь": "12"
}
def read_and_process_file(file):
# Read the first three lines as a single text string
with open(file.name, 'r') as f:
first_three_lines = ''.join([next(f) for _ in range(3)])
# Check for "Неделя" or "Week" (case-insensitive)
if not any(word in first_three_lines.lower() for word in ["неделя", "week"]):
period_type = "Month"
else:
period_type = "Week"
# Read the file again to process it
with open(file.name, 'r') as f:
lines = f.readlines()
# Check if the second line is empty
if lines[1].strip() == '':
source = 'Google'
data = pd.read_csv(file.name, skiprows=2)
# Replace any occurrences of "<1" with 0
else:
source = 'Yandex'
data = pd.read_csv(file.name, sep=';', skiprows=0, usecols=[0, 2])
if period_type == "Month":
# Replace Russian months with yyyy-MM format
data.iloc[:, 0] = data.iloc[:, 0].apply(lambda x: re.sub(r'(\w+)\s(\d{4})', lambda m: f'{m.group(2)}-{russian_months[m.group(1).lower()]}', x) + '-01')
if period_type == "Week":
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format="%d.%m.%Y")
# Replace any occurrences of "<1" with 0
data.iloc[:, 1] = data.iloc[:, 1].apply(str).str.replace('<1', '0').str.replace(' ', '').str.replace(',', '.').astype(float)
# Process the date column and set it as the index
period_col = data.columns[0]
data[period_col] = pd.to_datetime(data[period_col])
data.set_index(period_col, inplace=True)
return data, period_type, period_col
def get_data_from_db(query):
# conn = psycopg2.connect(
# dbname="kroyscappingdb",
# user="read_only",
# password=PG_PASSWORD,
# host="rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net",
# port="6432",
# sslmode="require"
# )
cert_data = requests.get('https://storage.yandexcloud.net/cloud-certs/RootCA.pem').text
with tempfile.NamedTemporaryFile(delete=False) as temp_cert_file:
temp_cert_file.write(cert_data.encode())
cert_file_path = temp_cert_file.name
client = Client(host='rc1d-a93v7vf0pjfr6e2o.mdb.yandexcloud.net',
port = 9440,
user='user1',
password=CH_PASSWORD,
database='db1',
secure=True,
ca_certs=cert_file_path)
# data = pd.read_sql_query(query, conn)
result, columns = client.execute(query, with_column_types=True)
column_names = [col[0] for col in columns]
data = pd.DataFrame(result, columns=column_names)
# conn.close()
return data
def forecast_time_series(file, product_name, wb, ozon, model_choice):
if file is None:
# Construct the query
marketplaces = []
if wb:
marketplaces.append('wildberries')
if ozon:
marketplaces.append('ozon')
mp_filter = "', '".join(marketplaces)
# query = f"""
# select
# to_char(dm.end_date, 'yyyy-mm-dd') as ds,
# 1.0*sum(turnover) / (max(sum(turnover)) over ()) as y
# from v_datamart dm
# where {product_name}
# and mp in ('{mp_filter}')
# group by ds
# order by ds
# """
query = f"""
select
cast(start_date as date) as ds,
1.0*sum(turnover) / (max(sum(turnover)) over ()) as y
from datamart_all_1
join week_data
using (id_week)
where {product_name}
and mp in ('{mp_filter}')
group by ds
order by ds
"""
print(query)
data = get_data_from_db(query)
print(data)
period_type = "Week"
period_col = "ds"
if len(data)==0:
raise gr.Error("No data found in database. Please adjust filters")
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format='%Y-%m-%d')
data.set_index('ds', inplace=True)
else:
data, period_type, period_col = read_and_process_file(file)
if period_type == "Month":
year = 12
n_periods = 24
freq = "MS"
else:
year = 52
n_periods = year * 2
freq = "W"
df = data.reset_index().rename(columns={period_col: 'ds', data.columns[0]: 'y'})
if model_choice == "Prophet":
forecast, yoy_change = forecast_prophet(df, n_periods, freq, year)
elif model_choice == "Chronos":
forecast, yoy_change = forecast_chronos(df, n_periods, freq, year)
else:
raise ValueError("Invalid model choice")
# Create Plotly figure (common for both models)
fig = create_plot(data, forecast)
# Combine original data and forecast
combined_df = pd.concat([data, forecast.set_index('ds')], axis=1)
# Save combined data
combined_file = 'combined_data.csv'
combined_df.to_csv(combined_file)
return fig, f'Year-over-Year Change in Sum of Values: {yoy_change:.2%}', combined_file
def forecast_prophet(df, n_periods, freq, year):
model = Prophet()
model.fit(df)
future = model.make_future_dataframe(periods=n_periods, freq=freq)
forecast = model.predict(future)
sum_last_year_original = df['y'].iloc[-year:].sum()
sum_first_year_forecast = forecast['yhat'].iloc[-n_periods:-n_periods + year].sum()
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
return forecast, yoy_change
def forecast_chronos(df, n_periods, freq, year):
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-mini",
device_map="cpu",
torch_dtype=torch.bfloat16,
)
# Check for non-numeric values
if not pd.api.types.is_numeric_dtype(df['y']):
non_numeric = df[pd.to_numeric(df['y'], errors='coerce').isna()]
if not non_numeric.empty:
error_message = f"Non-numeric values found in 'y' column. First few problematic rows:\n{non_numeric.head().to_string()}"
raise ValueError(error_message)
try:
y_values = df['y'].values.astype(np.float32)
except ValueError as e:
raise ValueError(f"Unable to convert 'y' column to float32: {str(e)}")
chronos_forecast = pipeline.predict(
context=torch.tensor(y_values),
prediction_length=n_periods,
num_samples=20,
limit_prediction_length=False
)
forecast_index = pd.date_range(start=df['ds'].iloc[-1], periods=n_periods+1, freq=freq)[1:]
low, median, high = np.quantile(chronos_forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
forecast = pd.DataFrame({
'ds': forecast_index,
'yhat': median,
'yhat_lower': low,
'yhat_upper': high
})
sum_last_year_original = df['y'].iloc[-year:].sum()
sum_first_year_forecast = median[:year].sum()
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
return forecast, yoy_change
def create_plot(data, forecast):
fig = go.Figure()
fig.add_trace(go.Scatter(x=data.index, y=data.iloc[:, 0], mode='lines', name='Observed'))
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast', line=dict(color='red')))
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_lower'], fill=None, mode='lines', line=dict(color='pink'), name='Lower CI'))
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_upper'], fill='tonexty', mode='lines', line=dict(color='pink'), name='Upper CI'))
fig.update_layout(
title='Observed Time Series and Forecast with Confidence Intervals',
xaxis_title='Date',
yaxis_title='Values',
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
hovermode='x unified'
)
return fig
# Create Gradio interface using Blocks
with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
gr.Markdown("# Time Series Forecasting")
gr.Markdown("Upload a CSV file with a time series to forecast the next 2 years and see the YoY % change. Download the combined original and forecast data.")
with gr.Row():
file_input = gr.File(label="Upload Time Series CSV")
with gr.Row():
wb_checkbox = gr.Checkbox(label="Wildberries", value=True)
ozon_checkbox = gr.Checkbox(label="Ozon", value=True)
with gr.Row():
product_name_input = gr.Textbox(label="Product Name Filter", value="name like '%пуховик%'")
with gr.Row():
model_choice = gr.Radio(["Prophet", "Chronos"], label="Choose Model", value="Prophet")
with gr.Row():
compute_button = gr.Button("Compute")
with gr.Row():
plot_output = gr.Plot(label="Time Series + Forecast Chart")
with gr.Row():
yoy_output = gr.Text(label="YoY % Change")
with gr.Row():
csv_output = gr.File(label="Download Combined Data CSV")
compute_button.click(
forecast_time_series,
inputs=[file_input, product_name_input, wb_checkbox, ozon_checkbox, model_choice],
outputs=[plot_output, yoy_output, csv_output]
)
# Launch the interface
interface.launch(debug=True)