gaur3009 commited on
Commit
c36d988
·
verified ·
1 Parent(s): d21edd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -6,7 +6,6 @@ from stable_baselines3 import PPO
6
  from gym import Env
7
  from gym.spaces import Box
8
  import matplotlib.pyplot as plt
9
- import io
10
 
11
  class PortfolioEnv(Env):
12
  def __init__(self, returns, initial_balance=10000):
@@ -46,15 +45,16 @@ class PortfolioEnv(Env):
46
  return self.state
47
 
48
  def fetch_data(tickers, start_date, end_date):
49
- data = yf.download(tickers, start=start_date, end=end_date)['Adj Close']
50
- return data
 
 
 
51
 
52
  def optimize_portfolio(tickers, start_date, end_date, initial_balance):
53
- # Fetch real-time data
54
  data = fetch_data(tickers, start_date, end_date)
55
  returns = data.pct_change().dropna()
56
 
57
- # Define the environment
58
  env = PortfolioEnv(returns, initial_balance=initial_balance)
59
  model = PPO("MlpPolicy", env, verbose=0)
60
  model.learn(total_timesteps=5000)
@@ -76,21 +76,23 @@ def optimize_portfolio(tickers, start_date, end_date, initial_balance):
76
  plt.ylabel("Portfolio Value")
77
  plt.legend()
78
  plt.grid()
79
- plt.savefig("portfolio_chart.png")
 
 
 
80
  plt.close()
81
 
82
- # Prepare the output
83
  weights = {f"Asset_{i + 1} ({tickers.split(',')[i].strip()})": float(weight) for i, weight in enumerate(portfolio_weights)}
84
- return weights, "portfolio_chart.png"
85
 
86
  def run_optimization(tickers, start_date, end_date, initial_balance):
87
  if not tickers or not start_date or not end_date:
88
- return "Error: Please fill all the fields.", None
89
  try:
90
- weights, chart_path = optimize_portfolio(tickers, start_date, end_date, initial_balance)
91
- return weights, chart_path
92
  except Exception as e:
93
- return f"Error: {e}", None
94
 
95
  interface = gr.Interface(
96
  fn=run_optimization,
@@ -105,11 +107,9 @@ interface = gr.Interface(
105
  gr.Image(label="Portfolio Value Chart"),
106
  ],
107
  title="AI-Powered Portfolio Optimization",
108
- description="""
109
- Enter stock tickers (e.g., AAPL, MSFT, TSLA), a date range, and your initial investment amount.
110
- The app fetches real-time historical data, runs AI optimization, and returns the optimized portfolio weights
111
- along with a performance chart.
112
- """
113
  )
114
 
115
  if __name__ == "__main__":
 
6
  from gym import Env
7
  from gym.spaces import Box
8
  import matplotlib.pyplot as plt
 
9
 
10
  class PortfolioEnv(Env):
11
  def __init__(self, returns, initial_balance=10000):
 
45
  return self.state
46
 
47
  def fetch_data(tickers, start_date, end_date):
48
+ try:
49
+ data = yf.download(tickers, start=start_date, end=end_date)['Adj Close']
50
+ return data
51
+ except Exception as e:
52
+ raise ValueError(f"Failed to fetch data: {e}")
53
 
54
  def optimize_portfolio(tickers, start_date, end_date, initial_balance):
 
55
  data = fetch_data(tickers, start_date, end_date)
56
  returns = data.pct_change().dropna()
57
 
 
58
  env = PortfolioEnv(returns, initial_balance=initial_balance)
59
  model = PPO("MlpPolicy", env, verbose=0)
60
  model.learn(total_timesteps=5000)
 
76
  plt.ylabel("Portfolio Value")
77
  plt.legend()
78
  plt.grid()
79
+
80
+ buffer = io.BytesIO()
81
+ plt.savefig(buffer, format="png")
82
+ buffer.seek(0)
83
  plt.close()
84
 
 
85
  weights = {f"Asset_{i + 1} ({tickers.split(',')[i].strip()})": float(weight) for i, weight in enumerate(portfolio_weights)}
86
+ return weights, buffer
87
 
88
  def run_optimization(tickers, start_date, end_date, initial_balance):
89
  if not tickers or not start_date or not end_date:
90
+ return {"error": "Please fill all the fields."}, None
91
  try:
92
+ weights, chart_buffer = optimize_portfolio(tickers, start_date, end_date, initial_balance)
93
+ return weights, chart_buffer
94
  except Exception as e:
95
+ return {"error": str(e)}, None
96
 
97
  interface = gr.Interface(
98
  fn=run_optimization,
 
107
  gr.Image(label="Portfolio Value Chart"),
108
  ],
109
  title="AI-Powered Portfolio Optimization",
110
+ description="""Enter stock tickers (e.g., AAPL, MSFT, TSLA), a date range, and your initial investment amount.
111
+ The app fetches historical data, runs AI optimization, and returns the optimized portfolio weights
112
+ along with a performance chart."""
 
 
113
  )
114
 
115
  if __name__ == "__main__":