ReWOO-Demo / algos /PWS.py
billxbf's picture
init
926675f
raw
history blame
6.63 kB
# main class chaining Planner, Worker and Solver.
import re
import time
from nodes.Planner import Planner
from nodes.Solver import Solver
from nodes.Worker import *
from utils.util import *
class PWS:
def __init__(self, available_tools=["Google", "LLM"], fewshot="\n", planner_model="text-davinci-003",
solver_model="text-davinci-003"):
self.workers = available_tools
self.planner = Planner(workers=self.workers,
model_name=planner_model,
fewshot=fewshot)
self.solver = Solver(model_name=solver_model)
self.plans = []
self.planner_evidences = {}
self.worker_evidences = {}
self.tool_counter = {}
self.planner_token_unit_price = get_token_unit_price(planner_model)
self.solver_token_unit_price = get_token_unit_price(solver_model)
self.tool_token_unit_price = get_token_unit_price("text-davinci-003")
self.google_unit_price = 0.01
# input: the question line. e.g. "Question: What is the capital of France?"
def run(self, input):
# run is stateless, so we need to reset the evidences
self._reinitialize()
result = {}
st = time.time()
# Plan
planner_response = self.planner.run(input, log=True)
plan = planner_response["output"]
planner_log = planner_response["input"] + planner_response["output"]
self.plans = self._parse_plans(plan)
self.planner_evidences = self._parse_planner_evidences(plan)
#assert len(self.plans) == len(self.planner_evidences)
# Work
self._get_worker_evidences()
worker_log = ""
for i in range(len(self.plans)):
e = f"#E{i + 1}"
worker_log += f"{self.plans[i]}\nEvidence:\n{self.worker_evidences[e]}\n"
# Solve
solver_response = self.solver.run(input, worker_log, log=True)
output = solver_response["output"]
solver_log = solver_response["input"] + solver_response["output"]
result["wall_time"] = time.time() - st
result["input"] = input
result["output"] = output
result["planner_log"] = planner_log
result["worker_log"] = worker_log
result["solver_log"] = solver_log
result["tool_usage"] = self.tool_counter
result["steps"] = len(self.plans) + 1
result["total_tokens"] = planner_response["prompt_tokens"] + planner_response["completion_tokens"] \
+ solver_response["prompt_tokens"] + solver_response["completion_tokens"] \
+ self.tool_counter.get("LLM_token", 0) \
+ self.tool_counter.get("Calculator_token", 0)
result["token_cost"] = self.planner_token_unit_price * (planner_response["prompt_tokens"] + planner_response["completion_tokens"]) \
+ self.solver_token_unit_price * (solver_response["prompt_tokens"] + solver_response["completion_tokens"]) \
+ self.tool_token_unit_price * (self.tool_counter.get("LLM_token", 0) + self.tool_counter.get("Calculator_token", 0))
result["tool_cost"] = self.tool_counter.get("Google", 0) * self.google_unit_price
result["total_cost"] = result["token_cost"] + result["tool_cost"]
return result
def _parse_plans(self, response):
plans = []
for line in response.splitlines():
if line.startswith("Plan:"):
plans.append(line)
return plans
def _parse_planner_evidences(self, response):
evidences = {}
for line in response.splitlines():
if line.startswith("#") and line[1] == "E" and line[2].isdigit():
e, tool_call = line.split("=", 1)
e, tool_call = e.strip(), tool_call.strip()
if len(e) == 3:
evidences[e] = tool_call
else:
evidences[e] = "No evidence found"
return evidences
# use planner evidences to assign tasks to respective workers.
def _get_worker_evidences(self):
for e, tool_call in self.planner_evidences.items():
if "[" not in tool_call:
self.worker_evidences[e] = tool_call
continue
tool, tool_input = tool_call.split("[", 1)
tool_input = tool_input[:-1]
# find variables in input and replace with previous evidences
for var in re.findall(r"#E\d+", tool_input):
if var in self.worker_evidences:
tool_input = tool_input.replace(var, "[" + self.worker_evidences[var] + "]")
if tool in self.workers:
self.worker_evidences[e] = WORKER_REGISTRY[tool].run(tool_input)
if tool == "Google":
self.tool_counter["Google"] = self.tool_counter.get("Google", 0) + 1 # number of query
elif tool == "LLM":
self.tool_counter["LLM_token"] = self.tool_counter.get("LLM_token", 0) + len(
tool_input + self.worker_evidences[e]) // 4
elif tool == "Calculator":
self.tool_counter["Calculator_token"] = self.tool_counter.get("Calculator_token", 0) \
+ len(
LLMMathChain(llm=OpenAI(), verbose=False).prompt.template + tool_input + self.worker_evidences[
e]) // 4
else:
self.worker_evidences[e] = "No evidence found"
def _reinitialize(self):
self.plans = []
self.planner_evidences = {}
self.worker_evidences = {}
self.tool_counter = {}
class PWS_Base(PWS):
def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_BASE, planner_model="text-davinci-003",
solver_model="text-davinci-003", available_tools=["Wikipedia", "LLM"]):
super().__init__(available_tools=available_tools,
fewshot=fewshot,
planner_model=planner_model,
solver_model=solver_model)
class PWS_Extra(PWS):
def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_EXTRA, planner_model="text-davinci-003",
solver_model="text-davinci-003", available_tools=["Google", "Calculator", "LLM"]):
super().__init__(available_tools=available_tools,
fewshot=fewshot,
planner_model=planner_model,
solver_model=solver_model)