Stock-opt / app.py
gaur3009's picture
Update app.py
3519170 verified
raw
history blame
4.32 kB
import gradio as gr
import pandas as pd
import numpy as np
import yfinance as yf
from stable_baselines3 import PPO
from gym import Env
from gym.spaces import Box
import matplotlib.pyplot as plt
import io
class PortfolioEnv(Env):
def __init__(self, returns, initial_balance=10000):
super(PortfolioEnv, self).__init__()
self.returns = returns
self.n_assets = returns.shape[1]
self.initial_balance = initial_balance
self.current_balance = initial_balance
self.action_space = Box(low=0, high=1, shape=(self.n_assets,), dtype=np.float32)
self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self.n_assets,), dtype=np.float32)
self.weights = np.ones(self.n_assets) / self.n_assets
self.state = self.returns.iloc[0].values
self.current_step = 0
def step(self, action):
action = np.clip(action, 0, 1)
weights = action / (np.sum(action) + 1e-8) # Prevent division by zero
portfolio_return = np.dot(weights, self.returns.iloc[self.current_step].values)
self.current_balance *= (1 + portfolio_return)
self.current_step += 1
done = self.current_step >= len(self.returns) - 1
if not done:
self.state = self.returns.iloc[self.current_step].values
reward = self.current_balance
return self.state, reward, done, {}
def reset(self):
self.current_balance = self.initial_balance
self.current_step = 0
self.state = self.returns.iloc[0].values
return self.state
def fetch_data(tickers, start_date, end_date):
try:
data = yf.download(tickers, start=start_date, end=end_date)['Adj Close']
return data
except Exception as e:
raise ValueError(f"Failed to fetch data: {e}")
def optimize_portfolio(tickers, start_date, end_date, initial_balance):
# Fetch real-time data
data = fetch_data(tickers, start_date, end_date)
returns = data.pct_change().dropna()
# Define the environment
env = PortfolioEnv(returns, initial_balance=initial_balance)
model = PPO("MlpPolicy", env, verbose=0)
model.learn(total_timesteps=5000)
state = env.reset()
done = False
portfolio_weights = []
portfolio_values = [initial_balance]
while not done:
action, _ = model.predict(state)
state, reward, done, _ = env.step(action)
portfolio_weights = action / (np.sum(action) + 1e-8)
portfolio_values.append(reward)
# Save chart
plt.figure(figsize=(10, 6))
plt.plot(portfolio_values, label="Portfolio Value")
plt.title("Portfolio Value Over Time")
plt.xlabel("Time Steps")
plt.ylabel("Portfolio Value")
plt.legend()
plt.grid()
chart_path = "portfolio_chart.png"
plt.savefig(chart_path)
plt.close()
# Prepare the output
weights = {f"Asset_{i + 1} ({tickers.split(',')[i].strip()})": float(weight) for i, weight in enumerate(portfolio_weights)}
return weights, chart_path
def run_optimization(tickers, start_date, end_date, initial_balance):
if not tickers or not start_date or not end_date:
return {"error": "Please fill all the fields."}, None
try:
weights, chart_buffer = optimize_portfolio(tickers, start_date, end_date, initial_balance)
return weights, chart_buffer
except Exception as e:
return {"error": str(e)}, None
interface = gr.Interface(
fn=run_optimization,
inputs=[
gr.Textbox(label="Enter Stock Tickers (comma-separated)", placeholder="AAPL, MSFT, TSLA"),
gr.Textbox(label="Start Date (YYYY-MM-DD)", placeholder="2023-01-01"),
gr.Textbox(label="End Date (YYYY-MM-DD)", placeholder="2023-12-31"),
gr.Number(label="Initial Investment Amount", value=10000),
],
outputs=[
gr.JSON(label="Optimized Portfolio Weights"),
gr.Image(label="Portfolio Value Chart"),
],
title="AI-Powered Portfolio Optimization",
description="""Enter stock tickers (e.g., AAPL, MSFT, TSLA), a date range, and your initial investment amount.
The app fetches historical data, runs AI optimization, and returns the optimized portfolio weights
along with a performance chart."""
)
if __name__ == "__main__":
interface.launch()