File size: 4,401 Bytes
c2ba4d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# query를 자동으로 읽고 쓰는 container를 정의

from __future__ import annotations

import re
from typing import Callable, TypeVar

import streamlit as st

__all__ = ["QueryWrapper", "get_base_url"]

T = TypeVar("T")


import hashlib
import urllib.parse


def SHA1(msg: str) -> str:
    return hashlib.sha1(msg.encode()).hexdigest()[:8]


def get_base_url():
    session = st.runtime.get_instance()._session_mgr.list_active_sessions()[0]
    return urllib.parse.urlunparse(
        [session.client.request.protocol, session.client.request.host, "", "", "", ""]
    )


class QueryWrapper:
    queries: dict[str, _QueryWrapper] = {}  # 기록용

    def __init__(self, query: str, label: str | None = None, use_hash: bool = True):
        self.__wrapper = QueryWrapper.queries[query] = _QueryWrapper(
            query, label, use_hash
        )

    def __call__(self, *args, **kwargs):
        return self.__wrapper(*args, **kwargs)

    @classmethod
    def get_sharable_link(cls):
        # for k, v in cls.queries.items():
        #     print(f"{k}: {v}")
        return re.sub(
            "&+", "&", "&".join([str(v) for k, v in cls.queries.items()])
        ).strip("&")


class _QueryWrapper:
    ILLEGAL_CHARS = "&/=?"

    def __init__(self, query: str, label: str | None = None, use_hash: bool = True):
        self.query = query
        self.label = label or query
        self.use_hash = use_hash
        self.hash_table = {}
        self.key = None

    def __call__(
        self,
        base_container: Callable,
        legal_list: list[T],
        default: T | list[T] | None = None,
        *,
        key: str | None = None,
        **kwargs,
    ) -> T | list[T] | None:
        val_from_query = st.query_params.get_all(self.query.lower())
        # print(val_from_query)
        legal = len(val_from_query) > 0
        self.key = key or self.label

        self.hash_table = {SHA1(str(v)): v for v in legal_list}

        # filter out illegal values
        if legal and legal_list:
            val_from_query = [v for v in val_from_query if v in self.hash_table]
        # print(self.label, val_from_query, legal)
        if legal:
            selected = [self.hash_table[v] for v in val_from_query]
        elif default:
            selected = default
        elif self.label in st.session_state:
            selected = st.session_state[self.label]
            if legal_list:
                if isinstance(selected, list):
                    selected = [v for v in selected if v in legal_list]
                elif selected not in legal_list:
                    selected = []
        else:
            selected = []
        if selected is None:
            pass
        elif len(selected) == 1 and base_container in [st.selectbox, st.radio]:
            selected = selected[0]
        # print(self.label, selected)
        if base_container == st.checkbox:
            selected = base_container(
                self.label,
                legal_list,
                index=legal_list.index(selected) if selected in legal_list else None,
                key=self.key,
                **kwargs,
            )
        elif base_container == st.multiselect:
            selected = base_container(
                self.label, legal_list, default=selected, key=self.key, **kwargs
            )
        elif base_container == st.radio:
            selected = base_container(
                self.label,
                legal_list,
                index=legal_list.index(selected) if selected in legal_list else None,
                key=self.key,
                **kwargs,
            )
        elif base_container == st.selectbox:
            selected = base_container(
                self.label,
                legal_list,
                index=legal_list.index(selected) if selected in legal_list else None,
                key=self.key,
                **kwargs,
            )
        else:
            selected = base_container(self.label, legal_list, key=self.key, **kwargs)
        return st.session_state[self.key]

    def __str__(self):
        selected = st.session_state.get(self.key, None)
        if isinstance(selected, str):
            return f"{self.query.lower()}={SHA1(selected)}"
        elif isinstance(selected, list):
            return "&".join([f"{self.query.lower()}={SHA1(str(v))}" for v in selected])
        else:
            return ""