File size: 16,694 Bytes
1606665 a356e6c de7da50 a356e6c 1606665 a356e6c de7da50 a356e6c 39ba46e a356e6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
---
license: mit
datasets:
- sail/regmix-data
- sail/regmix-data-sample
language:
- en
---
# Models Trained with Random Mixture
This is a collection of 64 language models, each with approximately 1B parameters, trained on different random mixtures of data. This project aims to validate the generalization capabilities of the RegMix approach (https://huggingface.co/papers/2407.01492) from small-scale (e.g., 1M parameters) to large-scale (e.g., 1B parameters) models.
## Key Features
- **Model Size**: 64 separate models, each with ~1B parameters
- **Training Data**: Random data mixtures on the [RegMix-Data](https://huggingface.co/datasets/sail/regmix-data) dataset
- **Purpose**: To validate the effectiveness of RegMix on identifying high-performing data mixture
## Dataset
The models were trained using the [RegMix-Data](https://huggingface.co/datasets/sail/regmix-data) dataset, which is split into different domains from The Pile dataset.
## Training Hyperparameters
| Hyperparameter | Value |
|:---------------|:------|
| Batch Size | 1M tokens |
| Learning Rate | 4e-4 |
| Minimum Learning Rate | 1e-5 |
| Learning Rate Schedule | Cosine |
| Warmup Ratio | 4% |
| Total Tokens | 25B |
## How to Load a Model
You can load any model using the corresponding branch with the Hugging Face Transformers library:
```python
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("sail/data-mixture-random-1b", revision="model-index-1")
tokenizer = AutoTokenizer.from_pretrained("sail/data-mixture-random-1b", revision="model-index-1")
```
## Data Mixture
The specific data mixture used for training each 1B model can be found in the file `train_config.yaml` in each corresponding model branch.
## Model Variants
To access different model variants, simply change the `revision` parameter in the `from_pretrained` method to the desired model index (e.g., "model-index-2", "model-index-3"), and the maxium index is 64.
## Usage Notes
- These models are primarily intended for research purposes.
- Performance may vary depending on the specific task and domain.
## Citation
If you use these models in your research, please cite the RegMix paper:
```
@article{liu2024regmix,
title={RegMix: Data Mixture as Regression for Language Model Pre-training},
author={Liu, Qian and Zheng, Xiaosen and Muennighoff, Niklas and Zeng, Guangtao and Dou, Longxu and Pang, Tianyu and Jiang, Jing and Lin, Min},
journal={arXiv preprint arXiv:2407.01492},
year={2024}
}
```
For more information about the RegMix methodology and its applications, please refer to the [original paper](https://huggingface.co/papers/2407.01492).
## Performance
We evaluated each model using [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness). The performance metric for each task is the average of 0-shot to 5-shot `accnorm` (accuracy normalized, if available) or `acc` (accuracy) scores.
### Table 1: Model Index 1-8
| Task | Model 1 | Model 2 | Model 3 | Model 4 | Model 5 | Model 6 | Model 7 | Model 8 |
|---------------|---------|---------|---------|---------|---------|---------|---------|---------|
| Social IQA | 33.27 | 33.33 | 33.62 | 33.53 | 33.49 | 33.56 | 33.62 | 33.55 |
| HellaSwag | 40.58 | 36.86 | 40.58 | 36.06 | 40.07 | 37.85 | 37.93 | 39.59 |
| PiQA | 67.29 | 65.14 | 67.97 | 64.66 | 67.03 | 65.36 | 66.00 | 66.55 |
| OpenBookQA | 28.63 | 27.87 | 29.33 | 29.10 | 29.23 | 28.33 | 29.13 | 28.73 |
| Lambada | 29.17 | 26.86 | 31.55 | 27.11 | 29.16 | 28.92 | 31.53 | 30.92 |
| SciQ | 80.68 | 79.98 | 81.05 | 80.80 | 82.40 | 79.88 | 78.67 | 79.70 |
| COPA | 70.50 | 63.83 | 69.17 | 65.00 | 67.50 | 66.00 | 66.67 | 68.67 |
| RACE | 29.47 | 30.00 | 32.11 | 28.82 | 31.13 | 30.06 | 29.90 | 30.75 |
| ARC Easy | 50.03 | 48.72 | 50.01 | 46.64 | 51.06 | 47.46 | 46.75 | 48.39 |
| LogiQA | 23.76 | 24.17 | 25.29 | 25.29 | 24.55 | 25.96 | 25.45 | 26.32 |
| QQP | 55.71 | 55.90 | 54.84 | 56.52 | 54.01 | 56.34 | 52.35 | 54.20 |
| WinoGrande | 51.54 | 51.59 | 51.39 | 50.91 | 53.13 | 52.26 | 51.26 | 51.45 |
| MultiRC | 52.65 | 53.39 | 51.89 | 50.92 | 49.03 | 53.09 | 53.64 | 50.23 |
| **Average** | **47.18** | **45.97** | **47.60** | **45.80** | **47.06** | **46.54** | **46.38** | **46.85** |
### Table 2: Model Index 9-16
| Task | Model 9 | Model 10 | Model 11 | Model 12 | Model 13 | Model 14 | Model 15 | Model 16 |
|---------------|---------|----------|----------|----------|----------|----------|----------|----------|
| Social IQA | 33.43 | 33.21 | 33.31 | 33.17 | 33.28 | 32.43 | 33.57 | 33.70 |
| HellaSwag | 40.05 | 35.89 | 39.55 | 39.89 | 38.63 | 36.18 | 39.52 | 35.94 |
| PiQA | 66.60 | 64.74 | 66.29 | 66.27 | 66.90 | 64.05 | 66.70 | 64.51 |
| OpenBookQA | 28.87 | 26.60 | 29.33 | 28.73 | 29.40 | 27.87 | 29.67 | 27.83 |
| Lambada | 31.39 | 27.37 | 30.32 | 30.31 | 31.38 | 26.25 | 29.86 | 26.95 |
| SciQ | 81.10 | 79.12 | 79.97 | 82.85 | 79.42 | 81.40 | 81.38 | 81.23 |
| COPA | 67.00 | 64.50 | 66.83 | 69.50 | 67.33 | 65.83 | 69.50 | 66.33 |
| RACE | 30.57 | 29.63 | 30.49 | 30.85 | 30.35 | 28.66 | 31.21 | 29.57 |
| ARC Easy | 50.66 | 47.74 | 47.47 | 50.18 | 49.92 | 49.52 | 50.73 | 48.65 |
| LogiQA | 23.60 | 25.65 | 26.37 | 23.81 | 25.58 | 26.29 | 25.86 | 25.12 |
| QQP | 54.89 | 54.79 | 54.20 | 55.23 | 53.69 | 57.09 | 53.95 | 54.24 |
| WinoGrande | 50.83 | 51.84 | 51.05 | 51.83 | 52.12 | 52.00 | 51.01 | 51.82 |
| MultiRC | 54.18 | 54.48 | 50.17 | 52.12 | 51.42 | 52.69 | 51.87 | 53.48 |
| **Average** | **47.17** | **45.81** | **46.57** | **47.29** | **46.88** | **46.17** | **47.30** | **46.11** |
### Table 3: Model Index 17-24
| Task | Model 17 | Model 18 | Model 19 | Model 20 | Model 21 | Model 22 | Model 23 | Model 24 |
|---------------|----------|----------|----------|----------|----------|----------|----------|----------|
| Social IQA | 33.89 | 33.31 | 33.53 | 33.38 | 33.75 | 33.24 | 33.56 | 33.71 |
| HellaSwag | 38.68 | 39.90 | 34.67 | 37.12 | 37.44 | 36.07 | 42.15 | 34.67 |
| PiQA | 66.83 | 67.39 | 63.33 | 64.83 | 65.00 | 63.68 | 67.80 | 62.99 |
| OpenBookQA | 28.13 | 30.67 | 28.03 | 29.40 | 27.67 | 27.77 | 29.37 | 25.83 |
| Lambada | 28.78 | 28.56 | 24.13 | 29.41 | 27.67 | 28.03 | 33.47 | 24.04 |
| SciQ | 79.60 | 78.83 | 77.42 | 78.98 | 78.95 | 78.72 | 81.83 | 79.12 |
| COPA | 65.17 | 68.17 | 65.33 | 67.33 | 67.67 | 62.67 | 69.83 | 65.83 |
| RACE | 28.74 | 30.03 | 29.76 | 29.49 | 30.77 | 29.76 | 31.21 | 27.91 |
| ARC Easy | 48.86 | 49.42 | 47.90 | 48.30 | 47.88 | 46.68 | 50.92 | 45.24 |
| LogiQA | 25.91 | 26.34 | 26.24 | 25.76 | 26.11 | 26.24 | 24.17 | 25.91 |
| QQP | 53.35 | 53.18 | 50.61 | 51.49 | 54.27 | 54.99 | 52.77 | 55.19 |
| WinoGrande | 52.54 | 51.17 | 52.01 | 51.09 | 52.13 | 52.03 | 52.50 | 50.28 |
| MultiRC | 51.49 | 52.45 | 55.40 | 54.87 | 51.73 | 49.49 | 50.61 | 50.29 |
| **Average** | **46.30** | **46.88** | **45.26** | **46.27** | **46.23** | **45.34** | **47.71** | **44.69** |
### Table 4: Model Index 25-32
| Task | Model 25 | Model 26 | Model 27 | Model 28 | Model 29 | Model 30 | Model 31 | Model 32 |
|---------------|----------|----------|----------|----------|----------|----------|----------|----------|
| Social IQA | 33.51 | 33.40 | 33.59 | 33.52 | 33.53 | 33.49 | 33.16 | 33.56 |
| HellaSwag | 36.75 | 36.97 | 40.81 | 38.25 | 40.28 | 35.71 | 37.37 | 37.39 |
| PiQA | 64.09 | 64.74 | 67.97 | 66.15 | 66.88 | 63.84 | 64.47 | 65.05 |
| OpenBookQA | 29.47 | 28.70 | 29.57 | 29.77 | 29.50 | 29.13 | 29.47 | 28.00 |
| Lambada | 26.69 | 33.00 | 31.60 | 33.08 | 31.49 | 27.69 | 26.99 | 29.54 |
| SciQ | 80.03 | 79.17 | 80.12 | 80.22 | 81.92 | 78.23 | 77.42 | 80.87 |
| COPA | 67.67 | 65.50 | 69.00 | 65.67 | 68.33 | 63.33 | 64.67 | 67.17 |
| RACE | 30.05 | 30.19 | 30.96 | 30.37 | 30.08 | 29.62 | 30.13 | 29.92 |
| ARC Easy | 47.50 | 46.90 | 50.26 | 48.57 | 50.55 | 46.96 | 48.77 | 48.79 |
| LogiQA | 27.24 | 25.55 | 25.86 | 24.37 | 25.32 | 25.12 | 26.40 | 24.30 |
| QQP | 49.68 | 55.43 | 50.94 | 50.91 | 51.99 | 53.53 | 49.53 | 51.36 |
| WinoGrande | 51.68 | 52.12 | 51.93 | 51.50 | 52.32 | 51.67 | 52.13 | 52.63 |
| MultiRC | 51.24 | 51.91 | 50.33 | 52.42 | 52.52 | 54.04 | 52.05 | 53.04 |
| **Average** | **45.82** | **46.43** | **47.15** | **46.52** | **47.29** | **45.57** | **45.58** | **46.28** |
### Table 5: Model Index 33-40
| Task | Model 33 | Model 34 | Model 35 | Model 36 | Model 37 | Model 38 | Model 39 | Model 40 |
|---------------|----------|----------|----------|----------|----------|----------|----------|----------|
| Social IQA | 33.48 | 33.28 | 33.35 | 33.29 | 33.63 | 33.61 | 33.21 | 33.61 |
| HellaSwag | 38.00 | 40.18 | 43.37 | 37.69 | 32.96 | 32.98 | 37.31 | 37.79 |
| PiQA | 65.30 | 66.68 | 69.04 | 66.46 | 62.25 | 60.17 | 65.24 | 65.32 |
| OpenBookQA | 29.43 | 30.37 | 30.43 | 27.63 | 26.43 | 26.83 | 27.97 | 28.70 |
| Lambada | 26.59 | 31.46 | 31.71 | 30.21 | 18.92 | 20.29 | 28.10 | 28.58 |
| SciQ | 79.82 | 80.58 | 82.13 | 80.83 | 76.73 | 77.90 | 79.12 | 79.60 |
| COPA | 64.33 | 69.33 | 67.00 | 67.83 | 61.50 | 62.67 | 64.67 | 66.00 |
| RACE | 30.03 | 30.16 | 32.47 | 30.49 | 29.27 | 28.12 | 30.11 | 30.21 |
| ARC Easy | 48.86 | 49.88 | 52.22 | 48.32 | 44.86 | 45.54 | 48.15 | 48.86 |
| LogiQA | 25.91 | 24.30 | 23.35 | 24.96 | 26.19 | 27.68 | 25.47 | 25.37 |
| QQP | 56.06 | 56.56 | 52.57 | 56.70 | 52.54 | 48.04 | 49.81 | 57.12 |
| WinoGrande | 50.92 | 50.97 | 52.39 | 52.70 | 52.30 | 51.68 | 51.42 | 52.80 |
| MultiRC | 53.09 | 49.97 | 52.18 | 49.05 | 53.78 | 52.27 | 51.45 | 55.68 |
| **Average** | **46.29** | **47.21** | **47.86** | **46.63** | **43.95** | **43.67** | **45.54** | **46.90** |
### Table 6: Model Index 41-48
| Task | Model 41 | Model 42 | Model 43 | Model 44 | Model 45 | Model 46 | Model 47 | Model 48 |
|---------------|----------|----------|----------|----------|----------|----------|----------|----------|
| Social IQA | 33.49 | 33.43 | 33.07 | 33.28 | 33.44 | 33.08 | 33.78 | 33.17 |
| HellaSwag | 34.51 | 37.59 | 42.69 | 37.37 | 38.31 | 38.30 | 39.67 | 41.07 |
| PiQA | 62.24 | 65.58 | 68.05 | 66.62 | 66.54 | 65.52 | 66.98 | 67.21 |
| OpenBookQA | 27.10 | 28.77 | 28.90 | 28.07 | 28.07 | 27.60 | 31.17 | 29.73 |
| Lambada | 22.78 | 26.99 | 31.34 | 29.51 | 27.87 | 29.47 | 30.34 | 32.71 |
| SciQ | 77.78 | 80.25 | 79.47 | 80.25 | 80.70 | 79.72 | 81.35 | 81.77 |
| COPA | 64.00 | 66.33 | 67.00 | 67.00 | 67.33 | 68.33 | 67.17 | 67.67 |
| RACE | 28.33 | 28.82 | 30.78 | 30.80 | 30.08 | 30.24 | 30.24 | 30.67 |
| ARC Easy | 45.48 | 48.64 | 51.49 | 46.99 | 48.79 | 48.05 | 49.58 | 49.49 |
| LogiQA | 24.83 | 24.96 | 24.76 | 23.25 | 26.06 | 25.55 | 24.32 | 24.68 |
| QQP | 50.27 | 54.73 | 53.96 | 57.00 | 53.73 | 51.19 | 57.52 | 56.91 |
| WinoGrande | 51.79 | 51.63 | 51.32 | 50.76 | 53.18 | 52.45 | 50.72 | 52.24 |
| MultiRC | 54.03 | 53.96 | 48.91 | 50.74 | 53.01 | 50.89 | 47.63 | 53.84 |
| **Average** | **44.35** | **46.28** | **47.06** | **46.28** | **46.70** | **46.18** | **46.96** | **47.78** |
## Table 7: Model Index 49-56
| Task | Model 49 | Model 50 | Model 51 | Model 52 | Model 53 | Model 54 | Model 55 | Model 56 |
|---------------|----------|----------|----------|----------|----------|----------|----------|----------|
| Social IQA | 33.53 | 33.74 | 33.37 | 33.41 | 32.96 | 33.88 | 33.75 | 33.79 |
| HellaSwag | 39.09 | 35.65 | 38.68 | 36.07 | 37.68 | 38.53 | 35.40 | 40.50 |
| PiQA | 66.81 | 64.58 | 65.68 | 63.99 | 65.85 | 65.76 | 64.51 | 66.89 |
| OpenBookQA | 29.13 | 27.57 | 28.27 | 29.10 | 29.43 | 28.73 | 28.30 | 29.87 |
| Lambada | 30.23 | 26.19 | 30.29 | 30.84 | 29.76 | 29.03 | 28.63 | 30.74 |
| SciQ | 79.90 | 80.83 | 78.40 | 80.03 | 81.38 | 80.92 | 77.75 | 82.07 |
| COPA | 68.17 | 61.83 | 67.00 | 66.00 | 66.17 | 63.17 | 66.33 | 64.00 |
| RACE | 31.42 | 29.35 | 30.41 | 31.08 | 30.77 | 29.73 | 30.80 | 31.42 |
| ARC Easy | 49.54 | 47.71 | 49.02 | 47.64 | 48.38 | 49.36 | 46.96 | 51.22 |
| LogiQA | 24.99 | 24.58 | 25.32 | 24.91 | 25.17 | 26.22 | 24.63 | 24.91 |
| QQP | 54.06 | 56.48 | 50.96 | 56.62 | 56.45 | 53.86 | 53.85 | 53.26 |
| WinoGrande | 50.51 | 50.26 | 51.83 | 51.33 | 52.18 | 51.89 | 51.59 | 50.50 |
| MultiRC | 50.25 | 54.37 | 50.94 | 52.38 | 51.21 | 55.34 | 54.52 | 50.50 |
| **Average** | **46.74** | **45.63** | **46.17** | **46.42** | **46.72** | **46.65** | **45.92** | **46.90** |
## Table 8: Model Index 57-64
| Task | Model 57 | Model 58 | Model 59 | Model 60 | Model 61 | Model 62 | Model 63 | Model 64 |
|---------------|----------|----------|----------|----------|----------|----------|----------|----------|
| Social IQA | 33.24 | 33.30 | 33.56 | 33.54 | 33.42 | 33.84 | 33.32 | 33.55 |
| HellaSwag | 41.74 | 39.63 | 35.36 | 38.83 | 38.53 | 36.46 | 38.80 | 36.43 |
| PiQA | 68.07 | 67.31 | 64.44 | 66.38 | 66.50 | 64.74 | 66.54 | 64.87 |
| OpenBookQA | 29.20 | 29.50 | 28.10 | 27.97 | 27.83 | 27.37 | 28.83 | 27.87 |
| Lambada | 31.79 | 31.11 | 27.32 | 30.17 | 28.75 | 26.22 | 30.38 | 26.25 |
| SciQ | 80.42 | 79.83 | 80.85 | 79.60 | 78.93 | 80.05 | 79.50 | 78.65 |
| COPA | 66.17 | 69.00 | 64.00 | 64.83 | 67.00 | 64.00 | 66.00 | 66.83 |
| RACE | 31.39 | 29.82 | 29.67 | 30.08 | 29.98 | 29.46 | 30.37 | 29.19 |
| ARC Easy | 51.14 | 49.24 | 47.13 | 47.88 | 48.20 | 47.09 | 49.09 | 46.90 |
| LogiQA | 25.19 | 25.93 | 23.68 | 25.17 | 25.70 | 25.52 | 26.50 | 26.65 |
| QQP | 55.37 | 54.46 | 52.73 | 53.17 | 59.65 | 58.15 | 57.50 | 55.31 |
| WinoGrande | 53.21 | 51.46 | 50.83 | 52.16 | 52.37 | 51.41 | 51.63 | 51.85 |
| MultiRC | 53.58 | 52.31 | 52.22 | 53.03 | 50.41 | 52.17 | 52.27 | 51.50 |
| **Average** | **47.73** | **47.15** | **45.38** | **46.37** | **46.71** | **45.88** | **46.98** | **45.84** | |