Kr08 commited on
Commit
c992acc
·
verified ·
1 Parent(s): 9ece7f3

Update stock_analysis.py

Browse files
Files changed (1) hide show
  1. stock_analysis.py +20 -17
stock_analysis.py CHANGED
@@ -10,32 +10,35 @@ def is_business_day(a_date):
10
  return a_date.weekday() < 5
11
 
12
  def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
13
- predictions = []
14
- confidence_intervals = []
15
-
16
  if series.shape[1] > 1:
17
- series = series['Close'].values.tolist()
18
 
19
  if model == "ARIMA":
20
  model = ARIMA(series, order=(5, 1, 0))
21
  model_fit = model.fit()
22
- forecast = model_fit.forecast(steps=forecast_period, alpha=(1 - CONFIDENCE_INTERVAL))
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # Check if forecast is a numpy array (newer statsmodels) or a ForecastResults object (older statsmodels)
25
- if isinstance(forecast, np.ndarray):
26
- predictions = forecast
27
- confidence_intervals = model_fit.get_forecast(steps=forecast_period).conf_int()
28
- else:
29
- predictions = forecast.predicted_mean
30
- confidence_intervals = forecast.conf_int()
31
  elif model == "Prophet":
32
  # Implement Prophet forecasting method
33
  pass
34
  elif model == "LSTM":
35
  # Implement LSTM forecasting method
36
  pass
 
 
37
 
38
- return predictions, confidence_intervals
39
 
40
  def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method, start_date, end_date):
41
  stock_name, ticker_name = stock.split(":")
@@ -49,14 +52,14 @@ def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method,
49
  predictions, confidence_intervals = forecast_series(series, model=forecast_method)
50
 
51
  last_date = pd.to_datetime(series['Date'].values[-1])
52
- forecast_dates = pd.date_range(start=last_date + timedelta(days=1), periods=FORECAST_PERIOD)
53
- forecast_dates = [date for date in forecast_dates if is_business_day(date)]
54
 
55
  forecast = pd.DataFrame({
56
  "Date": forecast_dates,
57
  "Forecast": predictions,
58
- "Lower_CI": confidence_intervals[:, 0],
59
- "Upper_CI": confidence_intervals[:, 1]
60
  })
61
 
62
  if graph_type == 'Line Graph':
 
10
  return a_date.weekday() < 5
11
 
12
  def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
 
 
 
13
  if series.shape[1] > 1:
14
+ series = series['Close'].values
15
 
16
  if model == "ARIMA":
17
  model = ARIMA(series, order=(5, 1, 0))
18
  model_fit = model.fit()
19
+ forecast = model_fit.forecast(steps=forecast_period)
20
+
21
+ # Get confidence intervals
22
+ conf_int = model_fit.get_forecast(steps=forecast_period).conf_int()
23
+ lower_ci = conf_int.iloc[:, 0] if isinstance(conf_int, pd.DataFrame) else conf_int[:, 0]
24
+ upper_ci = conf_int.iloc[:, 1] if isinstance(conf_int, pd.DataFrame) else conf_int[:, 1]
25
+
26
+ # Ensure all arrays have the same length
27
+ min_length = min(len(forecast), len(lower_ci), len(upper_ci))
28
+ predictions = forecast[:min_length]
29
+ lower_ci = lower_ci[:min_length]
30
+ upper_ci = upper_ci[:min_length]
31
 
 
 
 
 
 
 
 
32
  elif model == "Prophet":
33
  # Implement Prophet forecasting method
34
  pass
35
  elif model == "LSTM":
36
  # Implement LSTM forecasting method
37
  pass
38
+ else:
39
+ raise ValueError(f"Unsupported model: {model}")
40
 
41
+ return predictions, pd.DataFrame({'Lower_CI': lower_ci, 'Upper_CI': upper_ci})
42
 
43
  def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method, start_date, end_date):
44
  stock_name, ticker_name = stock.split(":")
 
52
  predictions, confidence_intervals = forecast_series(series, model=forecast_method)
53
 
54
  last_date = pd.to_datetime(series['Date'].values[-1])
55
+ forecast_dates = pd.date_range(start=last_date + timedelta(days=1), periods=len(predictions))
56
+ forecast_dates = [date for date in forecast_dates if is_business_day(date)][:len(predictions)]
57
 
58
  forecast = pd.DataFrame({
59
  "Date": forecast_dates,
60
  "Forecast": predictions,
61
+ "Lower_CI": confidence_intervals['Lower_CI'],
62
+ "Upper_CI": confidence_intervals['Upper_CI']
63
  })
64
 
65
  if graph_type == 'Line Graph':