File size: 4,760 Bytes
3464b48
 
 
 
 
 
 
 
 
 
 
 
 
c992acc
3464b48
 
 
 
c992acc
 
 
 
 
 
1d1de4f
c992acc
1d1de4f
 
be4b326
3464b48
 
 
 
 
 
c992acc
 
3464b48
1d1de4f
 
 
 
 
 
c992acc
3464b48
 
 
 
 
 
 
 
 
 
 
 
 
c992acc
1d1de4f
 
 
 
 
 
 
3464b48
 
 
 
c992acc
 
3464b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from datetime import timedelta
from statsmodels.tsa.arima.model import ARIMA
from config import FORECAST_PERIOD, ticker_dict, CONFIDENCE_INTERVAL
from data_fetcher import get_stock_data, get_company_info

def is_business_day(a_date):
    return a_date.weekday() < 5

def forecast_series(series, model="ARIMA", forecast_period=FORECAST_PERIOD):
    if series.shape[1] > 1:
        series = series['Close'].values
    
    if model == "ARIMA":
        model = ARIMA(series, order=(5, 1, 0))
        model_fit = model.fit()
        forecast = model_fit.forecast(steps=forecast_period)
        
        # Get confidence intervals
        conf_int = model_fit.get_forecast(steps=forecast_period).conf_int()
        
        # Ensure all arrays have the same length
        min_length = min(len(forecast), conf_int.shape[0])
        predictions = forecast[:min_length]
        lower_ci = conf_int.iloc[:min_length, 0] if isinstance(conf_int, pd.DataFrame) else conf_int[:min_length, 0]
        upper_ci = conf_int.iloc[:min_length, 1] if isinstance(conf_int, pd.DataFrame) else conf_int[:min_length, 1]
        
    elif model == "Prophet":
        # Implement Prophet forecasting method
        pass
    elif model == "LSTM":
        # Implement LSTM forecasting method
        pass
    else:
        raise ValueError(f"Unsupported model: {model}")

    # Ensure all arrays are of the same length
    min_length = min(len(predictions), len(lower_ci), len(upper_ci))
    predictions = predictions[:min_length]
    lower_ci = lower_ci[:min_length]
    upper_ci = upper_ci[:min_length]

    return predictions, pd.DataFrame({'Lower_CI': lower_ci, 'Upper_CI': upper_ci})

def get_stock_graph_and_info(idx, stock, interval, graph_type, forecast_method, start_date, end_date):
    stock_name, ticker_name = stock.split(":")
    
    if ticker_dict[idx] == 'FTSE 100':
        ticker_name += '.L' if ticker_name[-1] != '.' else 'L'
    elif ticker_dict[idx] == 'CAC 40':
        ticker_name += '.PA'

    series = get_stock_data(ticker_name, interval, start_date, end_date)
    predictions, confidence_intervals = forecast_series(series, model=forecast_method)

    last_date = pd.to_datetime(series['Date'].values[-1])
    forecast_dates = pd.date_range(start=last_date + timedelta(days=1), periods=len(predictions))
    forecast_dates = [date for date in forecast_dates if is_business_day(date)]

    # Ensure all data has the same length
    min_length = min(len(predictions), len(forecast_dates), len(confidence_intervals))
    predictions = predictions[:min_length]
    forecast_dates = forecast_dates[:min_length]
    confidence_intervals = confidence_intervals.iloc[:min_length]

    forecast = pd.DataFrame({
        "Date": forecast_dates,
        "Forecast": predictions,
        "Lower_CI": confidence_intervals['Lower_CI'],
        "Upper_CI": confidence_intervals['Upper_CI']
    })

    if graph_type == 'Line Graph':
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=series['Date'], y=series['Close'], mode='lines', name='Historical'))
        fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
        fig.add_trace(go.Scatter(
            x=forecast['Date'].tolist() + forecast['Date'].tolist()[::-1],
            y=forecast['Upper_CI'].tolist() + forecast['Lower_CI'].tolist()[::-1],
            fill='toself',
            fillcolor='rgba(0,100,80,0.2)',
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        ))
    else:  # Candlestick Graph
        fig = go.Figure(data=[go.Candlestick(x=series['Date'],
                                             open=series['Open'],
                                             high=series['High'],
                                             low=series['Low'],
                                             close=series['Close'],
                                             name='Historical')])
        fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
        fig.add_trace(go.Scatter(
            x=forecast['Date'].tolist() + forecast['Date'].tolist()[::-1],
            y=forecast['Upper_CI'].tolist() + forecast['Lower_CI'].tolist()[::-1],
            fill='toself',
            fillcolor='rgba(0,100,80,0.2)',
            line=dict(color='rgba(255,255,255,0)'),
            hoverinfo="skip",
            showlegend=False
        ))

    fig.update_layout(title=f"Stock Price of {stock_name}",
                      xaxis_title="Date",
                      yaxis_title="Price")

    fundamentals = get_company_info(ticker_name)

    return fig, fundamentals