gaur3009 commited on
Commit
3519170
·
verified ·
1 Parent(s): df11b5e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -53,9 +53,11 @@ def fetch_data(tickers, start_date, end_date):
53
  raise ValueError(f"Failed to fetch data: {e}")
54
 
55
  def optimize_portfolio(tickers, start_date, end_date, initial_balance):
 
56
  data = fetch_data(tickers, start_date, end_date)
57
  returns = data.pct_change().dropna()
58
 
 
59
  env = PortfolioEnv(returns, initial_balance=initial_balance)
60
  model = PPO("MlpPolicy", env, verbose=0)
61
  model.learn(total_timesteps=5000)
@@ -70,6 +72,7 @@ def optimize_portfolio(tickers, start_date, end_date, initial_balance):
70
  portfolio_weights = action / (np.sum(action) + 1e-8)
71
  portfolio_values.append(reward)
72
 
 
73
  plt.figure(figsize=(10, 6))
74
  plt.plot(portfolio_values, label="Portfolio Value")
75
  plt.title("Portfolio Value Over Time")
@@ -78,13 +81,14 @@ def optimize_portfolio(tickers, start_date, end_date, initial_balance):
78
  plt.legend()
79
  plt.grid()
80
 
81
- buffer = io.BytesIO()
82
- plt.savefig(buffer, format="png")
83
- buffer.seek(0)
84
  plt.close()
85
 
 
86
  weights = {f"Asset_{i + 1} ({tickers.split(',')[i].strip()})": float(weight) for i, weight in enumerate(portfolio_weights)}
87
- return weights, buffer
 
88
 
89
  def run_optimization(tickers, start_date, end_date, initial_balance):
90
  if not tickers or not start_date or not end_date:
 
53
  raise ValueError(f"Failed to fetch data: {e}")
54
 
55
  def optimize_portfolio(tickers, start_date, end_date, initial_balance):
56
+ # Fetch real-time data
57
  data = fetch_data(tickers, start_date, end_date)
58
  returns = data.pct_change().dropna()
59
 
60
+ # Define the environment
61
  env = PortfolioEnv(returns, initial_balance=initial_balance)
62
  model = PPO("MlpPolicy", env, verbose=0)
63
  model.learn(total_timesteps=5000)
 
72
  portfolio_weights = action / (np.sum(action) + 1e-8)
73
  portfolio_values.append(reward)
74
 
75
+ # Save chart
76
  plt.figure(figsize=(10, 6))
77
  plt.plot(portfolio_values, label="Portfolio Value")
78
  plt.title("Portfolio Value Over Time")
 
81
  plt.legend()
82
  plt.grid()
83
 
84
+ chart_path = "portfolio_chart.png"
85
+ plt.savefig(chart_path)
 
86
  plt.close()
87
 
88
+ # Prepare the output
89
  weights = {f"Asset_{i + 1} ({tickers.split(',')[i].strip()})": float(weight) for i, weight in enumerate(portfolio_weights)}
90
+ return weights, chart_path
91
+
92
 
93
  def run_optimization(tickers, start_date, end_date, initial_balance):
94
  if not tickers or not start_date or not end_date: