File size: 4,161 Bytes
e67043b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import requests
import json
from ..tool import Tool
import os

from steamship import Block, Steamship
import uuid
from enum import Enum
import re
from IPython import display
from IPython.display import Image


class ModelName(str, Enum):
    """Supported Image Models for generation."""

    DALL_E = "dall-e"
    STABLE_DIFFUSION = "stable-diffusion"


SUPPORTED_IMAGE_SIZES = {
    ModelName.DALL_E: ("256x256", "512x512", "1024x1024"),
    ModelName.STABLE_DIFFUSION: ("512x512", "768x768"),
}


def make_image_public(client: Steamship, block: Block) -> str:
    """Upload a block to a signed URL and return the public URL."""
    try:
        from steamship.data.workspace import SignedUrl
        from steamship.utils.signed_urls import upload_to_signed_url
    except ImportError:
        raise ValueError(
            "The make_image_public function requires the steamship"
            " package to be installed. Please install steamship"
            " with `pip install --upgrade steamship`"
        )

    filepath = str(uuid.uuid4())
    signed_url = (
        client.get_workspace()
        .create_signed_url(
            SignedUrl.Request(
                bucket=SignedUrl.Bucket.PLUGIN_DATA,
                filepath=filepath,
                operation=SignedUrl.Operation.WRITE,
            )
        )
        .signed_url
    )
    read_signed_url = (
        client.get_workspace()
        .create_signed_url(
            SignedUrl.Request(
                bucket=SignedUrl.Bucket.PLUGIN_DATA,
                filepath=filepath,
                operation=SignedUrl.Operation.READ,
            )
        )
        .signed_url
    )
    upload_to_signed_url(signed_url, block.raw())
    return read_signed_url


def show_output(output):
    """Display the multi-modal output from the agent."""
    UUID_PATTERN = re.compile(
        r"([0-9A-Za-z]{8}-[0-9A-Za-z]{4}-[0-9A-Za-z]{4}-[0-9A-Za-z]{4}-[0-9A-Za-z]{12})"
    )

    outputs = UUID_PATTERN.split(output)
    outputs = [
        re.sub(r"^\W+", "", el) for el in outputs
    ]  # Clean trailing and leading non-word characters

    for output in outputs:
        maybe_block_id = UUID_PATTERN.search(output)
        if maybe_block_id:
            display(Image(Block.get(Steamship(), _id=maybe_block_id.group()).raw()))
        else:
            print(output, end="\n\n")


def build_tool(config) -> Tool:
    tool = Tool(
        "Image Generator",
        "Tool that can generate image based on text description.",
        name_for_model="Image Generator",
        description_for_model=(
            "Useful for when you need to generate an image."
            "Input: A detailed text-2-image prompt describing an image"
            "Output: the UUID of a generated image"
        ),
        logo_url="https://your-app-url.com/.well-known/logo.png",
        contact_email="[email protected]",
        legal_info_url="[email protected]",
    )

    model_name: ModelName = ModelName.DALL_E  # choose model and image size?
    size: Optional[str] = "512x512"
    return_urls: Optional[bool] = False
    steamship_api_key = os.environ.get("STEAMSHIP_API_KEY", "")
    if steamship_api_key == "":
        raise RuntimeError(
            "STEAMSHIP_API_KEY is not provided. Please sign up for a free account at https://steamship.com/account/api, create a new API key, and add it to environment variables."
        )

    steamship = Steamship(
        api_key=steamship_api_key,
    )

    @tool.get("/generate_image")
    def generate_image(query: str):
        """Generate an image."""
        image_generator = steamship.use_plugin(
            plugin_handle=model_name.value, config={"n": 1, "size": size}
        )

        task = image_generator.generate(text=query, append_output_to_file=True)
        task.wait()
        blocks = task.output.blocks
        output_uiud = blocks[0].id
        if len(blocks) > 0:
            if return_urls:
                output_uiud = make_image_public(steamship, blocks[0])
            # print image?
            # show_output(output_uiud)
            return output_uiud
        raise RuntimeError("Tool unable to generate image!")

    return tool