mm-arxiv / app.py
anas-awadalla's picture
demo
f947031
import streamlit as st
import json
import base64
import glob
from PIL import Image
import io
import re
def display_image(image_bytes):
image_data = base64.b64decode(image_bytes)
image = Image.open(io.BytesIO(image_data))
st.image(image, width=500)
def load_json(file_path):
with open(file_path, 'r') as file:
return json.load(file)
def display_image_caption_pairs(json_data):
image_bytes = json_data['images_bytes']
caption = json_data['caption']
if isinstance(image_bytes, list):
for img_bytes in image_bytes:
display_image(img_bytes)
else:
display_image(image_bytes)
st.markdown(f"**Caption:** {caption}")
def display_interleaved_text_and_images(text, images_bytes):
pattern = r'<img[^>]*>'
segments = re.split(pattern, text)
for i, segment in enumerate(segments):
st.markdown(segment)
if i < len(images_bytes):
for img_bytes in images_bytes[i]:
display_image(img_bytes)
def main():
st.title("🧪 Multimodal Arxiv Viewer")
# Select the mode
mode = st.selectbox("Select Dataset", ["Image-Caption Pairs", "Interleaved Sequences"], on_change=st.session_state.clear)
# Path to the directory containing the JSON files
# get pwd
import os
pwd = os.getcwd()
if mode == "Image-Caption Pairs":
# json_files = glob.glob('/image_text_arXiv_src_2305_145-000000000/*.json')
# use pwd instead of absolute path
json_files = glob.glob(pwd + '/image_text/*.json')
else:
# json_files = glob.glob('interleaved_arXiv_src_2305_145-000000000/*.json')
# use pwd instead of absolute path
json_files = glob.glob(pwd + '/interleaved/*.json')
# Session state to keep track of the current file index
if 'file_index' not in st.session_state:
st.session_state.file_index = 0
# Display "Previous" and "Next" buttons
col1, col2 = st.columns(2)
if col1.button("Previous"):
st.session_state.file_index -= 1
st.session_state.file_index = max(0, st.session_state.file_index)
if col2.button("Next"):
st.session_state.file_index += 1
st.session_state.file_index = min(len(json_files) - 1, st.session_state.file_index)
st.markdown(f"**File {st.session_state.file_index + 1} of {len(json_files)}**")
st.text("")
st.text("")
selected_file = json_files[st.session_state.file_index]
json_data = load_json(selected_file)
# Display based on the selected mode
if mode == "Image-Caption Pairs":
display_image_caption_pairs(json_data)
else:
images_bytes = json_data['images']
# # if there are no images remove the json file
# if len(images_bytes) <= 1 or (len(images_bytes[0]) == 1 and len(images_bytes)==1) or json_data['txt'] == "":
# print(f"Removing {selected_file}")
# os.remove(selected_file)
# st.session_state.file_index -= 1
# st.session_state.file_index = max(0, st.session_state.file_index)
# selected_file = json_files[st.session_state.file_index]
# json_data = load_json(selected_file)
# images_bytes = json_data['images']
text = json_data['txt']
display_interleaved_text_and_images(text, images_bytes)
if __name__ == "__main__":
main()