Update app.py
Browse files
app.py
CHANGED
@@ -53,9 +53,11 @@ def fetch_data(tickers, start_date, end_date):
|
|
53 |
raise ValueError(f"Failed to fetch data: {e}")
|
54 |
|
55 |
def optimize_portfolio(tickers, start_date, end_date, initial_balance):
|
|
|
56 |
data = fetch_data(tickers, start_date, end_date)
|
57 |
returns = data.pct_change().dropna()
|
58 |
|
|
|
59 |
env = PortfolioEnv(returns, initial_balance=initial_balance)
|
60 |
model = PPO("MlpPolicy", env, verbose=0)
|
61 |
model.learn(total_timesteps=5000)
|
@@ -70,6 +72,7 @@ def optimize_portfolio(tickers, start_date, end_date, initial_balance):
|
|
70 |
portfolio_weights = action / (np.sum(action) + 1e-8)
|
71 |
portfolio_values.append(reward)
|
72 |
|
|
|
73 |
plt.figure(figsize=(10, 6))
|
74 |
plt.plot(portfolio_values, label="Portfolio Value")
|
75 |
plt.title("Portfolio Value Over Time")
|
@@ -78,13 +81,14 @@ def optimize_portfolio(tickers, start_date, end_date, initial_balance):
|
|
78 |
plt.legend()
|
79 |
plt.grid()
|
80 |
|
81 |
-
|
82 |
-
plt.savefig(
|
83 |
-
buffer.seek(0)
|
84 |
plt.close()
|
85 |
|
|
|
86 |
weights = {f"Asset_{i + 1} ({tickers.split(',')[i].strip()})": float(weight) for i, weight in enumerate(portfolio_weights)}
|
87 |
-
return weights,
|
|
|
88 |
|
89 |
def run_optimization(tickers, start_date, end_date, initial_balance):
|
90 |
if not tickers or not start_date or not end_date:
|
|
|
53 |
raise ValueError(f"Failed to fetch data: {e}")
|
54 |
|
55 |
def optimize_portfolio(tickers, start_date, end_date, initial_balance):
|
56 |
+
# Fetch real-time data
|
57 |
data = fetch_data(tickers, start_date, end_date)
|
58 |
returns = data.pct_change().dropna()
|
59 |
|
60 |
+
# Define the environment
|
61 |
env = PortfolioEnv(returns, initial_balance=initial_balance)
|
62 |
model = PPO("MlpPolicy", env, verbose=0)
|
63 |
model.learn(total_timesteps=5000)
|
|
|
72 |
portfolio_weights = action / (np.sum(action) + 1e-8)
|
73 |
portfolio_values.append(reward)
|
74 |
|
75 |
+
# Save chart
|
76 |
plt.figure(figsize=(10, 6))
|
77 |
plt.plot(portfolio_values, label="Portfolio Value")
|
78 |
plt.title("Portfolio Value Over Time")
|
|
|
81 |
plt.legend()
|
82 |
plt.grid()
|
83 |
|
84 |
+
chart_path = "portfolio_chart.png"
|
85 |
+
plt.savefig(chart_path)
|
|
|
86 |
plt.close()
|
87 |
|
88 |
+
# Prepare the output
|
89 |
weights = {f"Asset_{i + 1} ({tickers.split(',')[i].strip()})": float(weight) for i, weight in enumerate(portfolio_weights)}
|
90 |
+
return weights, chart_path
|
91 |
+
|
92 |
|
93 |
def run_optimization(tickers, start_date, end_date, initial_balance):
|
94 |
if not tickers or not start_date or not end_date:
|