wenjiao commited on
Commit
b8e5d23
·
verified ·
1 Parent(s): 111e33c

modify ComputeDtype params

Browse files
Files changed (1) hide show
  1. src/display/utils.py +8 -3
src/display/utils.py CHANGED
@@ -230,17 +230,17 @@ class QuantType(Enum):
230
  return QuantType.Unknown
231
 
232
 
 
233
  class WeightDtype(Enum):
 
234
  int2 = ModelDetails("int2")
235
- int3 = ModelDetails("int3")
236
  int4 = ModelDetails("int4")
237
  nf4 = ModelDetails("nf4")
238
  fp4 = ModelDetails("fp4")
239
 
240
-
241
  Unknown = ModelDetails("?")
242
 
243
- all = ModelDetails("All")
244
 
245
 
246
  def from_str(weight_dtype):
@@ -259,11 +259,13 @@ class WeightDtype(Enum):
259
  return WeightDtype.Unknown
260
 
261
  class ComputeDtype(Enum):
 
262
  fp16 = ModelDetails("float16")
263
  bf16 = ModelDetails("bfloat16")
264
  int8 = ModelDetails("int8")
265
  fp32 = ModelDetails("float32")
266
 
 
267
  Unknown = ModelDetails("?")
268
 
269
  def from_str(compute_dtype):
@@ -275,8 +277,11 @@ class ComputeDtype(Enum):
275
  return ComputeDtype.int8
276
  if compute_dtype in ["float32"]:
277
  return ComputeDtype.fp32
 
 
278
  return ComputeDtype.Unknown
279
 
 
280
  class GroupDtype(Enum):
281
  group_1 = ModelDetails("-1")
282
  group_1024 = ModelDetails("1024")
 
230
  return QuantType.Unknown
231
 
232
 
233
+
234
  class WeightDtype(Enum):
235
+ all = ModelDetails("All")
236
  int2 = ModelDetails("int2")
237
+ int3 = ModelDetails("int3")
238
  int4 = ModelDetails("int4")
239
  nf4 = ModelDetails("nf4")
240
  fp4 = ModelDetails("fp4")
241
 
 
242
  Unknown = ModelDetails("?")
243
 
 
244
 
245
 
246
  def from_str(weight_dtype):
 
259
  return WeightDtype.Unknown
260
 
261
  class ComputeDtype(Enum):
262
+ all = ModelDetails("All")
263
  fp16 = ModelDetails("float16")
264
  bf16 = ModelDetails("bfloat16")
265
  int8 = ModelDetails("int8")
266
  fp32 = ModelDetails("float32")
267
 
268
+
269
  Unknown = ModelDetails("?")
270
 
271
  def from_str(compute_dtype):
 
277
  return ComputeDtype.int8
278
  if compute_dtype in ["float32"]:
279
  return ComputeDtype.fp32
280
+ if compute_dtype in ["All"]:
281
+ return ComputeDtype.all
282
  return ComputeDtype.Unknown
283
 
284
+
285
  class GroupDtype(Enum):
286
  group_1 = ModelDetails("-1")
287
  group_1024 = ModelDetails("1024")