Seems like metadata is not in the safetensors files

#2
by tanyongkeat - opened

Running AutoModel.from_pretrained("SeaLLMs/SeaLLM-7B-Hybrid") gets the following error messages:

File /usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py:566, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    564 elif type(config) in cls._model_mapping.keys():
    565     model_class = _get_model_class(config, cls._model_mapping)
--> 566     return model_class.from_pretrained(
    567         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    568     )
    569 raise ValueError(
    570     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    571     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    572 )

File /usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:3480, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3471     if dtype_orig is not None:
   3472         torch.set_default_dtype(dtype_orig)
   3473     (
   3474         model,
   3475         missing_keys,
   3476         unexpected_keys,
   3477         mismatched_keys,
   3478         offload_index,
   3479         error_msgs,
-> 3480     ) = cls._load_pretrained_model(
   3481         model,
   3482         state_dict,
   3483         loaded_state_dict_keys,  # XXX: rename?
   3484         resolved_archive_file,
   3485         pretrained_model_name_or_path,
   3486         ignore_mismatched_sizes=ignore_mismatched_sizes,
   3487         sharded_metadata=sharded_metadata,
   3488         _fast_init=_fast_init,
   3489         low_cpu_mem_usage=low_cpu_mem_usage,
   3490         device_map=device_map,
   3491         offload_folder=offload_folder,
   3492         offload_state_dict=offload_state_dict,
   3493         dtype=torch_dtype,
   3494         is_quantized=(getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES),
   3495         keep_in_fp32_modules=keep_in_fp32_modules,
   3496     )
   3498 model.is_loaded_in_4bit = load_in_4bit
   3499 model.is_loaded_in_8bit = load_in_8bit

File /usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:3856, in PreTrainedModel._load_pretrained_model(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, is_quantized, keep_in_fp32_modules)
   3854 if shard_file in disk_only_shard_files:
   3855     continue
-> 3856 state_dict = load_state_dict(shard_file)
   3858 # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
   3859 # matching the weights in the model.
   3860 mismatched_keys += _find_mismatched_keys(
   3861     state_dict,
   3862     model_state_dict,
   (...)
   3866     ignore_mismatched_sizes,
   3867 )

File /usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py:469, in load_state_dict(checkpoint_file)
    467 with safe_open(checkpoint_file, framework="pt") as f:
    468     metadata = f.metadata()
--> 469 if metadata.get("format") not in ["pt", "tf", "flax"]:
    470     raise OSError(
    471         f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
    472         "you save your model with the `save_pretrained` method."
    473     )
    474 return safe_load_file(checkpoint_file)

AttributeError: 'NoneType' object has no attribute 'get'

Saw another user with the same problem (https://huggingface.co/SeaLLMs/SeaLLM-7B-Hybrid/discussions/1#6571a97563aa14eb56fc6f73), but the thread was close, so I reopened it here.

SeaLLMs - Language Models for Southeast Asian Languages org

Thanks for reporting. I will put the metadata and upload a fix. In the meantime, you can overwrite metadata as torch to bypass this.

Thank you so much for the prompt reply!

Meanwhile, we can solve this by doing this on the 2 safetensors files

import safetensors
from safetensors.torch import save_file

tensors = dict()
with safetensors.safe_open(safetensors_path, framework="pt") as f:
    for key in f.keys():
        tensors[key] = f.get_tensor(key)

save_file(tensors, safetensors_path, metadata={'format': 'pt'})

Sign up or log in to comment