mobenta commited on
Commit
44829cc
·
verified ·
1 Parent(s): 7496630

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -176
app.py CHANGED
@@ -2,13 +2,13 @@ import torch
2
  import yfinance as yf
3
  import matplotlib.pyplot as plt
4
  import mplfinance as mpf
5
- from PIL import Image
6
  import gradio as gr
7
  import datetime
8
  import logging
9
  from transformers import AutoProcessor, AutoModelForPreTraining
10
- import spaces
11
- import pandas as pd
12
 
13
  # Configure logging
14
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -95,30 +95,62 @@ def create_stock_chart(data, ticker, filename='chart.png', timeframe='1d', indic
95
 
96
  fig, axlist = mpf.plot(data, type='candle', style=my_style, volume=True, addplot=addplot, returnfig=True)
97
  fig.suptitle(title, y=0.98)
 
 
98
  fig.savefig(filename, dpi=300)
99
  plt.close(fig)
100
 
 
101
  image = Image.open(filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  new_size = (image.width * 3, image.height * 3)
103
  resized_image = image.resize(new_size, Image.LANCZOS)
104
  resized_image.save(filename)
 
105
  logging.debug(f"Resized image with timeframe {timeframe} and ticker {ticker} saved to {filename}")
106
  except Exception as e:
107
  logging.error(f"Error creating or resizing chart: {e}")
108
  raise
109
 
110
- def combine_images(image1_path, image2_path, output_path='combined_chart.png'):
111
  try:
112
- logging.debug(f"Combining images {image1_path} and {image2_path} into {output_path}")
113
- image1 = Image.open(image1_path)
114
- image2 = Image.open(image2_path)
115
-
116
- total_width = image1.width + image2.width
117
- max_height = max(image1.height, image2.height)
118
 
119
  combined_image = Image.new('RGB', (total_width, max_height))
120
- combined_image.paste(image1, (0, 0))
121
- combined_image.paste(image2, (image1.width, 0))
 
 
122
 
123
  combined_image.save(output_path)
124
  logging.debug(f"Combined image saved to {output_path}")
@@ -127,174 +159,69 @@ def combine_images(image1_path, image2_path, output_path='combined_chart.png'):
127
  logging.error(f"Error combining images: {e}")
128
  raise
129
 
130
- def gradio_interface(ticker1, start_date, end_date, ticker2, query, analysis_type, interval, indicators):
131
  try:
132
- logging.debug(f"Starting gradio_interface with ticker1: {ticker1}, start_date: {start_date}, end_date: {end_date}, ticker2: {ticker2}, query: {query}, analysis_type: {analysis_type}, interval: {interval}")
133
-
134
- data1 = fetch_stock_data(ticker1, start=start_date, end=end_date, interval=interval)
135
- chart_path1 = '/tmp/chart1.png'
136
- create_stock_chart(data1, ticker1, chart_path1, timeframe=interval, indicators=indicators)
137
-
138
- if analysis_type == 'Comparative Analysis' and ticker2:
139
- data2 = fetch_stock_data(ticker2, start=start_date, end=end_date, interval=interval)
140
- chart_path2 = '/tmp/chart2.png'
141
- create_stock_chart(data2, ticker2, chart_path2, timeframe=interval, indicators=indicators)
142
-
143
- combined_chart_path = combine_images(chart_path1, chart_path2)
144
- insights = predict(Image.open(combined_chart_path), query)
145
- return insights, combined_chart_path
146
-
147
- insights = predict(Image.open(chart_path1), query)
148
- return insights, chart_path1
 
 
 
 
 
 
 
 
 
149
  except Exception as e:
150
- logging.error(f"Error processing image or query: {e}")
151
  return f"Error processing image or query: {e}", None
152
 
153
- def set_query_trend():
154
- return "What are the key trends shown in this chart?"
155
-
156
- def set_query_comparative():
157
- return "How does [First Ticker] compare to [Second Ticker]?"
158
-
159
- def set_query_forecasting():
160
- return "Based on the current data, what are the projected trends?"
161
-
162
- # Default dates
163
- default_start_date = '2010-01-01'
164
- default_end_date = datetime.datetime.now().strftime('%Y-%m-%d')
165
-
166
- with gr.Blocks() as interface:
167
- gr.Markdown(
168
- """
169
- # 📈 Price Market Analysis Tool
170
- Welcome to the Price Market Analysis Tool! This interface helps you generate insightful analyses of market data. Choose between trend analysis, comparative analysis, and forecasting based on your needs.
171
- """
172
- )
173
-
174
- with gr.Row():
175
- ticker1_input = gr.Textbox(
176
- lines=1,
177
- placeholder="Enter first ticker (e.g., TSLA)",
178
- label="First Ticker",
179
- )
180
-
181
- ticker2_input = gr.Textbox(
182
- lines=1,
183
- placeholder="Enter second ticker for comparative analysis (optional)",
184
- label="Second Ticker (Optional)",
185
- )
186
-
187
- with gr.Row():
188
- start_date_input = gr.Textbox(
189
- lines=1,
190
- placeholder="Enter start date (e.g., 2010-01-01)",
191
- value=default_start_date,
192
- label="Start Date",
193
  )
194
 
195
- end_date_input = gr.Textbox(
196
- lines=1,
197
- placeholder=f"Enter end date (default: {default_end_date})",
198
- value=default_end_date,
199
- label="End Date",
200
- )
201
-
202
- with gr.Row():
203
- input_text = gr.Textbox(
204
- lines=3,
205
- placeholder="Enter your input text",
206
- label="Input Text",
207
- )
208
-
209
- analysis_type_input = gr.Textbox(
210
- lines=1,
211
- placeholder="Analysis Type",
212
- label="",
213
- visible=False,
214
- )
215
-
216
- query_input = gr.Textbox(
217
- lines=3,
218
- placeholder="Query",
219
- label="",
220
- visible=False,
221
- )
222
-
223
- with gr.Row():
224
- interval_input = gr.Dropdown(
225
- choices=["1d", "1wk", "1mo"],
226
- value="1d",
227
- label="Select Time Frame",
228
- )
229
-
230
- with gr.Row():
231
- indicator_input = gr.CheckboxGroup(
232
- choices=["VWAP", "Volume", "RSI", "Ichimoku Cloud", "Bollinger Bands", "Pivot Levels", "SMA21", "SMA50", "SMA200"],
233
- label="Select Indicators",
234
- )
235
-
236
- with gr.Row():
237
- trend_button = gr.Button("Trend Analysis")
238
- comparative_button = gr.Button("Comparative Analysis")
239
- forecasting_button = gr.Button("Forecasting")
240
- submit_button = gr.Button("Submit")
241
- clear_button = gr.Button("Clear")
242
-
243
- output_text = gr.Textbox(lines=5, label="Generated Insights")
244
- output_image = gr.Image(type="filepath", label="Price Chart")
245
-
246
- trend_button.click(
247
- fn=lambda: "Trend Analysis",
248
- inputs=[],
249
- outputs=[analysis_type_input],
250
- ).then(
251
- set_query_trend,
252
- inputs=[],
253
- outputs=[query_input],
254
- ).then(
255
- gradio_interface,
256
- inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input, indicator_input],
257
- outputs=[output_text, output_image],
258
- )
259
-
260
- comparative_button.click(
261
- fn=lambda: "Comparative Analysis",
262
- inputs=[],
263
- outputs=[analysis_type_input],
264
- ).then(
265
- set_query_comparative,
266
- inputs=[],
267
- outputs=[query_input],
268
- ).then(
269
- gradio_interface,
270
- inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input, indicator_input],
271
- outputs=[output_text, output_image],
272
- )
273
-
274
- forecasting_button.click(
275
- fn=lambda: "Forecasting",
276
- inputs=[],
277
- outputs=[analysis_type_input],
278
- ).then(
279
- set_query_forecasting,
280
- inputs=[],
281
- outputs=[query_input],
282
- ).then(
283
- gradio_interface,
284
- inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, query_input, analysis_type_input, interval_input, indicator_input],
285
- outputs=[output_text, output_image],
286
- )
287
-
288
- submit_button.click(
289
- gradio_interface,
290
- inputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, input_text, analysis_type_input, interval_input, indicator_input],
291
- outputs=[output_text, output_image],
292
- )
293
-
294
- clear_button.click(
295
- fn=lambda: ("", "", "", "", "", "", "", []),
296
- inputs=[],
297
- outputs=[ticker1_input, start_date_input, end_date_input, ticker2_input, input_text, analysis_type_input, interval_input, indicator_input],
298
- )
299
 
