Lagent / lagent /utils /util.py
yanyoyo
update
ec878fd
import asyncio
import importlib
import inspect
import logging
import os
import os.path as osp
import sys
import time
from functools import partial
from logging.handlers import RotatingFileHandler
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
def load_class_from_string(class_path: str, path=None):
path_in_sys = False
if path:
if path not in sys.path:
path_in_sys = True
sys.path.insert(0, path)
try:
module_name, class_name = class_path.rsplit('.', 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
return cls
finally:
if path and path_in_sys:
sys.path.remove(path)
def create_object(config: Union[Dict, Any] = None):
"""Create an instance based on the configuration where 'type' is a
preserved key to indicate the class (path). When accepting non-dictionary
input, the function degenerates to an identity.
"""
if config is None or not isinstance(config, dict):
return config
assert isinstance(config, dict) and 'type' in config
config = config.copy()
obj_type = config.pop('type')
if isinstance(obj_type, str):
obj_type = load_class_from_string(obj_type)
if inspect.isclass(obj_type):
obj = obj_type(**config)
else:
assert callable(obj_type)
obj = partial(obj_type, **config)
return obj
async def async_as_completed(futures: Iterable[asyncio.Future]):
"""A asynchronous wrapper for `asyncio.as_completed`"""
loop = asyncio.get_event_loop()
wrappers = []
for fut in futures:
assert isinstance(fut, asyncio.Future)
wrapper = loop.create_future()
fut.add_done_callback(wrapper.set_result)
wrappers.append(wrapper)
for next_completed in asyncio.as_completed(wrappers):
yield await next_completed
def filter_suffix(response: Union[str, List[str]],
suffixes: Optional[List[str]] = None) -> str:
"""Filter response with suffixes.
Args:
response (Union[str, List[str]]): generated responses by LLMs.
suffixes (str): a list of suffixes to be deleted.
Return:
str: a clean response.
"""
if suffixes is None:
return response
batched = True
if isinstance(response, str):
response = [response]
batched = False
processed = []
for resp in response:
for item in suffixes:
# if response.endswith(item):
# response = response[:len(response) - len(item)]
if item in resp:
resp = resp.split(item)[0]
processed.append(resp)
if not batched:
return processed[0]
return processed
def get_logger(
name: str = 'lagent',
level: str = 'debug',
fmt:
str = '%(asctime)s %(levelname)8s %(filename)20s %(lineno)4s - %(message)s',
add_file_handler: bool = False,
log_dir: str = 'log',
log_file: str = time.strftime('%Y-%m-%d.log', time.localtime()),
max_bytes: int = 5 * 1024 * 1024,
backup_count: int = 3,
):
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(getattr(logging, level.upper(), logging.DEBUG))
formatter = logging.Formatter(fmt)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
if add_file_handler:
if not osp.exists(log_dir):
os.makedirs(log_dir)
log_file_path = osp.join(log_dir, log_file)
file_handler = RotatingFileHandler(
log_file_path,
maxBytes=max_bytes,
backupCount=backup_count,
encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
class GeneratorWithReturn:
"""Generator wrapper to capture the return value."""
def __init__(self, generator: Generator):
self.generator = generator
self.ret = None
def __iter__(self):
self.ret = yield from self.generator
return self.ret