File size: 5,323 Bytes
b31f748 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import math
from base64 import b32encode, b32decode
from pybase64 import urlsafe_b64encode, urlsafe_b64decode
from loguru import logger as log
import os
import time
from pathlib import Path
from urllib.request import urlretrieve
from blake3 import blake3
from platformdirs import PlatformDirs
APP_NAME = "iscc-sct"
APP_AUTHOR = "iscc"
dirs = PlatformDirs(appname=APP_NAME, appauthor=APP_AUTHOR)
os.makedirs(dirs.user_data_dir, exist_ok=True)
__all__ = [
"timer",
"get_model",
"encode_base32",
"encode_base64",
"hamming_distance",
"iscc_distance",
"MODEL_PATH",
]
BASE_VERSION = "1.0.0"
BASE_URL = f"https://github.com/iscc/iscc-binaries/releases/download/v{BASE_VERSION}"
MODEL_FILENAME = "iscc-sct-v0.1.0.onnx"
MODEL_URL = f"{BASE_URL}/{MODEL_FILENAME}"
MODEL_PATH = Path(dirs.user_data_dir) / MODEL_FILENAME
MODEL_CHECKSUM = "ff254d62db55ed88a1451b323a66416f60838dd2f0338dba21bc3b8822459abc"
class timer:
def __init__(self, message: str):
self.message = message
def __enter__(self):
# Record the start time
self.start_time = time.perf_counter()
def __exit__(self, exc_type, exc_value, traceback):
# Calculate the elapsed time
elapsed_time = time.perf_counter() - self.start_time
# Log the message with the elapsed time
log.debug(f"{self.message} {elapsed_time:.4f} seconds")
def get_model(): # pragma: no cover
"""Check and return local model file if it exists, otherwise download."""
if MODEL_PATH.exists():
try:
return check_integrity(MODEL_PATH, MODEL_CHECKSUM)
except RuntimeError:
log.warning("Model file integrity error - redownloading ...")
urlretrieve(MODEL_URL, filename=MODEL_PATH)
else:
log.info("Downloading embedding model ...")
urlretrieve(MODEL_URL, filename=MODEL_PATH)
return check_integrity(MODEL_PATH, MODEL_CHECKSUM)
def check_integrity(file_path, checksum):
# type: (str|Path, str) -> Path
"""
Check file integrity against blake3 checksum
:param file_path: path to file to be checked
:param checksum: blake3 checksum to verify integrity
:raises RuntimeError: if verification fails
"""
file_path = Path(file_path)
file_hasher = blake3(max_threads=blake3.AUTO)
with timer("INTEGRITY check time"):
file_hasher.update_mmap(file_path)
file_hash = file_hasher.hexdigest()
if checksum != file_hash:
msg = f"Failed integrity check for {file_path.name}"
log.error(msg)
raise RuntimeError(msg)
return file_path
def encode_base32(data):
# type: (bytes) -> str
"""
Standard RFC4648 base32 encoding without padding.
:param bytes data: Data for base32 encoding
:return: Base32 encoded str
"""
return b32encode(data).decode("ascii").rstrip("=")
def decode_base32(code):
# type: (str) -> bytes
"""
Standard RFC4648 base32 decoding without padding and with casefolding.
"""
# python stdlib does not support base32 without padding, so we have to re-pad.
cl = len(code)
pad_length = math.ceil(cl / 8) * 8 - cl
return bytes(b32decode(code + "=" * pad_length, casefold=True))
def encode_base64(data):
# type: (bytes) -> str
"""
Standard RFC4648 base64url encoding without padding.
"""
code = urlsafe_b64encode(data).decode("ascii")
return code.rstrip("=")
def decode_base64(code):
# type: (str) -> bytes
"""
Standard RFC4648 base64url decoding without padding.
"""
padding = 4 - (len(code) % 4)
string = code + ("=" * padding)
return urlsafe_b64decode(string)
def hamming_distance(a, b):
# type: (bytes, bytes) -> int
"""
Calculate the bitwise Hamming distance between two bytes objects.
:param a: The first bytes object.
:param b: The second bytes object.
:return: The Hamming distance between two bytes objects.
:raise ValueError: If a and b are not the same length.
"""
if len(a) != len(b):
raise ValueError("The lengths of the two bytes objects must be the same")
distance = 0
for b1, b2 in zip(a, b):
xor_result = b1 ^ b2
distance += bin(xor_result).count("1")
return distance
def iscc_distance(iscc1, iscc2):
# type: (str, str) -> int
"""
Calculate the Hamming distance between two ISCC Semantic Text Codes.
:param iscc1: The first ISCC Semantic Text Code.
:param iscc2: The second ISCC Semantic Text Code.
:return: The Hamming distance between the two ISCC codes.
:raise ValueError: If the input ISCCs are not valid or of different lengths.
"""
# Remove the "ISCC:" prefix if present
iscc1 = iscc1[5:] if iscc1.startswith("ISCC:") else iscc1
iscc2 = iscc2[5:] if iscc2.startswith("ISCC:") else iscc2
# Decode the base32-encoded ISCCs
decoded1 = decode_base32(iscc1)
decoded2 = decode_base32(iscc2)
# Check if the decoded ISCCs have the same length
if len(decoded1) != len(decoded2):
raise ValueError("The input ISCCs must have the same length")
# Remove the 2-byte header from each decoded ISCC
content1 = decoded1[2:]
content2 = decoded2[2:]
# Calculate and return the Hamming distance
return hamming_distance(content1, content2)
|