300
- interface.launch(debug=True)
 
 
2
  import yfinance as yf
3
  import matplotlib.pyplot as plt
4
  import mplfinance as mpf
5
+ from PIL import Image, ImageDraw, ImageFont
6
  import gradio as gr
7
  import datetime
8
  import logging
9
  from transformers import AutoProcessor, AutoModelForPreTraining
10
+ import tempfile
11
+ import os
12
 
13
  # Configure logging
14
  logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
 
95
 
96
  fig, axlist = mpf.plot(data, type='candle', style=my_style, volume=True, addplot=addplot, returnfig=True)
97
  fig.suptitle(title, y=0.98)
98
+
99
+ # Save chart image
100
  fig.savefig(filename, dpi=300)
101
  plt.close(fig)
102
 
103
+ # Open and add financial data to the image
104
  image = Image.open(filename)
105
+ draw = ImageDraw.Draw(image)
106
+ font = ImageFont.load_default() # Use default font, you can also use custom fonts if available
107
+
108
+ # Financial metrics to add
109
+ metrics = {
110
+ "Ticker": ticker,
111
+ "Latest Close": f"${data['Close'].iloc[-1]:,.2f}",
112
+ "Volume": f"{data['Volume'].iloc[-1]:,.0f}"
113
+ }
114
+
115
+ # Add additional metrics if indicators are present
116
+ if 'SMA21' in indicators:
117
+ metrics["SMA 21"] = f"${data['Close'].rolling(window=21).mean().iloc[-1]:,.2f}"
118
+ if 'SMA50' in indicators:
119
+ metrics["SMA 50"] = f"${data['Close'].rolling(window=50).mean().iloc[-1]:,.2f}"
120
+ if 'SMA200' in indicators:
121
+ metrics["SMA 200"] = f"${data['Close'].rolling(window=200).mean().iloc[-1]:,.2f}"
122
+
123
+ # Draw metrics on the image
124
+ y_text = image.height - 50 # Starting y position for text
125
+ for key, value in metrics.items():
126
+ text = f"{key}: {value}"
127
+ draw.text((10, y_text), text, font=font, fill=(255, 255, 255)) # White color text
128
+ y_text += 20
129
+
130
+ # Resize image
131
  new_size = (image.width * 3, image.height * 3)
