|
import gradio as gr |
|
import pandas as pd |
|
from prophet import Prophet |
|
import plotly.graph_objs as go |
|
import re |
|
import logging |
|
import psycopg2 |
|
import os |
|
import torch |
|
from chronos import ChronosPipeline |
|
import numpy as np |
|
|
|
try: |
|
from google.colab import userdata |
|
PG_PASSWORD = userdata.get('FASHION_PG_PASS') |
|
except: |
|
PG_PASSWORD = os.environ['FASHION_PG_PASS'] |
|
|
|
logging.getLogger("prophet").setLevel(logging.WARNING) |
|
logging.getLogger("cmdstanpy").setLevel(logging.WARNING) |
|
|
|
|
|
russian_months = { |
|
"январь": "01", "февраль": "02", "март": "03", "апрель": "04", |
|
"май": "05", "июнь": "06", "июль": "07", "август": "08", |
|
"сентябрь": "09", "октябрь": "10", "ноябрь": "11", "декабрь": "12" |
|
} |
|
|
|
def read_and_process_file(file): |
|
|
|
with open(file.name, 'r') as f: |
|
first_three_lines = ''.join([next(f) for _ in range(3)]) |
|
|
|
|
|
if not any(word in first_three_lines.lower() for word in ["неделя", "week"]): |
|
period_type = "Month" |
|
else: |
|
period_type = "Week" |
|
|
|
|
|
with open(file.name, 'r') as f: |
|
lines = f.readlines() |
|
|
|
|
|
if lines[1].strip() == '': |
|
source = 'Google' |
|
data = pd.read_csv(file.name, skiprows=2) |
|
|
|
else: |
|
source = 'Yandex' |
|
data = pd.read_csv(file.name, sep=';', skiprows=0, usecols=[0, 2]) |
|
if period_type == "Month": |
|
|
|
data.iloc[:, 0] = data.iloc[:, 0].apply(lambda x: re.sub(r'(\w+)\s(\d{4})', lambda m: f'{m.group(2)}-{russian_months[m.group(1).lower()]}', x) + '-01') |
|
if period_type == "Week": |
|
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format="%d.%m.%Y") |
|
|
|
data.iloc[:, 1] = data.iloc[:, 1].apply(str).str.replace('<1', '0').str.replace(' ', '').str.replace(',', '.').astype(float) |
|
|
|
|
|
period_col = data.columns[0] |
|
data[period_col] = pd.to_datetime(data[period_col]) |
|
data.set_index(period_col, inplace=True) |
|
|
|
return data, period_type, period_col |
|
|
|
def get_data_from_db(query): |
|
conn = psycopg2.connect( |
|
dbname="kroyscappingdb", |
|
user="read_only", |
|
password=PG_PASSWORD, |
|
host="rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net", |
|
port="6432", |
|
sslmode="require" |
|
) |
|
data = pd.read_sql_query(query, conn) |
|
conn.close() |
|
return data |
|
|
|
def forecast_time_series(file, product_name, wb, ozon, model_choice): |
|
if file is None: |
|
|
|
marketplaces = [] |
|
if wb: |
|
marketplaces.append('wildberries') |
|
if ozon: |
|
marketplaces.append('ozon') |
|
mp_filter = "', '".join(marketplaces) |
|
query = f""" |
|
select |
|
to_char(dm.end_date, 'yyyy-mm-dd') as ds, |
|
1.0*sum(turnover) / (max(sum(turnover)) over ()) as y |
|
from v_datamart dm |
|
where {product_name} |
|
and mp in ('{mp_filter}') |
|
group by ds |
|
order by ds |
|
""" |
|
print(query) |
|
data = get_data_from_db(query) |
|
period_type = "Week" |
|
period_col = "ds" |
|
|
|
if len(data)==0: |
|
raise gr.Error("No data found in database. Please adjust filters") |
|
|
|
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format='%Y-%m-%d') |
|
data.set_index('ds', inplace=True) |
|
else: |
|
data, period_type, period_col = read_and_process_file(file) |
|
|
|
if period_type == "Month": |
|
year = 12 |
|
n_periods = 24 |
|
freq = "MS" |
|
else: |
|
year = 52 |
|
n_periods = year * 2 |
|
freq = "W" |
|
|
|
df = data.reset_index().rename(columns={period_col: 'ds', data.columns[0]: 'y'}) |
|
|
|
if model_choice == "Prophet": |
|
forecast, yoy_change = forecast_prophet(df, n_periods, freq, year) |
|
elif model_choice == "Chronos": |
|
forecast, yoy_change = forecast_chronos(df, n_periods, freq, year) |
|
else: |
|
raise ValueError("Invalid model choice") |
|
|
|
|
|
fig = create_plot(data, forecast) |
|
|
|
|
|
combined_df = pd.concat([data, forecast.set_index('ds')], axis=1) |
|
|
|
|
|
combined_file = 'combined_data.csv' |
|
combined_df.to_csv(combined_file) |
|
|
|
return fig, f'Year-over-Year Change in Sum of Values: {yoy_change:.2%}', combined_file |
|
|
|
def forecast_prophet(df, n_periods, freq, year): |
|
model = Prophet() |
|
model.fit(df) |
|
future = model.make_future_dataframe(periods=n_periods, freq=freq) |
|
forecast = model.predict(future) |
|
|
|
sum_last_year_original = df['y'].iloc[-year:].sum() |
|
sum_first_year_forecast = forecast['yhat'].iloc[-n_periods:-n_periods + year].sum() |
|
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original |
|
|
|
return forecast, yoy_change |
|
|
|
def forecast_chronos(df, n_periods, freq, year): |
|
pipeline = ChronosPipeline.from_pretrained( |
|
"amazon/chronos-t5-mini", |
|
device_map="cpu", |
|
torch_dtype=torch.bfloat16, |
|
) |
|
|
|
|
|
if not pd.api.types.is_numeric_dtype(df['y']): |
|
non_numeric = df[pd.to_numeric(df['y'], errors='coerce').isna()] |
|
if not non_numeric.empty: |
|
error_message = f"Non-numeric values found in 'y' column. First few problematic rows:\n{non_numeric.head().to_string()}" |
|
raise ValueError(error_message) |
|
|
|
try: |
|
y_values = df['y'].values.astype(np.float32) |
|
except ValueError as e: |
|
raise ValueError(f"Unable to convert 'y' column to float32: {str(e)}") |
|
|
|
chronos_forecast = pipeline.predict( |
|
context=torch.tensor(y_values), |
|
prediction_length=n_periods, |
|
num_samples=20, |
|
limit_prediction_length=False |
|
) |
|
|
|
forecast_index = pd.date_range(start=df['ds'].iloc[-1], periods=n_periods+1, freq=freq)[1:] |
|
low, median, high = np.quantile(chronos_forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) |
|
|
|
forecast = pd.DataFrame({ |
|
'ds': forecast_index, |
|
'yhat': median, |
|
'yhat_lower': low, |
|
'yhat_upper': high |
|
}) |
|
|
|
sum_last_year_original = df['y'].iloc[-year:].sum() |
|
sum_first_year_forecast = median[:year].sum() |
|
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original |
|
|
|
return forecast, yoy_change |
|
|
|
def create_plot(data, forecast): |
|
fig = go.Figure() |
|
fig.add_trace(go.Scatter(x=data.index, y=data.iloc[:, 0], mode='lines', name='Observed')) |
|
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast', line=dict(color='red'))) |
|
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_lower'], fill=None, mode='lines', line=dict(color='pink'), name='Lower CI')) |
|
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_upper'], fill='tonexty', mode='lines', line=dict(color='pink'), name='Upper CI')) |
|
|
|
fig.update_layout( |
|
title='Observed Time Series and Forecast with Confidence Intervals', |
|
xaxis_title='Date', |
|
yaxis_title='Values', |
|
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1), |
|
hovermode='x unified' |
|
) |
|
|
|
return fig |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Monochrome()) as interface: |
|
gr.Markdown("# Time Series Forecasting") |
|
gr.Markdown("Upload a CSV file with a time series to forecast the next 2 years and see the YoY % change. Download the combined original and forecast data.") |
|
|
|
with gr.Row(): |
|
file_input = gr.File(label="Upload Time Series CSV") |
|
|
|
with gr.Row(): |
|
wb_checkbox = gr.Checkbox(label="Wildberries", value=True) |
|
ozon_checkbox = gr.Checkbox(label="Ozon", value=True) |
|
|
|
with gr.Row(): |
|
product_name_input = gr.Textbox(label="Product Name Filter", value="product_name like '%пуховик%'") |
|
|
|
with gr.Row(): |
|
model_choice = gr.Radio(["Prophet", "Chronos"], label="Choose Model", value="Prophet") |
|
|
|
with gr.Row(): |
|
compute_button = gr.Button("Compute") |
|
|
|
with gr.Row(): |
|
plot_output = gr.Plot(label="Time Series + Forecast Chart") |
|
|
|
with gr.Row(): |
|
yoy_output = gr.Text(label="YoY % Change") |
|
|
|
with gr.Row(): |
|
csv_output = gr.File(label="Download Combined Data CSV") |
|
|
|
compute_button.click( |
|
forecast_time_series, |
|
inputs=[file_input, product_name_input, wb_checkbox, ozon_checkbox, model_choice], |
|
outputs=[plot_output, yoy_output, csv_output] |
|
) |
|
|
|
|
|
interface.launch(debug=True) |