Jae-Won Chung
New leaderboard prototype
b10121d
raw
history blame
1.53 kB
import json
import random
import base64
SEED = 68
NUM_SAMPLES = 500
def main() -> None:
random.seed(SEED)
with open("full.json") as f:
data = json.load(f)
data = random.sample(data, NUM_SAMPLES * 2)
dataset = []
data_iter = iter(data)
while len(dataset) < NUM_SAMPLES:
sample = next(data_iter)
# 1. The image should exist.
# 2. Even index messages in the conversation should be from the human.
# 3. The first message should contain at most one "<image>" substring, which will be removed.
# 4. Even index messages will be concatenated to form the prompt.
image_path = "train2017/" + sample["image"]
conversation = []
for conv in sample["conversations"][::2]:
assert conv["from"] == "human", sample
conversation.append(conv["value"])
if (ind := conversation[0].find("<image>")) != -1:
conversation[0] = conversation[0][:ind] + conversation[0][ind+len("<image>"):]
message = ""
for conv in conversation:
assert "<image>" not in conv, sample
message += conv.strip() + " "
message = message.strip()
dataset.append(
dict(
image=base64.b64encode(open(image_path, "rb").read()).decode("utf-8"),
prompt=message,
),
)
with open(f"llava_conversation_{NUM_SAMPLES}.json", "w") as f:
json.dump(dataset, f)
if __name__ == "__main__":
main()