mobenta commited on
Commit
8755ade
·
verified ·
1 Parent(s): 2c99a8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -129
app.py CHANGED
@@ -11,6 +11,9 @@ import tempfile
11
  import os
12
  import spaces
13
  import pandas as pd
 
 
 
14
 
15
  # Configure logging
16
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -43,123 +46,71 @@ def fetch_stock_data(ticker='TSLA', start='2010-01-01', end=None, interval='1d')
43
  data = stock.history(start=start, end=end, interval=interval)
44
  if data.empty:
45
  logging.warning(f"No data fetched for {ticker} in the range {start} to {end}")
46
- return pd.DataFrame() # Return empty DataFrame instead of raising an exception
47
  logging.debug(f"Fetched data with {len(data)} rows")
48
  return data
49
  except Exception as e:
50
  logging.error(f"Error fetching data: {e}")
51
- return pd.DataFrame() # Return empty DataFrame on any exception
52
 
53
  def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d', indicators=None):
54
- try:
55
- logging.debug(f"Creating chart for {ticker} with timeframe {timeframe} and saving to {filename}")
56
- title = f"{ticker.upper()} Price Data (Timeframe: {timeframe})"
57
-
58
- plt.rcParams["axes.titlesize"] = 10
59
- my_style = mpf.make_mpf_style(base_mpf_style='charles')
60
-
61
- # Calculate indicators if selected
62
- addplot = []
63
- if indicators:
64
- if 'RSI' in indicators:
65
- delta = data['Close'].diff(1)
66
- gain = delta.where(delta > 0, 0)
67
- loss = -delta.where(delta < 0, 0)
68
- avg_gain = gain.rolling(window=14).mean()
69
- avg_loss = loss.rolling(window=14).mean()
70
- rs = avg_gain / avg_loss
71
- rsi = 100 - (100 / (1 + rs))
72
- addplot.append(mpf.make_addplot(rsi, panel=2, color='orange', ylabel='RSI'))
73
- if 'SMA21' in indicators:
74
- logging.debug("Calculating SMA 21")
75
- sma_21 = data['Close'].rolling(window=21).mean()
76
- addplot.append(mpf.make_addplot(sma_21, color='purple', linestyle='dashed'))
77
- if 'SMA50' in indicators:
78
- logging.debug("Calculating SMA 50")
79
- sma_50 = data['Close'].rolling(window=50).mean()
80
- addplot.append(mpf.make_addplot(sma_50, color='orange', linestyle='dashed'))
81
- if 'SMA200' in indicators:
82
- logging.debug("Calculating SMA 200")
83
- sma_200 = data['Close'].rolling(window=200).mean()
84
- addplot.append(mpf.make_addplot(sma_200, color='brown', linestyle='dashed'))
85
- if 'VWAP' in indicators:
86
- logging.debug("Calculating VWAP")
87
- vwap = (data['Volume'] * (data['High'] + data['Low'] + data['Close']) / 3).cumsum() / data['Volume'].cumsum()
88
- addplot.append(mpf.make_addplot(vwap, color='blue', linestyle='dashed'))
89
- if 'Bollinger Bands' in indicators:
90
- logging.debug("Calculating Bollinger Bands")
91
- sma = data['Close'].rolling(window=20).mean()
92
- std = data['Close'].rolling(window=20).std()
93
- upper_band = sma + (std * 2)
94
- lower_band = sma - (std * 2)
95
- addplot.append(mpf.make_addplot(upper_band, color='green', linestyle='dashed'))
96
- addplot.append(mpf.make_addplot(lower_band, color='green', linestyle='dashed'))
97
-
98
- fig, axlist = mpf.plot(data, type='candle', style=my_style, volume=True, addplot=addplot, returnfig=True)
99
- fig.suptitle(title, y=0.98)
100
-
101
- # Save chart image
102
- fig.savefig(filename, dpi=300)
103
- plt.close(fig)
104
-
105
- # Open and add financial data to the image
106
- image = Image.open(filename)
107
- draw = ImageDraw.Draw(image)
108
- font = ImageFont.load_default() # Use default font, you can also use custom fonts if available
109
-
110
- # Financial metrics to add
111
- metrics = {
112
- "Ticker": ticker,
113
- "Latest Close": f"${data['Close'].iloc[-1]:,.2f}",
114
- "Volume": f"{data['Volume'].iloc[-1]:,.0f}"
115
- }
116
-
117
- # Add additional metrics if indicators are present
118
- if 'SMA21' in indicators:
119
- metrics["SMA 21"] = f"${data['Close'].rolling(window=21).mean().iloc[-1]:,.2f}"
120
- if 'SMA50' in indicators:
121
- metrics["SMA 50"] = f"${data['Close'].rolling(window=50).mean().iloc[-1]:,.2f}"
122
- if 'SMA200' in indicators:
123
- metrics["SMA 200"] = f"${data['Close'].rolling(window=200).mean().iloc[-1]:,.2f}"
124
-
125
- # Draw metrics on the image
126
- y_text = image.height - 50 # Starting y position for text
127
- for key, value in metrics.items():
128
- text = f"{key}: {value}"
129
- draw.text((10, y_text), text, font=font, fill=(255, 255, 255)) # White color text
130
- y_text += 20
131
-
132
- # Resize image
133
- new_size = (image.width * 3, image.height * 3)
134
- resized_image = image.resize(new_size, Image.LANCZOS)
135
- resized_image.save(filename)
136
-
137
- logging.debug(f"Resized image with timeframe {timeframe} and ticker {ticker} saved to {filename}")
138
- except Exception as e:
139
- logging.error(f"Error creating or resizing chart: {e}")
140
- raise
141
 
