File size: 4,166 Bytes
2ca0c5e
 
 
 
40912b5
2ca0c5e
 
 
 
 
40912b5
 
 
 
 
2ca0c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40912b5
 
 
 
2ca0c5e
 
 
 
 
40912b5
 
a2b7758
2ca0c5e
40912b5
 
 
2ca0c5e
40912b5
 
 
a2b7758
 
 
 
 
 
 
 
 
 
 
 
 
2ca0c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import keras
import keras_hub

model_presets = [
    # 8B params models
    "hf://google/gemma-2-instruct-9b-keras",
    "hf://meta-llama/Llama-3.1-8B-Instruct",
    "hf://google/codegemma-7b-it-keras",
    "hf://keras/mistral_instruct_7b_en",
    "hf://keras/vicuna_1.5_7b_en",
    # "keras/gemma_1.1_instruct_7b_en", # won't fit?
    # 1-3B params models
    "hf://meta-llama/Llama-3.2-1B-Instruct",
    "hf://google/gemma-2b-it-keras",
    "hf://meta-llama/Llama-3.2-3B-Instruct",
]

model_labels = map(lambda s: s.removeprefix("hf://"), model_presets)
model_labels = map(lambda s: s.removeprefix("google/"), model_labels)
model_labels = map(lambda s: s.removeprefix("keras/"), model_labels)
model_labels = map(lambda s: s.removeprefix("meta-llama/"), model_labels)


def preset_to_website_url(preset):
    preset = preset.removeprefix("hf://")
    url = "http://huggingface.co/" + preset
    return url


def get_appropriate_chat_template(preset):
    return "Vicuna" if "vicuna" in preset else "auto"


def get_default_layout_map(preset_name, device_mesh):
    # Llama's default layout map works for mistral and vicuna
    # because their transformer layers have the same names.
    if (
        "Llama" in preset_name
        or "mistral" in preset_name
        or "vicuna" in preset_name
    ):
        layout_map = keras_hub.models.Llama3Backbone.get_layout_map(device_mesh)
        # This line is missing for some Llama models (TODO: fix this in keras_hub)
        layout_map["token_embedding/reverse_embeddings"] = ("batch", "model")
        return layout_map
    elif "gemma" in preset_name:
        return keras_hub.models.GemmaBackbone.get_layout_map(device_mesh)


def log_applied_layout_map(model):
    print("Model class:", type(model).__name__)

    if "Gemma" in type(model).__name__:
        transformer_decoder_block_name = "decoder_block_1"
    elif "Llama" in type(model).__name__:  # works for Llama (Vicuna) and Llama3
        transformer_decoder_block_name = "transformer_layer_1"
    elif "Mistral" in type(model).__name__:
        transformer_decoder_block_name = "transformer_layer_1"
    else:
        print("Unknown architecture. Cannot display the applied layout.")
        return

    # See how layer sharding was applied
    embedding_layer = model.backbone.get_layer("token_embedding")
    print(embedding_layer)
    decoder_block = model.backbone.get_layer(transformer_decoder_block_name)
    print(type(decoder_block))
    for variable in embedding_layer.weights + decoder_block.weights:
        print(
            f"{variable.path:<58}  \
                {str(variable.shape):<16}  \
                {str(variable.value.sharding.spec):<35} \
                {str(variable.dtype)}"
        )


def load_model(preset):
    devices = keras.distribution.list_devices()
    device_mesh = keras.distribution.DeviceMesh(
        shape=(1, len(devices)), axis_names=["batch", "model"], devices=devices
    )
    model_parallel = keras.distribution.ModelParallel(
        layout_map=get_default_layout_map(preset, device_mesh),
        batch_dim_name="batch",
    )

    with model_parallel.scope():
        # These two buggy models need this workaround to be loaded in bfloat16
        if "google/gemma-2-instruct-9b-keras" in preset:
            model = keras_hub.models.GemmaCausalLM(
                backbone=keras_hub.models.GemmaBackbone.from_preset(
                    preset, dtype="bfloat16"
                ),
                preprocessor=keras_hub.models.GemmaCausalLMPreprocessor.from_preset(
                    preset
                ),
            )
        elif "meta-llama/Llama-3.1-8B-Instruct" in preset:
            model = keras_hub.models.Llama3CausalLM(
                backbone=keras_hub.models.Llama3Backbone.from_preset(
                    preset, dtype="bfloat16"
                ),
                preprocessor=keras_hub.models.Llama3CausalLMPreprocessor.from_preset(
                    preset
                ),
            )
        else:
            model = keras_hub.models.CausalLM.from_preset(
                preset, dtype="bfloat16"
            )

    log_applied_layout_map(model)
    return model