mobenta commited on
Commit
c1ea847
·
verified ·
1 Parent(s): aeaf850

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +354 -3
app.py CHANGED
@@ -1,9 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def gradio_app():
2
  with gr.Blocks() as demo:
3
  gr.Markdown("""
4
  ## 📈 Advanced Stock Analysis Dashboard
5
-
6
- [... rest of the markdown ...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """)
8
 
9
  ticker_options = [
@@ -152,7 +491,19 @@ def gradio_app():
152
  ticker3 = gr.Dropdown(label="Third Ticker", choices=ticker_options, value="GOOGL: Alphabet Inc.")
153
  ticker4 = gr.Dropdown(label="Fourth Ticker", choices=ticker_options, value="AMZN: Amazon.com, Inc.")
154
 
155
- # ... rest of the code ...
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  analyze_button.click(
158
  fn=gradio_interface,
 
1
+
2
+ import torch
3
+ import yfinance as yf
4
+ import matplotlib.pyplot as plt
5
+ import mplfinance as mpf
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import gradio as gr
8
+ import datetime
9
+ import logging
10
+ from transformers import AutoProcessor, AutoModelForPreTraining
11
+ import tempfile
12
+ import os
13
+ import spaces
14
+ import pandas as pd
15
+ import numpy as np
16
+ from scipy import stats
17
+ import seaborn as sns
18
+
19
+ # Configure logging
20
+ logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
21
+
22
+ # Load the chart_analysis model and processor
23
+ processor = AutoProcessor.from_pretrained("mobenta/chart_analysis")
24
+ model = AutoModelForPreTraining.from_pretrained("mobenta/chart_analysis")
25
+
26
+ @spaces.GPU
27
+ def predict(image, input_text):
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ model.to(device)
30
+
31
+ image = image.convert("RGB")
32
+ inputs = processor(text=input_text, images=image, return_tensors="pt")
33
+ inputs = {k: v.to(device) for k, v in inputs.items()}
34
+
35
+ prompt_length = inputs['input_ids'].shape[1]
36
+ generate_ids = model.generate(**inputs, max_new_tokens=512)
37
+ output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
38
+
39
+ return output_text
40
+
41
+ def fetch_stock_data(ticker='TSLA', start='2010-01-01', end=None, interval='1d'):
42
+ if end is None:
43
+ end = datetime.date.today().isoformat()
44
+ try:
45
+ logging.debug(f"Fetching data for {ticker} from {start} to {end} with interval {interval}")
46
+ stock = yf.Ticker(ticker)
47
+ data = stock.history(start=start, end=end, interval=interval)
48
+ if data.empty:
49
+ logging.warning(f"No data fetched for {ticker} in the range {start} to {end}")
50
+ return pd.DataFrame()
51
+ logging.debug(f"Fetched data with {len(data)} rows")
52
+ return data
53
+ except Exception as e:
54
+ logging.error(f"Error fetching data: {e}")
55
+ return pd.DataFrame()
56
+
57
+ def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d', indicators=None):
58
+ try:
59
+ logging.debug(f"Creating chart for {ticker} with timeframe {timeframe} and saving to {filename}")
60
+ title = f"{ticker.upper()} Price Data (Timeframe: {timeframe})"
61
+
62
+ plt.rcParams["axes.titlesize"] = 10
63
+ my_style = mpf.make_mpf_style(base_mpf_style='charles')
64
+
65
+ # Calculate indicators if selected
66
+ addplot = []
67
+ if indicators:
68
+ if 'RSI' in indicators:
69
+ delta = data['Close'].diff(1)
70
+ gain = delta.where(delta > 0, 0)
71
+ loss = -delta.where(delta < 0, 0)
72
+ avg_gain = gain.rolling(window=14).mean()
73
+ avg_loss = loss.rolling(window=14).mean()
74
+ rs = avg_gain / avg_loss
75
+ rsi = 100 - (100 / (1 + rs))
76
+ addplot.append(mpf.make_addplot(rsi, panel=2, color='orange', ylabel='RSI'))
77
+ if 'SMA21' in indicators:
78
+ logging.debug("Calculating SMA 21")
79
+ sma_21 = data['Close'].rolling(window=21).mean()
80
+ addplot.append(mpf.make_addplot(sma_21, color='purple', linestyle='dashed'))
81
+ if 'SMA50' in indicators:
82
+ logging.debug("Calculating SMA 50")
83
+ sma_50 = data['Close'].rolling(window=50).mean()
84
+ addplot.append(mpf.make_addplot(sma_50, color='orange', linestyle='dashed'))
85
+ if 'SMA200' in indicators:
86
+ logging.debug("Calculating SMA 200")
87
+ sma_200 = data['Close'].rolling(window=200).mean()
88
+ addplot.append(mpf.make_addplot(sma_200, color='brown', linestyle='dashed'))
89
+ if 'VWAP' in indicators:
90
+ logging.debug("Calculating VWAP")
91
+ vwap = (data['Volume'] * (data['High'] + data['Low'] + data['Close']) / 3).cumsum() / data['Volume'].cumsum()
92
+ addplot.append(mpf.make_addplot(vwap, color='blue', linestyle='dashed'))
93
+ if 'Bollinger Bands' in indicators:
94
+ logging.debug("Calculating Bollinger Bands")
95
+ sma = data['Close'].rolling(window=20).mean()
96
+ std = data['Close'].rolling(window=20).std()
97
+ upper_band = sma + (std * 2)
98
+ lower_band = sma - (std * 2)
99
+ addplot.append(mpf.make_addplot(upper_band, color='green', linestyle='dashed'))
100
+ addplot.append(mpf.make_addplot(lower_band, color='green', linestyle='dashed'))
101
+
102
+ fig, axlist = mpf.plot(data, type='candle', style=my_style, volume=True, addplot=addplot, returnfig=True)
103
+ fig.suptitle(title, y=0.98)
104
+
105
+ # Save chart image
106
+ fig.savefig(filename, dpi=300)
107
+ plt.close(fig)
108
+
109
+ # Open and add financial data to the image
110
+ image = Image.open(filename)
111
+ draw = ImageDraw.Draw(image)
112
+ font = ImageFont.load_default() # Use default font, you can also use custom fonts if available
113
+
114
+ # Financial metrics to add
115
+ metrics = {
116
+ "Ticker": ticker,
117
+ "Latest Close": f"${data['Close'].iloc[-1]:,.2f}",
118
+ "Volume": f"{data['Volume'].iloc[-1]:,.0f}"
119
+ }
120
+
121
+ # Add additional metrics if indicators are present
122
+ if 'SMA21' in indicators:
123
+ metrics["SMA 21"] = f"${data['Close'].rolling(window=21).mean().iloc[-1]:,.2f}"
124
+ if 'SMA50' in indicators:
125
+ metrics["SMA 50"] = f"${data['Close'].rolling(window=50).mean().iloc[-1]:,.2f}"
126
+ if 'SMA200' in indicators:
127
+ metrics["SMA 200"] = f"${data['Close'].rolling(window=200).mean().iloc[-1]:,.2f}"
128
+
129
+ # Draw metrics on the image
130
+ y_text = image.height - 50 # Starting y position for text
131
+ for key, value in metrics.items():
132
+ text = f"{key}: {value}"
133
+ draw.text((10, y_text), text, font=font, fill=(255, 255, 255)) # White color text
134
+ y_text += 20
135
+
136
+ # Resize image
137
+ new_size = (image.width * 3, image.height * 3)
138
+ resized_image = image.resize(new_size, Image.LANCZOS)
139
+ resized_image.save(filename)
140
+
141
+ logging.debug(f"Resized image with timeframe {timeframe} and ticker {ticker} saved to {filename}")
142
+ except Exception as e:
143
+ logging.error(f"Error creating or resizing chart: {e}")
144
+ raise
145
+
146
+ def combine_images(image_paths, output_path='combined_chart.png'):
147
+ try:
148
+ logging.debug(f"Combining images {image_paths} into {output_path}")
149
+ images = [Image.open(path) for path in image_paths]
150
+
151
+ # Calculate total width and max height for combined image
152
+ total_width = sum(img.width for img in images)
153
+ max_height = max(img.height for img in images)
154
+
155
+ combined_image = Image.new('RGB', (total_width, max_height))
156
+ x_offset = 0
157
+ for img in images:
158
+ combined_image.paste(img, (x_offset, 0))
159
+ x_offset += img.width
160
+
161
+ combined_image.save(output_path)
162
+ logging.debug(f"Combined image saved to {output_path}")
163
+ return output_path
164
+ except Exception as e:
165
+ logging.error(f"Error combining images: {e}")
166
+ raise
167
+
168
+ def perform_trend_analysis(data):
169
+ # Perform trend analysis
170
+ close_prices = data['Close']
171
+ time_index = np.arange(len(close_prices))
172
+ slope, intercept, r_value, p_value, std_err = stats.linregress(time_index, close_prices)
173
+
174
+ trend_line = slope * time_index + intercept
175
+
176
+ plt.figure(figsize=(12, 6))
177
+ plt.plot(close_prices.index, close_prices, label='Close Price')
178
+ plt.plot(close_prices.index, trend_line, color='red', label='Trend Line')
179
+ plt.title('Trend Analysis')
180
+ plt.xlabel('Date')
181
+ plt.ylabel('Price')
182
+ plt.legend()
183
+
184
+ trend_filename = 'trend_analysis.png'
185
+ plt.savefig(trend_filename)
186
+ plt.close()
187
+
188
+ trend_strength = abs(r_value)
189
+ trend_direction = "upward" if slope > 0 else "downward"
190
+
191
+ analysis_text = f"Trend Analysis:\n"
192
+ analysis_text += f"The stock shows a {trend_direction} trend.\n"
193
+ analysis_text += f"Trend strength (R-squared): {trend_strength:.2f}\n"
194
+ analysis_text += f"Slope: {slope:.4f}\n"
195
+
196
+ return analysis_text, trend_filename
197
+
198
+ def perform_correlation_analysis(data_dict):
199
+ # Perform correlation analysis
200
+ combined_data = pd.DataFrame({ticker: data['Close'] for ticker, data in data_dict.items()})
201
+ correlation_matrix = combined_data.corr()
202
+
203
+ plt.figure(figsize=(10, 8))
204
+ sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', vmin=-1, vmax=1, center=0)
205
+ plt.title('Correlation Analysis')
206
+
207
+ corr_filename = 'correlation_analysis.png'
208
+ plt.savefig(corr_filename)
209
+ plt.close()
210
+
211
+ analysis_text = f"Correlation Analysis:\n"
212
+ for i in range(len(correlation_matrix.columns)):
213
+ for j in range(i+1, len(correlation_matrix.columns)):
214
+ ticker1 = correlation_matrix.columns[i]
215
+ ticker2 = correlation_matrix.columns[j]
216
+ corr = correlation_matrix.iloc[i, j]
217
+ analysis_text += f"Correlation between {ticker1} and {ticker2}: {corr:.2f}\n"
218
+
219
+ return analysis_text, corr_filename
220
+
221
+ def comprehensive_investment_strategy():
222
+ strategy = """
223
+ Comprehensive Investment Strategy Analysis:
224
+
225
+ 1. Fundamental Analysis:
226
+ - Assess financial health using key metrics like P/E ratio, EPS growth, debt-to-equity ratio, and free cash flow.
227
+ - Analyze quarterly and annual earnings reports, focusing on revenue growth, profit margins, and management guidance.
228
+ - Consider industry trends such as technological disruption, regulatory changes, and shifting consumer preferences.
229
+
230
+ 2. Technical Analysis:
231
+ - Utilize chart patterns like head and shoulders, double tops/bottoms, and cup and handle.
232
+ - Employ technical indicators including Moving Averages, RSI, MACD, and Bollinger Bands.
233
+ - Incorporate volume analysis to confirm trend strength and potential reversals.
234
+
235
+ 3. Macroeconomic Analysis:
236
+ - Monitor key economic indicators: GDP growth, inflation rates, unemployment figures, and consumer sentiment indices.
237
+ - Track central bank policies, particularly interest rate decisions and quantitative easing programs.
238
+ - Evaluate geopolitical events' impact through news analysis and global market correlations.
239
+
240
+ 4. Risk Management:
241
+ - Implement diversification across sectors, geographies, and asset classes.
242
+ - Use position sizing based on account size and individual stock volatility.
243
+ - Set stop-loss orders at key technical levels, typically 5-15% below purchase price depending on stock volatility.
244
+
245
+ 5. Sentiment Analysis:
246
+ - Gauge market sentiment through tools like the VIX, put/call ratio, and investor surveys.
247
+ - Monitor social media trends and financial news sentiment using natural language processing tools.
248
+ - Apply contrarian strategies when extreme bullish or bearish sentiment is detected, supported by fundamental and technical analysis.
249
+
250
+ 6. Options Trading:
251
+ - Employ covered calls for income generation on long-term holds.
252
+ - Use protective puts to hedge downside risk on larger positions.
253
+ - Implement iron condors or credit spreads to take advantage of high implied volatility environments.
254
+
255
+ 7. Long-Term Investing:
256
+ - Identify companies with strong competitive advantages, consistent revenue growth, and solid balance sheets.
257
+ - Focus on businesses with high return on invested capital (ROIC) and effective management teams.
258
+ - Include dividend aristocrats and growth stocks at a reasonable price (GARP) for a balanced approach.
259
+
260
+ 8. Market Psychology:
261
+ - Apply principles of behavioral finance, recognizing common biases like herd mentality and recency bias.
262
+ - Maintain a trading journal to track decisions and emotions, promoting self-awareness and improvement.
263
+ - Develop and stick to a rules-based system to minimize emotional decision-making.
264
+
265
+ This comprehensive strategy aims to balance various analytical approaches, providing a robust framework for investment decisions across different market conditions.
266
+ """
267
+ return strategy
268
+
269
+ def gradio_interface(ticker1, ticker2, ticker3, ticker4, start_date, end_date, query, analysis_type, interval, indicators):
270
+ try:
271
+ logging.debug(f"Starting gradio_interface with tickers: {ticker1}, {ticker2}, {ticker3}, {ticker4}, start_date: {start_date}, end_date: {end_date}, query: {query}, analysis_type: {analysis_type}, interval: {interval}")
272
+
273
+ if analysis_type == 'Comprehensive Investment Strategy':
274
+ return comprehensive_investment_strategy(), None
275
+
276
+ tickers = [ticker.split(':')[0].strip() for ticker in [ticker1, ticker2, ticker3, ticker4] if ticker]
277
+ chart_paths = []
278
+ data_dict = {}
279
+
280
+ for i, ticker in enumerate(tickers):
281
+ if ticker:
282
+ data = fetch_stock_data(ticker, start=start_date, end=end_date, interval=interval)
283
+ if data.empty:
284
+ return f"No data available for {ticker} in the specified date range.", None
285
+ data_dict[ticker] = data
286
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_chart:
287
+ chart_path = temp_chart.name
288
+ create_stock_chart(data, ticker, chart_path, timeframe=interval, indicators=indicators)
289
+ chart_paths.append(chart_path)
290
+
291
+ if analysis_type == 'Comparative Analysis' and len(chart_paths) > 1:
292
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_combined_chart:
293
+ combined_chart_path = temp_combined_chart.name
294
+ combine_images(chart_paths, combined_chart_path)
295
+ insights = predict(Image.open(combined_chart_path), query)
296
+ return insights, combined_chart_path
297
+ elif analysis_type == 'Trend Analysis':
298
+ if len(data_dict) > 0:
299
+ first_ticker = list(data_dict.keys())[0]
300
+ analysis_text, trend_chart = perform_trend_analysis(data_dict[first_ticker])
301
+ return analysis_text, trend_chart
302
+ else:
303
+ return "No data available for trend analysis.", None
304
+ elif analysis_type == 'Correlation Analysis':
305
+ if len(data_dict) > 1:
306
+ analysis_text, corr_chart = perform_correlation_analysis(data_dict)
307
+ return analysis_text, corr_chart
308
+ else:
309
+ return "At least two tickers are required for correlation analysis.", None
310
+ else:
311
+
312
+ # Single ticker analysis
313
+ if chart_paths:
314
+ insights = predict(Image.open(chart_paths[0]), query)
315
+ return insights, chart_paths[0]
316
+ else:
317
+ return "No tickers provided.", None
318
+ except Exception as e:
319
+ logging.error(f"Error in Gradio interface: {e}")
320
+ return f"Error processing image or query: {e}", None
321
+
322
  def gradio_app():
323
  with gr.Blocks() as demo:
324
  gr.Markdown("""
325
  ## 📈 Advanced Stock Analysis Dashboard
326
+
327
+ This application provides a comprehensive stock analysis tool that allows users to input up to four stock tickers, specify date ranges, and select various financial indicators. The core functionalities include:
328
+
329
+ 1. **Data Fetching and Chart Creation**: Historical stock data is fetched from Yahoo Finance, and candlestick charts are generated with optional financial indicators like RSI, SMA, VWAP, and Bollinger Bands.
330
+
331
+ 2. **Text Analysis and Insights Generation**: The application uses a pre-trained model based on the Paligema architecture to analyze the input chart and text query, generating insightful analysis based on the provided financial data and context.
332
+
333
+ 3. **Trend Analysis**: Performs trend analysis on a single stock, showing the trend line and providing information about the trend strength and direction.
334
+
335
+ 4. **Correlation Analysis**: Analyzes the correlation between multiple stocks, providing a correlation matrix and heatmap.
336
+
337
+ 5. **Comprehensive Investment Strategy**: Provides a detailed investment strategy based on fundamental analysis, technical analysis, macroeconomic factors, risk management, and more.
338
+
339
+ 6. **User Interface**: Users can interactively select stocks, date ranges, intervals, and indicators. The app supports single ticker analysis, comparative analysis, trend analysis, correlation analysis, and comprehensive investment strategy.
340
+
341
+ 7. **Logging and Debugging**: Detailed logging helps in debugging and tracking the application's processes.
342
+
343
+ 8. **Enhanced Image Processing**: The app adds financial metrics and annotations to the generated charts, ensuring clear presentation of data.
344
+
345
+ This tool leverages various analysis techniques to provide detailed insights into stock market trends, offering an interactive and educational experience for users.
346
  """)
347
 
348
  ticker_options = [
 
491
  ticker3 = gr.Dropdown(label="Third Ticker", choices=ticker_options, value="GOOGL: Alphabet Inc.")
492
  ticker4 = gr.Dropdown(label="Fourth Ticker", choices=ticker_options, value="AMZN: Amazon.com, Inc.")
493
 
494
+ with gr.Row():
495
+ start_date = gr.Textbox(label="Start Date", value="2022-01-01")
496
+ end_date = gr.Textbox(label="End Date", value=datetime.date.today().isoformat())
497
+ interval = gr.Dropdown(label="Interval", choices=['1m', '2m', '5m', '15m', '30m', '60m', '90m', '1h', '1d', '5d', '1wk', '1mo', '3mo'], value='1d')
498
+
499
+ with gr.Row():
500
+ indicators = gr.CheckboxGroup(label="Indicators", choices=['RSI', 'SMA21', 'SMA50', 'SMA200', 'VWAP', 'Bollinger Bands'], value=['SMA21', 'SMA50'])
501
+ analysis_type = gr.Radio(label="Analysis Type", choices=['Single Ticker', 'Comparative Analysis', 'Trend Analysis', 'Correlation Analysis', 'Comprehensive Investment Strategy'], value='Single Ticker')
502
+
503
+ query = gr.Textbox(label="Analysis Query", value="Analyze the price trends.")
504
+ analyze_button = gr.Button("Analyze")
505
+ output_image = gr.Image(label="Analysis Chart")
506
+ output_text = gr.Textbox(label="Generated Insights", lines=10)
507
 
508
  analyze_button.click(
509
  fn=gradio_interface,