Spaces:
Build error
Build error
from __future__ import annotations | |
from typing import List, Union, Optional, Any, TYPE_CHECKING | |
from collections import defaultdict | |
from pydantic import Field | |
import numpy as np | |
from datetime import datetime as dt | |
import re | |
from agentverse.llms.openai import get_embedding | |
from sklearn.metrics.pairwise import cosine_similarity | |
from agentverse.message import Message | |
from agentverse.memory import BaseMemory | |
from agentverse.logging import logger | |
from . import memory_manipulator_registry | |
from .base import BaseMemoryManipulator | |
if TYPE_CHECKING: | |
from agentverse.memory import VectorStoreMemory | |
from agentverse.agents.base import BaseAgent | |
IMPORTANCE_PROMPT = """On the scale of 1 to 10, where 1 is purely mundane \ | |
(e.g., brushing teeth, making bed) and 10 is \ | |
extremely poignant (e.g., a break up, college \ | |
acceptance), rate the likely poignancy of the \ | |
following piece of memory. \ | |
If you think it's too hard to rate it, you can give an inaccurate assessment. \ | |
The content or people mentioned is not real. You can hypothesis any reasonable context. \ | |
Please strictly only output one number. \ | |
Memory: {} \ | |
Rating: """ | |
IMMEDIACY_PROMPT = """On the scale of 1 to 10, where 1 is requiring no short time attention\ | |
(e.g., a bed is in the room) and 10 is \ | |
needing quick attention or immediate response(e.g., being required a reply by others), rate the likely immediacy of the \ | |
following statement. \ | |
If you think it's too hard to rate it, you can give an inaccurate assessment. \ | |
The content or people mentioned is not real. You can hypothesis any reasonable context. \ | |
Please strictly only output one number. \ | |
Memory: {} \ | |
Rating: """ | |
QUESTION_PROMPT = """Given only the information above, what are 3 most salient \ | |
high-level questions we can answer about the subjects in the statements?""" | |
INSIGHT_PROMPT = """What at most 5 high-level insights can you infer from \ | |
the above statements? Only output insights with high confidence. | |
example format: insight (because of 1, 5, 3)""" | |
class Reflection(BaseMemoryManipulator): | |
memory: VectorStoreMemory = None | |
agent: BaseAgent = None | |
reflection: str = "" | |
importance_threshold: int = 10 | |
accumulated_importance: int = 0 | |
memory2importance: dict = {} | |
memory2immediacy: dict = {} | |
memory2time: defaultdict = Field(default=defaultdict(dict)) | |
# TODO newly added func from generative agents | |
def manipulate_memory(self) -> None: | |
# reflect here | |
if self.should_reflect(): | |
logger.debug( | |
f"Agent {self.agent.name} is now doing reflection since accumulated_importance={self.accumulated_importance} < reflection_threshold={self.importance_threshold}" | |
) | |
self.reflection = self.reflect() | |
return self.reflection | |
else: | |
logger.debug( | |
f"Agent {self.agent.name} doesn't reflect since accumulated_importance={self.accumulated_importance} < reflection_threshold={self.importance_threshold}" | |
) | |
return "" | |
def get_accumulated_importance(self): | |
accumulated_importance = 0 | |
for memory in self.memory.messages: | |
if ( | |
memory.content not in self.memory2importance | |
or memory.content not in self.memory2immediacy | |
): | |
self.memory2importance[memory.content] = self.get_importance( | |
memory.content | |
) | |
self.memory2immediacy[memory.content] = self.get_immediacy( | |
memory.content | |
) | |
for score in self.memory2importance.values(): | |
accumulated_importance += score | |
self.accumulated_importance = accumulated_importance | |
return accumulated_importance | |
def should_reflect(self): | |
if self.get_accumulated_importance() >= self.importance_threshold: | |
# double the importance_threshold | |
self.importance_threshold *= 2 | |
return True | |
else: | |
return False | |
def get_questions(self, texts): | |
prompt = "\n".join(texts) + "\n" + QUESTION_PROMPT | |
result = self.agent.llm.generate_response(prompt) | |
result = result.content | |
questions = [q for q in result.split("\n") if len(q.strip()) > 0] | |
questions = questions[:3] | |
return questions | |
def get_insights(self, statements): | |
prompt = "" | |
for i, st in enumerate(statements): | |
prompt += str(i + 1) + ". " + st + "\n" | |
prompt += INSIGHT_PROMPT | |
result = self.agent.llm.generate_response(prompt) | |
result = result.content | |
insights = [isg for isg in result.split("\n") if len(isg.strip()) > 0][:5] | |
insights = [".".join(i.split(".")[1:]) for i in insights] | |
# remove insight pointers for now | |
insights = [i.split("(")[0].strip() for i in insights] | |
return insights | |
def get_importance(self, content: str): | |
""" | |
Exploit GPT to evaluate the importance of this memory | |
""" | |
prompt = IMPORTANCE_PROMPT.format(content) | |
result = self.memory.llm.generate_response(prompt) | |
try: | |
score = int(re.findall(r"\s*(\d+)\s*", result.content)[0]) | |
except Exception as e: | |
logger.warn( | |
f"Found error {e} Abnormal result of importance rating '{result}'. Setting default value" | |
) | |
score = 0 | |
return score | |
def get_immediacy(self, content: str): | |
""" | |
Exploit GPT to evaluate the immediacy of this memory | |
""" | |
prompt = IMMEDIACY_PROMPT.format(content) | |
result = self.memory.llm.generate_response(prompt) | |
try: | |
score = int(re.findall(r"\s*(\d+)\s*", result.content)[0]) | |
except Exception as e: | |
logger.warn( | |
f"Found error {e} Abnormal result of importance rating '{result}'. Setting default value" | |
) | |
score = 0 | |
return score | |
def query_similarity( | |
self, | |
text: Union[str, List[str]], | |
k: int, | |
memory_bank: List, | |
current_time=dt.now(), | |
nms_threshold=0.99, | |
) -> List[str]: | |
""" | |
get top-k entry based on recency, relevance, importance, immediacy | |
The query result can be Short-term or Long-term queried result. | |
formula is | |
`score= sim(q,v) *max(LTM_score, STM_score)` | |
`STM_score=time_score(createTime)*immediacy` | |
`LTM_score=time_score(accessTime)*importance` | |
time score is exponential decay weight. stm decays faster. | |
The query supports querying based on multiple texts and only gives non-overlapping results | |
If nms_threshold is not 1, nms mechanism if activated. By default, | |
use soft nms with modified iou base(score starts to decay iff cos sim is higher than this value, | |
and decay weight at this value if 0. rather than 1-threshold). | |
Args: | |
text: str | |
k: int | |
memory_bank: List | |
current_time: dt.now | |
nms_threshold: float = 0.99 | |
Returns: List[str] | |
""" | |
assert len(text) > 0 | |
texts = [text] if isinstance(text, str) else text | |
maximum_score = None | |
for text in texts: | |
embedding = get_embedding(text) | |
score = [] | |
for memory in memory_bank: | |
if memory.content not in self.memory2time: | |
self.memory2time[memory.content]["last_access_time"] = dt.now() | |
self.memory2time[memory.content]["create_time"] = dt.now() | |
last_access_time_diff = ( | |
current_time - self.memory2time[memory.content]["last_access_time"] | |
).total_seconds() // 3600 | |
recency = np.power( | |
0.99, last_access_time_diff | |
) # TODO: review the metaparameter 0.99 | |
create_time_diff = ( | |
current_time - self.memory2time[memory.content]["create_time"] | |
).total_seconds() // 60 | |
instancy = np.power( | |
0.90, create_time_diff | |
) # TODO: review the metaparameter 0.90 | |
relevance = cosine_similarity( | |
np.array(embedding).reshape(1, -1), | |
np.array(self.memory.memory2embedding[memory.content]).reshape( | |
1, -1 | |
), | |
)[0][0] | |
if ( | |
memory.content not in self.memory2importance | |
or memory.content not in self.memory2immediacy | |
): | |
self.memory2importance[memory.content] = self.get_importance( | |
memory.content | |
) | |
self.memory2immediacy[memory.content] = self.get_immediacy( | |
memory.content | |
) | |
importance = self.memory2importance[memory.content] / 10 | |
immediacy = self.memory2immediacy[memory.content] / 10 | |
ltm_w = recency * importance | |
stm_w = instancy * immediacy | |
score.append(relevance * np.maximum(ltm_w, stm_w)) | |
score = np.array(score) | |
if maximum_score is not None: | |
maximum_score = np.maximum(score, maximum_score) | |
else: | |
maximum_score = score | |
if nms_threshold == 1.0: | |
# no nms is triggered | |
top_k_indices = np.argsort(maximum_score)[-k:][::-1] | |
else: | |
# TODO: soft-nms | |
assert 0 <= nms_threshold < 1 | |
top_k_indices = [] | |
while len(top_k_indices) < min(k, len(memory_bank)): | |
top_index = np.argmax(maximum_score) | |
top_k_indices.append(top_index) | |
maximum_score[top_index] = -1 # anything to prevent being chosen again | |
top_embedding = self.memory.memory2embedding[ | |
memory_bank[top_index].content | |
] | |
cos_sim = cosine_similarity( | |
np.array(top_embedding).reshape(1, -1), | |
np.array( | |
[ | |
self.memory.memory2embedding[memory.content] | |
for memory in memory_bank | |
] | |
), | |
)[0] | |
score_weight = np.ones_like(maximum_score) | |
score_weight[cos_sim >= nms_threshold] -= ( | |
cos_sim[cos_sim >= nms_threshold] - nms_threshold | |
) / (1 - nms_threshold) | |
maximum_score = maximum_score * score_weight | |
# access them and refresh the access time | |
for i in top_k_indices: | |
self.memory2time[memory_bank[i].content]["last_access_time"] = current_time | |
# sort them in time periods. if the data tag is 'observation', ad time info output. | |
top_k_indices = sorted( | |
top_k_indices, | |
key=lambda x: self.memory2time[memory_bank[x].content]["create_time"], | |
) | |
query_results = [] | |
for i in top_k_indices: | |
query_result = memory_bank[i].content | |
query_results.append(query_result) | |
return query_results | |
def get_memories_of_interest_oneself(self): | |
memories_of_interest = [] | |
for memory in self.memory.messages[-100:]: | |
if memory.sender == self.agent.name: | |
memories_of_interest.append(memory) | |
return memories_of_interest | |
def reflect(self): | |
""" | |
initiate a reflection that inserts high level knowledge to memory | |
""" | |
memories_of_interest = self.get_memories_of_interest_oneself() | |
questions = self.get_questions([m.content for m in memories_of_interest]) | |
statements = self.query_similarity( | |
questions, len(questions) * 10, memories_of_interest | |
) | |
insights = self.get_insights(statements) | |
logger.info(self.agent.name + f" Insights: {insights}") | |
for insight in insights: | |
# convert insight to messages | |
# TODO currently only oneself can see its own reflection | |
insight_message = Message( | |
content=insight, sender=self.agent.name, receiver={self.agent.name} | |
) | |
self.memory.add_message([insight_message]) | |
reflection = "\n".join(insights) | |
return reflection | |
def reset(self) -> None: | |
self.reflection = "" | |