ldwang's picture
Create README.md
eda73a2 verified
|
raw
history blame
1.4 kB

Approach

This model of Mamba architecture has been pre-trained on approximately 400B tokens of Chinese and English corpora.

Usage

import torch

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer

repo_id = 'mamba-1.4b-aquila-400b'
device = f"cuda:0"
model = MambaLMHeadModel.from_pretrained(repo_id, dtype=torch.bfloat16, device=device)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(repo_id)
prompt = "The Spring Festival is"
tokens = tokenizer.encode_plus(prompt, truncation=False)["input_ids"]
tokens = torch.tensor(tokens)[None,].to(device)
with torch.no_grad():
    input_length = len(tokens[0])
    out_ids = model.generate(input_ids=tokens, max_length=input_length+200, temperature=1.0, top_p=0.95, eos_token_id=tokenizer.eos_token_id, cg=True, top_k=15)
    out_ids = out_ids[0][input_length:].cpu().numpy()
    out_text = tokenizer.decode(out_ids.tolist())
    print(out_text)

the most important festival of the year for the Chinese people. It usually comes in January or February and it takes about 15 days to prepare for it.

References

The Mamba architecture was introduced in Mamba: Linear-Time Sequence Modeling with Selective State Spaces.

The official implementation is here: https://github.com/state-spaces/mamba/tree/main