Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
from typing import Any, List, Optional | |
from langchain.schema import AIMessage, HumanMessage | |
from langchain_core.pydantic_v1 import SecretStr | |
from src.models.generator import ImgGenerator, QuestionGenerator | |
from src.models.lc_base_model import ChainGenerator, ContentGenerator, EvaluationChatModel | |
from src.models.lc_img_desc_model import EvaluationChatModelImg | |
from src.models.lc_qa_model import EvaluationChatModelQA | |
class ModelFactory(ABC): | |
def create_model(self, *args: Any, **kwargs: Any) -> ChainGenerator: | |
""" | |
An abstract method to create a model, with the return types EvaluationChatModel or ChainGenerator. | |
""" | |
class EvaluationChatModelFactory(ModelFactory): | |
def create_model(self, model_class: str, openai_api_key: SecretStr, **kwargs: Any) -> EvaluationChatModel: | |
""" | |
Create a model based on the provided model class and OpenAI API key. | |
Args: | |
model_class (str): The type of model to create. | |
openai_api_key (SecretStr): The API key for OpenAI. | |
**kwargs (Any): Additional keyword arguments. | |
Returns: | |
EvaluationChatModel: The created evaluation chat model. | |
Raises: | |
ValueError: If an invalid model class is provided. | |
""" | |
match model_class: | |
case "qa": | |
return EvaluationChatModelQA(openai_api_key=openai_api_key, **kwargs) | |
case "img_desc": | |
return EvaluationChatModelImg(openai_api_key=openai_api_key, **kwargs) | |
case _: | |
raise ValueError("Invalid model class provided") | |
class GeneratorModelFactory(ModelFactory): | |
def create_model( | |
self, | |
model_class: str, | |
openai_api_key: SecretStr, | |
history_chat: Optional[List[HumanMessage | AIMessage]] = None, | |
img_size: str = "256x256", | |
**kwargs: Any, | |
) -> ContentGenerator: | |
""" | |
Generate a model based on the specified model class and parameters. | |
Parameters: | |
model_class (str): The class of the model to create. | |
openai_api_key (SecretStr): The API key for OpenAI. | |
history_chat (Optional[list], optional): List of chat history. Defaults to None. | |
img_size (str, optional): The size of the image. Defaults to "256x256". | |
**kwargs (Any): Additional keyword arguments. | |
Returns: | |
ContentGenerator: A generator for the specified model class. | |
""" | |
match model_class: | |
case "qa": | |
return QuestionGenerator(openai_api_key=openai_api_key, history_chat=history_chat or [], **kwargs) | |
case "img_desc": | |
return ImgGenerator(openai_api_key=openai_api_key, img_size=img_size, **kwargs) | |
case _: | |
raise ValueError("Invalid model class provided") | |