import gradio as gr import pickle import tensorflow as tf import keras.ops as ops import keras from keras import layers from keras.layers import TextVectorization # from gradio_webrtc import WebRTC @keras.saving.register_keras_serializable() class TextVectorization(keras.layers.TextVectorization): pass @keras.saving.register_keras_serializable() class StringLookup(keras.layers.StringLookup): pass @keras.saving.register_keras_serializable(package="Transformer") class TransformerEncoder(layers.Layer): def __init__(self, embed_dim, dense_dim, num_heads, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim self.dense_dim = dense_dim self.num_heads = num_heads self.attention = layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim ) self.dense_proj = keras.Sequential( [ layers.Dense(dense_dim, activation="relu"), layers.Dense(embed_dim), ] ) self.layernorm_1 = layers.LayerNormalization() self.layernorm_2 = layers.LayerNormalization() self.supports_masking = True def call(self, inputs, mask=None): if mask is not None: padding_mask = ops.cast(mask[:, None, :], dtype="int32") else: padding_mask = None attention_output = self.attention( query=inputs, value=inputs, key=inputs, attention_mask=padding_mask ) proj_input = self.layernorm_1(inputs + attention_output) proj_output = self.dense_proj(proj_input) return self.layernorm_2(proj_input + proj_output) def get_config(self): config = super().get_config() config.update( { "embed_dim": self.embed_dim, "dense_dim": self.dense_dim, "num_heads": self.num_heads, } ) return config @keras.saving.register_keras_serializable(package="Transformer") class PositionalEmbedding(layers.Layer): def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs): super().__init__(**kwargs) self.token_embeddings = layers.Embedding( input_dim=vocab_size, output_dim=embed_dim ) self.position_embeddings = layers.Embedding( input_dim=sequence_length, output_dim=embed_dim ) self.sequence_length = sequence_length self.vocab_size = vocab_size self.embed_dim = embed_dim def call(self, inputs): length = ops.shape(inputs)[-1] positions = ops.arange(0, length, 1) embedded_tokens = self.token_embeddings(inputs) embedded_positions = self.position_embeddings(positions) return embedded_tokens + embedded_positions def compute_mask(self, inputs, mask=None): return ops.not_equal(inputs, 0) def get_config(self): config = super().get_config() config.update( { "sequence_length": self.sequence_length, "vocab_size": self.vocab_size, "embed_dim": self.embed_dim, } ) return config @keras.saving.register_keras_serializable(package="Transformer") class TransformerDecoder(layers.Layer): def __init__(self, embed_dim, latent_dim, num_heads, **kwargs): super().__init__(**kwargs) self.embed_dim = embed_dim self.latent_dim = latent_dim self.num_heads = num_heads self.attention_1 = layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim ) self.attention_2 = layers.MultiHeadAttention( num_heads=num_heads, key_dim=embed_dim ) self.dense_proj = keras.Sequential( [ layers.Dense(latent_dim, activation="relu"), layers.Dense(embed_dim), ] ) self.layernorm_1 = layers.LayerNormalization() self.layernorm_2 = layers.LayerNormalization() self.layernorm_3 = layers.LayerNormalization() self.supports_masking = True def call(self, inputs, mask=None): inputs, encoder_outputs = inputs causal_mask = self.get_causal_attention_mask(inputs) if mask is None: inputs_padding_mask, encoder_outputs_padding_mask = None, None else: inputs_padding_mask, encoder_outputs_padding_mask = mask attention_output_1 = self.attention_1( query=inputs, value=inputs, key=inputs, attention_mask=causal_mask, query_mask=inputs_padding_mask, ) out_1 = self.layernorm_1(inputs + attention_output_1) attention_output_2 = self.attention_2( query=out_1, value=encoder_outputs, key=encoder_outputs, query_mask=inputs_padding_mask, key_mask=encoder_outputs_padding_mask, ) out_2 = self.layernorm_2(out_1 + attention_output_2) proj_output = self.dense_proj(out_2) return self.layernorm_3(out_2 + proj_output) def get_causal_attention_mask(self, inputs): input_shape = ops.shape(inputs) batch_size, sequence_length = input_shape[0], input_shape[1] i = ops.arange(sequence_length)[:, None] j = ops.arange(sequence_length) mask = ops.cast(i >= j, dtype="int32") mask = ops.reshape(mask, (1, input_shape[1], input_shape[1])) mult = ops.concatenate( [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], axis=0, ) return ops.tile(mask, mult) def get_config(self): config = super().get_config() config.update( { "embed_dim": self.embed_dim, "latent_dim": self.latent_dim, "num_heads": self.num_heads, } ) return config with open("id_vectorization_transformer.pickle", "rb") as file: from_disk = pickle.load(file) id_vectorization = TextVectorization.from_config(from_disk['config']) id_vectorization.adapt(tf.data.Dataset.from_tensor_slices(["xyz"])) id_vectorization.set_weights(from_disk['weights']) id_vectorization.set_vocabulary(from_disk["vocab"]) with open("en_vectorization_transformer.pickle", "rb") as file: from_disk = pickle.load(file) en_vectorization = TextVectorization.from_config(from_disk['config']) en_vectorization.adapt(tf.data.Dataset.from_tensor_slices(["xyz"])) en_vectorization.set_weights(from_disk['weights']) en_vectorization.set_vocabulary(from_disk["vocab"]) transformer = keras.models.load_model( "transformer_keras.keras", custom_objects={"TransformerEncoder": TransformerEncoder, "TransformerDecoder": TransformerDecoder, "PositionalEmbedding": PositionalEmbedding} ) id_vocab = id_vectorization.get_vocabulary() id_index_lookup = dict(zip(range(len(id_vocab)), id_vocab)) max_decoded_sentence_lenth = 20 def decode_sequence(input_sentence): tokenized_input_sentence = en_vectorization([input_sentence]) decoded_sentence = "[start]" for i in range(max_decoded_sentence_lenth): tokenized_target_sentence = id_vectorization([decoded_sentence])[:, :-1] predictions = transformer( { "encoder_inputs": tokenized_input_sentence, "decoder_inputs": tokenized_target_sentence, } ) sampled_token_index = ops.convert_to_numpy( ops.argmax(predictions[0, i, :]) ).item(0) sampled_token = id_index_lookup[sampled_token_index] decoded_sentence += " " + sampled_token if sampled_token == "end": break return decoded_sentence.replace("[start]", "").replace("end", "").lstrip().rstrip() # image = WebRTC(label="Stream") desc=("