modify ComputeDtype params
Browse files- src/display/utils.py +8 -3
src/display/utils.py
CHANGED
@@ -230,17 +230,17 @@ class QuantType(Enum):
|
|
230 |
return QuantType.Unknown
|
231 |
|
232 |
|
|
|
233 |
class WeightDtype(Enum):
|
|
|
234 |
int2 = ModelDetails("int2")
|
235 |
-
int3 = ModelDetails("int3")
|
236 |
int4 = ModelDetails("int4")
|
237 |
nf4 = ModelDetails("nf4")
|
238 |
fp4 = ModelDetails("fp4")
|
239 |
|
240 |
-
|
241 |
Unknown = ModelDetails("?")
|
242 |
|
243 |
-
all = ModelDetails("All")
|
244 |
|
245 |
|
246 |
def from_str(weight_dtype):
|
@@ -259,11 +259,13 @@ class WeightDtype(Enum):
|
|
259 |
return WeightDtype.Unknown
|
260 |
|
261 |
class ComputeDtype(Enum):
|
|
|
262 |
fp16 = ModelDetails("float16")
|
263 |
bf16 = ModelDetails("bfloat16")
|
264 |
int8 = ModelDetails("int8")
|
265 |
fp32 = ModelDetails("float32")
|
266 |
|
|
|
267 |
Unknown = ModelDetails("?")
|
268 |
|
269 |
def from_str(compute_dtype):
|
@@ -275,8 +277,11 @@ class ComputeDtype(Enum):
|
|
275 |
return ComputeDtype.int8
|
276 |
if compute_dtype in ["float32"]:
|
277 |
return ComputeDtype.fp32
|
|
|
|
|
278 |
return ComputeDtype.Unknown
|
279 |
|
|
|
280 |
class GroupDtype(Enum):
|
281 |
group_1 = ModelDetails("-1")
|
282 |
group_1024 = ModelDetails("1024")
|
|
|
230 |
return QuantType.Unknown
|
231 |
|
232 |
|
233 |
+
|
234 |
class WeightDtype(Enum):
|
235 |
+
all = ModelDetails("All")
|
236 |
int2 = ModelDetails("int2")
|
237 |
+
int3 = ModelDetails("int3")
|
238 |
int4 = ModelDetails("int4")
|
239 |
nf4 = ModelDetails("nf4")
|
240 |
fp4 = ModelDetails("fp4")
|
241 |
|
|
|
242 |
Unknown = ModelDetails("?")
|
243 |
|
|
|
244 |
|
245 |
|
246 |
def from_str(weight_dtype):
|
|
|
259 |
return WeightDtype.Unknown
|
260 |
|
261 |
class ComputeDtype(Enum):
|
262 |
+
all = ModelDetails("All")
|
263 |
fp16 = ModelDetails("float16")
|
264 |
bf16 = ModelDetails("bfloat16")
|
265 |
int8 = ModelDetails("int8")
|
266 |
fp32 = ModelDetails("float32")
|
267 |
|
268 |
+
|
269 |
Unknown = ModelDetails("?")
|
270 |
|
271 |
def from_str(compute_dtype):
|
|
|
277 |
return ComputeDtype.int8
|
278 |
if compute_dtype in ["float32"]:
|
279 |
return ComputeDtype.fp32
|
280 |
+
if compute_dtype in ["All"]:
|
281 |
+
return ComputeDtype.all
|
282 |
return ComputeDtype.Unknown
|
283 |
|
284 |
+
|
285 |
class GroupDtype(Enum):
|
286 |
group_1 = ModelDetails("-1")
|
287 |
group_1024 = ModelDetails("1024")
|