fix bugs: remove import of psycopg2, fix product_name→name in default db filter
dd74596
import gradio as gr | |
import pandas as pd | |
from prophet import Prophet | |
import plotly.graph_objs as go | |
import re | |
import logging | |
import os | |
import torch | |
from chronos import ChronosPipeline | |
import numpy as np | |
import requests | |
import tempfile | |
try: | |
from google.colab import userdata | |
PG_PASSWORD = userdata.get('FASHION_PG_PASS') | |
CH_PASSWORD = userdata.get('FASHION_CH_PASS') | |
except: | |
PG_PASSWORD = os.environ['FASHION_PG_PASS'] | |
CH_PASSWORD = os.environ['FASHION_CH_PASS'] | |
logging.getLogger("prophet").setLevel(logging.WARNING) | |
logging.getLogger("cmdstanpy").setLevel(logging.WARNING) | |
# Dictionary to map Russian month names to month numbers | |
russian_months = { | |
"январь": "01", "февраль": "02", "март": "03", "апрель": "04", | |
"май": "05", "июнь": "06", "июль": "07", "август": "08", | |
"сентябрь": "09", "октябрь": "10", "ноябрь": "11", "декабрь": "12" | |
} | |
def read_and_process_file(file): | |
# Read the first three lines as a single text string | |
with open(file.name, 'r') as f: | |
first_three_lines = ''.join([next(f) for _ in range(3)]) | |
# Check for "Неделя" or "Week" (case-insensitive) | |
if not any(word in first_three_lines.lower() for word in ["неделя", "week"]): | |
period_type = "Month" | |
else: | |
period_type = "Week" | |
# Read the file again to process it | |
with open(file.name, 'r') as f: | |
lines = f.readlines() | |
# Check if the second line is empty | |
if lines[1].strip() == '': | |
source = 'Google' | |
data = pd.read_csv(file.name, skiprows=2) | |
# Replace any occurrences of "<1" with 0 | |
else: | |
source = 'Yandex' | |
data = pd.read_csv(file.name, sep=';', skiprows=0, usecols=[0, 2]) | |
if period_type == "Month": | |
# Replace Russian months with yyyy-MM format | |
data.iloc[:, 0] = data.iloc[:, 0].apply(lambda x: re.sub(r'(\w+)\s(\d{4})', lambda m: f'{m.group(2)}-{russian_months[m.group(1).lower()]}', x) + '-01') | |
if period_type == "Week": | |
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format="%d.%m.%Y") | |
# Replace any occurrences of "<1" with 0 | |
data.iloc[:, 1] = data.iloc[:, 1].apply(str).str.replace('<1', '0').str.replace(' ', '').str.replace(',', '.').astype(float) | |
# Process the date column and set it as the index | |
period_col = data.columns[0] | |
data[period_col] = pd.to_datetime(data[period_col]) | |
data.set_index(period_col, inplace=True) | |
return data, period_type, period_col | |
def get_data_from_db(query): | |
# conn = psycopg2.connect( | |
# dbname="kroyscappingdb", | |
# user="read_only", | |
# password=PG_PASSWORD, | |
# host="rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net", | |
# port="6432", | |
# sslmode="require" | |
# ) | |
cert_data = requests.get('https://storage.yandexcloud.net/cloud-certs/RootCA.pem').text | |
with tempfile.NamedTemporaryFile(delete=False) as temp_cert_file: | |
temp_cert_file.write(cert_data.encode()) | |
cert_file_path = temp_cert_file.name | |
client = Client(host='rc1d-a93v7vf0pjfr6e2o.mdb.yandexcloud.net', | |
port = 9440, | |
user='user1', | |
password=CH_PASSWORD, | |
database='db1', | |
secure=True, | |
ca_certs=cert_file_path) | |
# data = pd.read_sql_query(query, conn) | |
result, columns = client.execute(query, with_column_types=True) | |
column_names = [col[0] for col in columns] | |
data = pd.DataFrame(result, columns=column_names) | |
# conn.close() | |
return data | |
def forecast_time_series(file, product_name, wb, ozon, model_choice): | |
if file is None: | |
# Construct the query | |
marketplaces = [] | |
if wb: | |
marketplaces.append('wildberries') | |
if ozon: | |
marketplaces.append('ozon') | |
mp_filter = "', '".join(marketplaces) | |
# query = f""" | |
# select | |
# to_char(dm.end_date, 'yyyy-mm-dd') as ds, | |
# 1.0*sum(turnover) / (max(sum(turnover)) over ()) as y | |
# from v_datamart dm | |
# where {product_name} | |
# and mp in ('{mp_filter}') | |
# group by ds | |
# order by ds | |
# """ | |
query = f""" | |
select | |
cast(start_date as date) as ds, | |
1.0*sum(turnover) / (max(sum(turnover)) over ()) as y | |
from datamart_all_1 | |
join week_data | |
using (id_week) | |
where {product_name} | |
and mp in ('{mp_filter}') | |
group by ds | |
order by ds | |
""" | |
print(query) | |
data = get_data_from_db(query) | |
print(data) | |
period_type = "Week" | |
period_col = "ds" | |
if len(data)==0: | |
raise gr.Error("No data found in database. Please adjust filters") | |
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format='%Y-%m-%d') | |
data.set_index('ds', inplace=True) | |
else: | |
data, period_type, period_col = read_and_process_file(file) | |
if period_type == "Month": | |
year = 12 | |
n_periods = 24 | |
freq = "MS" | |
else: | |
year = 52 | |
n_periods = year * 2 | |
freq = "W" | |
df = data.reset_index().rename(columns={period_col: 'ds', data.columns[0]: 'y'}) | |
if model_choice == "Prophet": | |
forecast, yoy_change = forecast_prophet(df, n_periods, freq, year) | |
elif model_choice == "Chronos": | |
forecast, yoy_change = forecast_chronos(df, n_periods, freq, year) | |
else: | |
raise ValueError("Invalid model choice") | |
# Create Plotly figure (common for both models) | |
fig = create_plot(data, forecast) | |
# Combine original data and forecast | |
combined_df = pd.concat([data, forecast.set_index('ds')], axis=1) | |
# Save combined data | |
combined_file = 'combined_data.csv' | |
combined_df.to_csv(combined_file) | |
return fig, f'Year-over-Year Change in Sum of Values: {yoy_change:.2%}', combined_file | |
def forecast_prophet(df, n_periods, freq, year): | |
model = Prophet() | |
model.fit(df) | |
future = model.make_future_dataframe(periods=n_periods, freq=freq) | |
forecast = model.predict(future) | |
sum_last_year_original = df['y'].iloc[-year:].sum() | |
sum_first_year_forecast = forecast['yhat'].iloc[-n_periods:-n_periods + year].sum() | |
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original | |
return forecast, yoy_change | |
def forecast_chronos(df, n_periods, freq, year): | |
pipeline = ChronosPipeline.from_pretrained( | |
"amazon/chronos-t5-mini", | |
device_map="cpu", | |
torch_dtype=torch.bfloat16, | |
) | |
# Check for non-numeric values | |
if not pd.api.types.is_numeric_dtype(df['y']): | |
non_numeric = df[pd.to_numeric(df['y'], errors='coerce').isna()] | |
if not non_numeric.empty: | |
error_message = f"Non-numeric values found in 'y' column. First few problematic rows:\n{non_numeric.head().to_string()}" | |
raise ValueError(error_message) | |
try: | |
y_values = df['y'].values.astype(np.float32) | |
except ValueError as e: | |
raise ValueError(f"Unable to convert 'y' column to float32: {str(e)}") | |
chronos_forecast = pipeline.predict( | |
context=torch.tensor(y_values), | |
prediction_length=n_periods, | |
num_samples=20, | |
limit_prediction_length=False | |
) | |
forecast_index = pd.date_range(start=df['ds'].iloc[-1], periods=n_periods+1, freq=freq)[1:] | |
low, median, high = np.quantile(chronos_forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0) | |
forecast = pd.DataFrame({ | |
'ds': forecast_index, | |
'yhat': median, | |
'yhat_lower': low, | |
'yhat_upper': high | |
}) | |
sum_last_year_original = df['y'].iloc[-year:].sum() | |
sum_first_year_forecast = median[:year].sum() | |
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original | |
return forecast, yoy_change | |
def create_plot(data, forecast): | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=data.index, y=data.iloc[:, 0], mode='lines', name='Observed')) | |
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast', line=dict(color='red'))) | |
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_lower'], fill=None, mode='lines', line=dict(color='pink'), name='Lower CI')) | |
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_upper'], fill='tonexty', mode='lines', line=dict(color='pink'), name='Upper CI')) | |
fig.update_layout( | |
title='Observed Time Series and Forecast with Confidence Intervals', | |
xaxis_title='Date', | |
yaxis_title='Values', | |
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1), | |
hovermode='x unified' | |
) | |
return fig | |
# Create Gradio interface using Blocks | |
with gr.Blocks(theme=gr.themes.Monochrome()) as interface: | |
gr.Markdown("# Time Series Forecasting") | |
gr.Markdown("Upload a CSV file with a time series to forecast the next 2 years and see the YoY % change. Download the combined original and forecast data.") | |
with gr.Row(): | |
file_input = gr.File(label="Upload Time Series CSV") | |
with gr.Row(): | |
wb_checkbox = gr.Checkbox(label="Wildberries", value=True) | |
ozon_checkbox = gr.Checkbox(label="Ozon", value=True) | |
with gr.Row(): | |
product_name_input = gr.Textbox(label="Product Name Filter", value="name like '%пуховик%'") | |
with gr.Row(): | |
model_choice = gr.Radio(["Prophet", "Chronos"], label="Choose Model", value="Prophet") | |
with gr.Row(): | |
compute_button = gr.Button("Compute") | |
with gr.Row(): | |
plot_output = gr.Plot(label="Time Series + Forecast Chart") | |
with gr.Row(): | |
yoy_output = gr.Text(label="YoY % Change") | |
with gr.Row(): | |
csv_output = gr.File(label="Download Combined Data CSV") | |
compute_button.click( | |
forecast_time_series, | |
inputs=[file_input, product_name_input, wb_checkbox, ozon_checkbox, model_choice], | |
outputs=[plot_output, yoy_output, csv_output] | |
) | |
# Launch the interface | |
interface.launch(debug=True) |