Spaces:
Sleeping
Sleeping
Update stock_analysis.py
Browse files- stock_analysis.py +20 -17
stock_analysis.py
CHANGED
@@ -10,32 +10,35 @@ def is_business_day(a_date):
|
|
10 |
return a_date.weekday() < 5
|
11 |
|
12 |
def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
|
13 |
-
predictions = []
|
14 |
-
confidence_intervals = []
|
15 |
-
|
16 |
if series.shape[1] > 1:
|
17 |
-
series = series['Close'].values
|
18 |
|
19 |
if model == "ARIMA":
|
20 |
model = ARIMA(series, order=(5, 1, 0))
|
21 |
model_fit = model.fit()
|
22 |
-
forecast = model_fit.forecast(steps=forecast_period
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
-
# Check if forecast is a numpy array (newer statsmodels) or a ForecastResults object (older statsmodels)
|
25 |
-
if isinstance(forecast, np.ndarray):
|
26 |
-
predictions = forecast
|
27 |
-
confidence_intervals = model_fit.get_forecast(steps=forecast_period).conf_int()
|
28 |
-
else:
|
29 |
-
predictions = forecast.predicted_mean
|
30 |
-
confidence_intervals = forecast.conf_int()
|
31 |
elif model == "Prophet":
|
32 |
# Implement Prophet forecasting method
|
33 |
pass
|
34 |
elif model == "LSTM":
|
35 |
# Implement LSTM forecasting method
|
36 |
pass
|
|
|
|
|
37 |
|
38 |
-
return predictions,
|
39 |
|
40 |
def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method, start_date, end_date):
|
41 |
stock_name, ticker_name = stock.split(":")
|
@@ -49,14 +52,14 @@ def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method,
|
|
49 |
predictions, confidence_intervals = forecast_series(series, model=forecast_method)
|
50 |
|
51 |
last_date = pd.to_datetime(series['Date'].values[-1])
|
52 |
-
forecast_dates = pd.date_range(start=last_date + timedelta(days=1), periods=
|
53 |
-
forecast_dates = [date for date in forecast_dates if is_business_day(date)]
|
54 |
|
55 |
forecast = pd.DataFrame({
|
56 |
"Date": forecast_dates,
|
57 |
"Forecast": predictions,
|
58 |
-
"Lower_CI": confidence_intervals[
|
59 |
-
"Upper_CI": confidence_intervals[
|
60 |
})
|
61 |
|
62 |
if graph_type == 'Line Graph':
|
|
|
10 |
return a_date.weekday() < 5
|
11 |
|
12 |
def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
|
|
|
|
|
|
|
13 |
if series.shape[1] > 1:
|
14 |
+
series = series['Close'].values
|
15 |
|
16 |
if model == "ARIMA":
|
17 |
model = ARIMA(series, order=(5, 1, 0))
|
18 |
model_fit = model.fit()
|
19 |
+
forecast = model_fit.forecast(steps=forecast_period)
|
20 |
+
|
21 |
+
# Get confidence intervals
|
22 |
+
conf_int = model_fit.get_forecast(steps=forecast_period).conf_int()
|
23 |
+
lower_ci = conf_int.iloc[:, 0] if isinstance(conf_int, pd.DataFrame) else conf_int[:, 0]
|
24 |
+
upper_ci = conf_int.iloc[:, 1] if isinstance(conf_int, pd.DataFrame) else conf_int[:, 1]
|
25 |
+
|
26 |
+
# Ensure all arrays have the same length
|
27 |
+
min_length = min(len(forecast), len(lower_ci), len(upper_ci))
|
28 |
+
predictions = forecast[:min_length]
|
29 |
+
lower_ci = lower_ci[:min_length]
|
30 |
+
upper_ci = upper_ci[:min_length]
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
elif model == "Prophet":
|
33 |
# Implement Prophet forecasting method
|
34 |
pass
|
35 |
elif model == "LSTM":
|
36 |
# Implement LSTM forecasting method
|
37 |
pass
|
38 |
+
else:
|
39 |
+
raise ValueError(f"Unsupported model: {model}")
|
40 |
|
41 |
+
return predictions, pd.DataFrame({'Lower_CI': lower_ci, 'Upper_CI': upper_ci})
|
42 |
|
43 |
def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method, start_date, end_date):
|
44 |
stock_name, ticker_name = stock.split(":")
|
|
|
52 |
predictions, confidence_intervals = forecast_series(series, model=forecast_method)
|
53 |
|
54 |
last_date = pd.to_datetime(series['Date'].values[-1])
|
55 |
+
forecast_dates = pd.date_range(start=last_date + timedelta(days=1), periods=len(predictions))
|
56 |
+
forecast_dates = [date for date in forecast_dates if is_business_day(date)][:len(predictions)]
|
57 |
|
58 |
forecast = pd.DataFrame({
|
59 |
"Date": forecast_dates,
|
60 |
"Forecast": predictions,
|
61 |
+
"Lower_CI": confidence_intervals['Lower_CI'],
|
62 |
+
"Upper_CI": confidence_intervals['Upper_CI']
|
63 |
})
|
64 |
|
65 |
if graph_type == 'Line Graph':
|