132
  resized_image = image.resize(new_size, Image.LANCZOS)
133
  resized_image.save(filename)
134
+
135
  logging.debug(f"Resized image with timeframe {timeframe} and ticker {ticker} saved to {filename}")
136
  except Exception as e:
137
  logging.error(f"Error creating or resizing chart: {e}")
138
  raise
139
 
140
+ def combine_images(image_paths, output_path='combined_chart.png'):
141
  try:
142
+ logging.debug(f"Combining images {image_paths} into {output_path}")
143
+ images = [Image.open(path) for path in image_paths]
144
+
145
+ # Calculate total width and max height for combined image
146
+ total_width = sum(img.width for img in images)
147
+ max_height = max(img.height for img in images)
148
 
149
  combined_image = Image.new('RGB', (total_width, max_height))
150
+ x_offset = 0
151
+ for img in images:
152
+ combined_image.paste(img, (x_offset, 0))
153
+ x_offset += img.width
154
 
155
  combined_image.save(output_path)
156
  logging.debug(f"Combined image saved to {output_path}")
 
159
  logging.error(f"Error combining images: {e}")
160
  raise
161
 
162
+ def gradio_interface(ticker1, ticker2, ticker3, ticker4, start_date, end_date, query, analysis_type, interval, indicators):
163
  try:
