Stock_Forecasting / stock_analysis.py
Kr08's picture
Create stock_analysis.py
0407f12 verified
raw
history blame
2.9 kB
import pandas as pd
import plotly.graph_objects as go
from datetime import timedelta
from statsmodels.tsa.arima.model import ARIMA
from config import FORECAST_PERIOD, ticker_dict
from data_fetcher import get_stock_data, get_company_info
def is_business_day(a_date):
return a_date.weekday() < 5
def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
predictions = list()
if series.shape[1] > 1:
series = series['Close'].values.tolist()
if model == "ARIMA":
for _ in range(forecast_period):
model = ARIMA(series, order=(5, 1, 0))
model_fit = model.fit()
output = model_fit.forecast()
yhat = output[0]
predictions.append(yhat)
series.append(yhat)
elif model == "Prophet":
# Implement Prophet forecasting method
pass
elif model == "LSTM":
# Implement LSTM forecasting method
pass
return predictions
def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method):
stock_name, ticker_name = stock.split(":")
if ticker_dict[idx] == 'FTSE 100':
ticker_name += '.L' if ticker_name[-1] != '.' else 'L'
elif ticker_dict[idx] == 'CAC 40':
ticker_name += '.PA'
series = get_stock_data(ticker_name, interval)
predictions = forecast_series(series, model=forecast_method)
last_date = pd.to_datetime(series['Date'].values[-1])
forecast_week = []
i = 1
while len(forecast_week) < FORECAST_PERIOD:
next_date = last_date + timedelta(days=i)
if is_business_day(next_date):
forecast_week.append(next_date)
i += 1
predictions = predictions[:len(forecast_week)]
forecast_week = forecast_week[:len(predictions)]
forecast = pd.DataFrame({"Date": forecast_week, "Forecast": predictions})
if graph_type == 'Line Graph':
fig = go.Figure()
fig.add_trace(go.Scatter(x=series['Date'], y=series['Close'], mode='lines', name='Historical'))
fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
else: # Candlestick Graph
fig = go.Figure(data=[go.Candlestick(x=series['Date'],
open=series['Open'],
high=series['High'],
low=series['Low'],
close=series['Close'],
name='Historical')])
fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
fig.update_layout(title=f"Stock Price of {stock_name}",
xaxis_title="Date",
yaxis_title="Price")
fundamentals = get_company_info(ticker_name)
return fig, fundamentals