142
  def combine_images(image_paths, output_path='combined_chart.png'):
143
- try:
144
- logging.debug(f"Combining images {image_paths} into {output_path}")
145
- images = [Image.open(path) for path in image_paths]
146
-
147
- # Calculate total width and max height for combined image
148
- total_width = sum(img.width for img in images)
149
- max_height = max(img.height for img in images)
150
-
151
- combined_image = Image.new('RGB', (total_width, max_height))
152
- x_offset = 0
153
- for img in images:
154
- combined_image.paste(img, (x_offset, 0))
155
- x_offset += img.width
156
-
157
- combined_image.save(output_path)
158
- logging.debug(f"Combined image saved to {output_path}")
159
- return output_path
160
- except Exception as e:
161
- logging.error(f"Error combining images: {e}")
162
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  def gradio_interface(ticker1, ticker2, ticker3, ticker4, start_date, end_date, query, analysis_type, interval, indicators):
165
  try:
@@ -167,12 +118,14 @@ def gradio_interface(ticker1, ticker2, ticker3, ticker4, start_date, end_date, q
167
 
168
  tickers = [ticker1, ticker2, ticker3, ticker4]
169
  chart_paths = []
 
170
 
171
  for i, ticker in enumerate(tickers):
172
  if ticker:
173
  data = fetch_stock_data(ticker, start=start_date, end=end_date, interval=interval)
174
  if data.empty:
175
  return f"No data available for {ticker} in the specified date range.", None
 
176
  with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_chart:
177
  chart_path = temp_chart.name
178
  create_stock_chart(data, ticker, chart_path, timeframe=interval, indicators=indicators)
@@ -184,13 +137,26 @@ def gradio_interface(ticker1, ticker2, ticker3, ticker4, start_date, end_date, q
184
  combine_images(chart_paths, combined_chart_path)
185
  insights = predict(Image.open(combined_chart_path), query)
186
  return insights, combined_chart_path
187
-
188
- # No comparative analysis, just return the single chart
189
- if chart_paths:
190
- insights = predict(Image.open(chart_paths[0]), query)
191
- return insights, chart_paths[0]
 
 
 
 
 
 
 
 
192
  else:
193
- return "No tickers provided.", None
 
 
 
 
 
194
  except Exception as e:
195
  logging.error(f"Error in Gradio interface: {e}")
196
  return f"Error processing image or query: {e}", None
@@ -198,42 +164,46 @@ def gradio_interface(ticker1, ticker2, ticker3, ticker4, start_date, end_date, q
198
  def gradio_app():
199
  with gr.Blocks() as demo:
200
  gr.Markdown("""
201
- ## 📈Stock Analysis Dashboard
202
 
203
  This application provides a comprehensive stock analysis tool that allows users to input up to four stock tickers, specify date ranges, and select various financial indicators. The core functionalities include:
204
 
205
  1. **Data Fetching and Chart Creation**: Historical stock data is fetched from Yahoo Finance, and candlestick charts are generated with optional financial indicators like RSI, SMA, VWAP, and Bollinger Bands.
206
 
207
- 2. **Text Analysis and Insights Generation**: The application uses a pre-trained model based on the **Paligema** architecture to analyze the input chart and text query, generating insightful analysis based on the provided financial data and context.
 
 
 
 
208
 
209
- 3. **User Interface**: Users can interactively select stocks, date ranges, intervals, and indicators. The app also supports the analysis of single tickers or comparative analysis across multiple tickers.
210
 
211
- 4. **Logging and Debugging**: Detailed logging helps in debugging and tracking the application's processes.
212
 
213
- 5. **Enhanced Image Processing**: The app adds financial metrics and annotations to the generated charts, ensuring clear presentation of data.
214
 
215
- This tool leverages the Paligema model to provide detailed insights into stock market trends, offering an interactive and educational experience for users.
216
  """)
217
 
218
  with gr.Row():
219
- ticker1 = gr.Textbox(label="Primary Ticker", value="GC=F")
220
- ticker2 = gr.Textbox(label="Secondary Ticker", value="CL=F")
221
- ticker3 = gr.Textbox(label="Third Ticker", value="SPY")
222
- ticker4 = gr.Textbox(label="Fourth Ticker", value="EURUSD=X")
223
 
224
  with gr.Row():
225
  start_date = gr.Textbox(label="Start Date", value="2022-01-01")
226
  end_date = gr.Textbox(label="End Date", value=datetime.date.today().isoformat())
227
- interval = gr.Dropdown(label="Interval", choices=['1m', '2m', '5m', '15m', '30m', '60m', '90m', '1h', '1d', '5d', '1wk', '1mo', '3mo'], value='1d')
228
 
229
  with gr.Row():
230
  indicators = gr.CheckboxGroup(label="Indicators", choices=['RSI', 'SMA21', 'SMA50', 'SMA200', 'VWAP', 'Bollinger Bands'], value=['SMA21', 'SMA50'])
231
- analysis_type = gr.Radio(label="Analysis Type", choices=['Single Ticker', 'Comparative Analysis'], value='Single Ticker')
232
 
233
  query = gr.Textbox(label="Analysis Query", value="Analyze the price trends.")
234
  analyze_button = gr.Button("Analyze")
235
- output_image = gr.Image(label="Stock Chart")
236
- output_text = gr.Textbox(label="Generated Insights", lines=5)
237
 
238
  analyze_button.click(
239
  fn=gradio_interface,
 
11
  import os
12
  import spaces
13
  import pandas as pd
14
+ import numpy as np
15
+ from scipy import stats
16
+ import seaborn as sns
17
 
18
  # Configure logging
19
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
 
46
  data = stock.history(start=start, end=end, interval=interval)
47
  if data.empty:
48
  logging.warning(f"No data fetched for {ticker} in the range {start} to {end}")
49
+ return pd.DataFrame()
50
  logging.debug(f"Fetched data with {len(data)} rows")
51
  return data
52
  except Exception as e:
53
  logging.error(f"Error fetching data: {e}")
54
+ return pd.DataFrame()
55
 
56
  def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d', indicators=None):
57
+ # ... (keep the existing create_stock_chart function as is) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def combine_images(image_paths, output_path='combined_chart.png'):
60
+ # ... (keep the existing combine_images function as is) ...
61
+
62
+ def perform_trend_analysis(data):
63
+ # Perform trend analysis
64
+ close_prices = data['Close']
65
+ time_index = np.arange(len(close_prices))
66
+ slope, intercept, r_value, p_value, std_err = stats.linregress(time_index, close_prices)
67
+
68
+ trend_line = slope * time_index + intercept
69
+
70
+ plt.figure(figsize=(12, 6))
71
+ plt.plot(close_prices.index, close_prices, label='Close Price')
72
+ plt.plot(close_prices.index, trend_line, color='red', label='Trend Line')
73
+ plt.title('Trend Analysis')
74
+ plt.xlabel('Date')
75
+ plt.ylabel('Price')
76
+ plt.legend()
77
+
78
+ trend_filename = 'trend_analysis.png'
79
+ plt.savefig(trend_filename)
80
+ plt.close()
81
+
82
+ trend_strength = abs(r_value)
83
+ trend_direction = "upward" if slope > 0 else "downward"
84
+
85
+ analysis_text = f"Trend Analysis:\n"
86
+ analysis_text += f"The stock shows a {trend_direction} trend.\n"
87
+ analysis_text += f"Trend strength (R-squared): {trend_strength:.2f}\n"
88
+ analysis_text += f"Slope: {slope:.4f}\n"
89
+
90
+ return analysis_text, trend_filename
91
+
92
+ def perform_correlation_analysis(data_dict):
93
+ # Perform correlation analysis
94
+ combined_data = pd.DataFrame({ticker: data['Close'] for ticker, data in data_dict.items()})
95
+ correlation_matrix = combined_data.corr()
96
+
97
+ plt.figure(figsize=(10, 8))
98
+ sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
99
+ plt.title('Correlation Analysis')
100
+
101
+ corr_filename = 'correlation_analysis.png'
102
+ plt.savefig(corr_filename)
103
+ plt.close()
104
+
105
+ analysis_text = f"Correlation Analysis:\n"
106
+ for i in range(len(correlation_matrix.columns)):
107
+ for j in range(i+1, len(correlation_matrix.columns)):
108
+ ticker1 = correlation_matrix.columns[i]
109
+ ticker2 = correlation_matrix.columns[j]
110
+ corr = correlation_matrix.iloc[i, j]
111
+ analysis_text += f"Correlation between {ticker1} and {ticker2}: {corr:.2f}\n"
112
+
113
+ return analysis_text, corr_filename
114
 
115
  def gradio_interface(ticker1, ticker2, ticker3, ticker4, start_date, end_date, query, analysis_type, interval, indicators):
116
  try:
 
118
 
119
  tickers = [ticker1, ticker2, ticker3, ticker4]
120
  chart_paths = []
121
+ data_dict = {}
122
 
123
  for i, ticker in enumerate(tickers):
124
  if ticker:
125
  data = fetch_stock_data(ticker, start=start_date, end=end_date, interval=interval)
126
  if data.empty:
127
  return f"No data available for {ticker} in the specified date range.", None
128
+ data_dict[ticker] = data
129
  with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_chart:
130
  chart_path = temp_chart.name
131
  create_stock_chart(data, ticker, chart_path, timeframe=interval, indicators=indicators)
 
137
  combine_images(chart_paths, combined_chart_path)
138
  insights = predict(Image.open(combined_chart_path), query)
139
  return insights, combined_chart_path
140
+ elif analysis_type == 'Trend Analysis':
141
+ if len(data_dict) > 0:
142
+ first_ticker = list(data_dict.keys())[0]
143
+ analysis_text, trend_chart = perform_trend_analysis(data_dict[first_ticker])
144
+ return analysis_text, trend_chart
145
+ else:
146
+ return "No data available for trend analysis.", None
147
+ elif analysis_type == 'Correlation Analysis':
148
+ if len(data_dict) > 1:
149
+ analysis_text, corr_chart = perform_correlation_analysis(data_dict)
150
+ return analysis_text, corr_chart
151
+ else:
152
+ return "At least two tickers are required for correlation analysis.", None
153
  else:
154
+ # Single ticker analysis
155
+ if chart_paths:
156
+ insights = predict(Image.open(chart_paths[0]), query)
157
+ return insights, chart_paths[0]
158
+ else:
159
+ return "No tickers provided.", None
160
  except Exception as e:
161
  logging.error(f"Error in Gradio interface: {e}")
162
  return f"Error processing image or query: {e}", None
 
164
  def gradio_app():
165
  with gr.Blocks() as demo:
166
  gr.Markdown("""
167
+ ## 📈 Advanced Stock Analysis Dashboard
168
 
169
  This application provides a comprehensive stock analysis tool that allows users to input up to four stock tickers, specify date ranges, and select various financial indicators. The core functionalities include:
170
 
171
  1. **Data Fetching and Chart Creation**: Historical stock data is fetched from Yahoo Finance, and candlestick charts are generated with optional financial indicators like RSI, SMA, VWAP, and Bollinger Bands.
172
 
173
+ 2. **Text Analysis and Insights Generation**: The application uses a pre-trained model based on the Paligema architecture to analyze the input chart and text query, generating insightful analysis based on the provided financial data and context.
174
+
175
+ 3. **Trend Analysis**: Performs trend analysis on a single stock, showing the trend line and providing information about the trend strength and direction.
176
+
177
+ 4. **Correlation Analysis**: Analyzes the correlation between multiple stocks, providing a correlation matrix and heatmap.
178
 
179
+ 5. **User Interface**: Users can interactively select stocks, date ranges, intervals, and indicators. The app supports single ticker analysis, comparative analysis, trend analysis, and correlation analysis.
180
 
181
+ 6. **Logging and Debugging**: Detailed logging helps in debugging and tracking the application's processes.
182
 
183
+ 7. **Enhanced Image Processing**: The app adds financial metrics and annotations to the generated charts, ensuring clear presentation of data.
184
 
185
+ This tool leverages various analysis techniques to provide detailed insights into stock market trends, offering an interactive and educational experience for users.
186
  """)
187
 
188
  with gr.Row():
189
+ ticker1 = gr.Textbox(label="Primary Ticker", value="AAPL")
190
+ ticker2 = gr.Textbox(label="Secondary Ticker", value="MSFT")
191
+ ticker3 = gr.Textbox(label="Third Ticker", value="GOOGL")
192
+ ticker4 = gr.Textbox(label="Fourth Ticker", value="AMZN")
193
 
194
  with gr.Row():
195
  start_date = gr.Textbox(label="Start Date", value="2022-01-01")
196
  end_date = gr.Textbox(label="End Date", value=datetime.date.today().isoformat())
197
+ interval = gr.Dropdown(label="Interval", choices=['1d', '5d', '1wk', '1mo', '3mo'], value='1d')
198
 
199
  with gr.Row():
200
  indicators = gr.CheckboxGroup(label="Indicators", choices=['RSI', 'SMA21', 'SMA50', 'SMA200', 'VWAP', 'Bollinger Bands'], value=['SMA21', 'SMA50'])
201
+ analysis_type = gr.Radio(label="Analysis Type", choices=['Single Ticker', 'Comparative Analysis', 'Trend Analysis', 'Correlation Analysis'], value='Single Ticker')
202
 
203
  query = gr.Textbox(label="Analysis Query", value="Analyze the price trends.")
204
  analyze_button = gr.Button("Analyze")
205
+ output_image = gr.Image(label="Analysis Chart")
206
+ output_text = gr.Textbox(label="Generated Insights", lines=10)
207
 
208
  analyze_button.click(
209
  fn=gradio_interface,