164
+ logging.debug(f"Starting gradio_interface with tickers: {ticker1}, {ticker2}, {ticker3}, {ticker4}, start_date: {start_date}, end_date: {end_date}, query: {query}, analysis_type: {analysis_type}, interval: {interval}")
165
+
166
+ tickers = [ticker1, ticker2, ticker3, ticker4]
167
+ chart_paths = []
168
+
169
+ for i, ticker in enumerate(tickers):
170
+ if ticker:
171
+ data = fetch_stock_data(ticker, start=start_date, end=end_date, interval=interval)
172
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_chart:
173
+ chart_path = temp_chart.name
174
+ create_stock_chart(data, ticker, chart_path, timeframe=interval, indicators=indicators)
175
+ chart_paths.append(chart_path)
176
+
177
+ if analysis_type == 'Comparative Analysis' and len(chart_paths) > 1:
178
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as temp_combined_chart:
179
+ combined_chart_path = temp_combined_chart.name
180
+ combine_images(chart_paths, combined_chart_path)
181
+ insights = predict(Image.open(combined_chart_path), query)
182
+ return insights, combined_chart_path
183
+
184
+ # No comparative analysis, just return the single chart
185
+ if chart_paths:
186
+ insights = predict(Image.open(chart_paths[0]), query)
187
+ return insights, chart_paths[0]
188
+ else:
189
+ return "No tickers provided.", None
190
  except Exception as e:
191
+ logging.error(f"Error in Gradio interface: {e}")
192
  return f"Error processing image or query: {e}", None
193
 
194
+ def gradio_app():
195
+ with gr.Blocks() as demo:
196
+ gr.Markdown("## Stock Analysis Dashboard")
197
+
198
+ with gr.Row():
199
+ ticker1 = gr.Textbox(label="Primary Ticker", value="GC=F")
200
+ ticker2 = gr.Textbox(label="Secondary Ticker", value="CL=F")
201
+ ticker3 = gr.Textbox(label="Third Ticker", value="SPY")
202
+ ticker4 = gr.Textbox(label="Fourth Ticker", value="EURUSD=X")
203
+
204
+ with gr.Row():
205
+ start_date = gr.Textbox(label="Start Date", value="2022-01-01")
206
+ end_date = gr.Textbox(label="End Date", value=datetime.datetime.now().strftime('%Y-%m-%d'))
207
+ interval = gr.Dropdown(label="Interval", choices=['1d', '1wk', '1mo'], value='1d')
208
+
209
+ with gr.Row():
210
+ indicators = gr.CheckboxGroup(label="Indicators", choices=['RSI', 'SMA21', 'SMA50', 'SMA200', 'VWAP', 'Bollinger Bands'], value=['SMA21', 'SMA50'])
211
+ analysis_type = gr.Radio(label="Analysis Type", choices=['Single Ticker', 'Comparative Analysis'], value='Single Ticker')
212
+
213
+ query = gr.Textbox(label="Analysis Query", value="Analyze the price trends.")
214
+ analyze_button = gr.Button("Analyze")
215
+ output_image = gr.Image(label="Stock Chart")
216
+ output_text = gr.Textbox(label="Generated Insights", lines=5)
217
+
218
+ analyze_button.click(
219
+ fn=gradio_interface,
220
+ inputs=[ticker1, ticker2, ticker3, ticker4, start_date, end_date, query, analysis_type, interval, indicators],
221
+ outputs=[output_text, output_image]
 
 
 
 
 
 
 
 
 
 
 
 
222
  )
223
 
224
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ if __name__ == "__main__":
227
+ gradio_app()