Train in 30B Byte. Mode size 353M. Table 2 in MambaByte

To use

import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

import numpy as np

model=MambaLMHeadModel.from_pretrained("JunxiongWang/MambaByte_Code", device='cuda', dtype=torch.float32)

text = "import torch"
text_byte = np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
input_ids = torch.from_numpy(text_byte[None, :].copy()).long().cuda()

sample = model.generate(
    input_ids=input_ids,
    max_length=2048,
    cg=True,
    return_dict_in_generate=True,
    output_scores=True,
    enable_timing=True,
    temperature=1,
    top_k=256,
    top_p=0.9,
)

print(bytes(sample.sequences[0].tolist()).decode('utf-8'))

Output

import torch
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable

from networkx.states import TransientState

def extract_data(num_epochs, epochs, is_last_epoch):

    def get_data(num_features, num_classes):
        data_features = num_features
        data_classes = num_classes
        data_labels = num_epochs

        if num_features == 0 or num_classes == 0:
            return data_features, data_classes
        if is_last_epoch:
            data_features = num_features
            data_classes = num_classes
            data_labels = num_epochs
        return data_features, data_classes

    data_features, data_classes = get_data(num_epochs, epochs, is_last_epoch)
    data_labels = num_epochs * 2
    return data_features, data_classes


class NumChannel:
    def __init__(self, x, y, dx=1, dy=1, idx=1, data_size=2, epoch=None):
        """idx is the channel index with data feature in the first epoch.
        x is the channel of the input data.
        y is the element of the input data.
        dx is the element of the data feature of the input data.
        data_size is the size of the element of the data.
        epoch is the channel of the element of the data.
        """
        self.x = x
        self.y = y
        self.dx = dx
        self.data_size = data_size
        self.epoch = epoch
        self.reference_count = 0
        self.data_features = {}
        self.data_classes = {}

        self._initialize()
        if idx is not None:
            self._start_time = time.time()

    def _initialize(self):
        """idx is the channel index with data feature in the first epoch.
        x is the channel of the input data.
        y is the element of the input data.
        dx is the element of the data feature of the input data.
        data_size is the size of the element of the data.
        epoch is the channel of the element of the data.
        """
        self.idx = idx
Downloads last month
21
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Collection including JunxiongWang/MambaByte_Code