kvaishnavi commited on
Commit
5e2952a
·
verified ·
1 Parent(s): 1b16cfc

Update onnx/builder.py

Browse files
Files changed (1) hide show
  1. onnx/builder.py +4 -4
onnx/builder.py CHANGED
@@ -17,7 +17,7 @@ def build_vision(args):
17
  prompt = f"{user_prompt}<|image_1|>\nWhat is shown in this image?{prompt_suffix}{assistant_prompt}"
18
  url = "https://www.ilankelman.org/stopsigns/australia.jpg"
19
  image = Image.open(requests.get(url, stream=True).raw)
20
- inputs = processor(prompt, image, return_tensors="pt").to(args.execution_provider)
21
  inputs["pixel_values"] = inputs["pixel_values"].to(args.precision)
22
 
23
  # TorchScript export
@@ -214,8 +214,8 @@ def get_args():
214
  "-e",
215
  "--execution_provider",
216
  required=True,
217
- choices=["cpu", "cuda"],
218
- help="Device to export Phi-3 vision components with",
219
  )
220
 
221
  parser.add_argument(
@@ -238,7 +238,7 @@ if __name__ == "__main__":
238
  args = get_args()
239
  config = AutoConfig.from_pretrained(args.input, trust_remote_code=True)
240
  processor = AutoProcessor.from_pretrained(args.input, trust_remote_code=True)
241
- model = AutoModelForCausalLM.from_pretrained(args.input, trust_remote_code=True, torch_dtype=args.precision).to(args.execution_provider)
242
 
243
  # Build model components
244
  build_vision(args)
 
17
  prompt = f"{user_prompt}<|image_1|>\nWhat is shown in this image?{prompt_suffix}{assistant_prompt}"
18
  url = "https://www.ilankelman.org/stopsigns/australia.jpg"
19
  image = Image.open(requests.get(url, stream=True).raw)
20
+ inputs = processor(prompt, image, return_tensors="pt").to(args.execution_provider.replace("dml", "cuda"))
21
  inputs["pixel_values"] = inputs["pixel_values"].to(args.precision)
22
 
23
  # TorchScript export
 
214
  "-e",
215
  "--execution_provider",
216
  required=True,
217
+ choices=["cpu", "cuda", "dml"],
218
+ help="Execution provider for Phi-3 vision components",
219
  )
220
 
221
  parser.add_argument(
 
238
  args = get_args()
239
  config = AutoConfig.from_pretrained(args.input, trust_remote_code=True)
240
  processor = AutoProcessor.from_pretrained(args.input, trust_remote_code=True)
241
+ model = AutoModelForCausalLM.from_pretrained(args.input, trust_remote_code=True, torch_dtype=args.precision).to(args.execution_provider.replace("dml", "cuda"))
242
 
243
  # Build model components
244
  build_vision(args)