Spaces:
Running
Running
File size: 4,401 Bytes
c2ba4d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# query를 자동으로 읽고 쓰는 container를 정의
from __future__ import annotations
import re
from typing import Callable, TypeVar
import streamlit as st
__all__ = ["QueryWrapper", "get_base_url"]
T = TypeVar("T")
import hashlib
import urllib.parse
def SHA1(msg: str) -> str:
return hashlib.sha1(msg.encode()).hexdigest()[:8]
def get_base_url():
session = st.runtime.get_instance()._session_mgr.list_active_sessions()[0]
return urllib.parse.urlunparse(
[session.client.request.protocol, session.client.request.host, "", "", "", ""]
)
class QueryWrapper:
queries: dict[str, _QueryWrapper] = {} # 기록용
def __init__(self, query: str, label: str | None = None, use_hash: bool = True):
self.__wrapper = QueryWrapper.queries[query] = _QueryWrapper(
query, label, use_hash
)
def __call__(self, *args, **kwargs):
return self.__wrapper(*args, **kwargs)
@classmethod
def get_sharable_link(cls):
# for k, v in cls.queries.items():
# print(f"{k}: {v}")
return re.sub(
"&+", "&", "&".join([str(v) for k, v in cls.queries.items()])
).strip("&")
class _QueryWrapper:
ILLEGAL_CHARS = "&/=?"
def __init__(self, query: str, label: str | None = None, use_hash: bool = True):
self.query = query
self.label = label or query
self.use_hash = use_hash
self.hash_table = {}
self.key = None
def __call__(
self,
base_container: Callable,
legal_list: list[T],
default: T | list[T] | None = None,
*,
key: str | None = None,
**kwargs,
) -> T | list[T] | None:
val_from_query = st.query_params.get_all(self.query.lower())
# print(val_from_query)
legal = len(val_from_query) > 0
self.key = key or self.label
self.hash_table = {SHA1(str(v)): v for v in legal_list}
# filter out illegal values
if legal and legal_list:
val_from_query = [v for v in val_from_query if v in self.hash_table]
# print(self.label, val_from_query, legal)
if legal:
selected = [self.hash_table[v] for v in val_from_query]
elif default:
selected = default
elif self.label in st.session_state:
selected = st.session_state[self.label]
if legal_list:
if isinstance(selected, list):
selected = [v for v in selected if v in legal_list]
elif selected not in legal_list:
selected = []
else:
selected = []
if selected is None:
pass
elif len(selected) == 1 and base_container in [st.selectbox, st.radio]:
selected = selected[0]
# print(self.label, selected)
if base_container == st.checkbox:
selected = base_container(
self.label,
legal_list,
index=legal_list.index(selected) if selected in legal_list else None,
key=self.key,
**kwargs,
)
elif base_container == st.multiselect:
selected = base_container(
self.label, legal_list, default=selected, key=self.key, **kwargs
)
elif base_container == st.radio:
selected = base_container(
self.label,
legal_list,
index=legal_list.index(selected) if selected in legal_list else None,
key=self.key,
**kwargs,
)
elif base_container == st.selectbox:
selected = base_container(
self.label,
legal_list,
index=legal_list.index(selected) if selected in legal_list else None,
key=self.key,
**kwargs,
)
else:
selected = base_container(self.label, legal_list, key=self.key, **kwargs)
return st.session_state[self.key]
def __str__(self):
selected = st.session_state.get(self.key, None)
if isinstance(selected, str):
return f"{self.query.lower()}={SHA1(selected)}"
elif isinstance(selected, list):
return "&".join([f"{self.query.lower()}={SHA1(str(v))}" for v in selected])
else:
return ""
|