Ryukijano commited on
Commit
7a18ab3
Β·
verified Β·
1 Parent(s): 38adfc7

Create llama_mesh.py

Browse files
Files changed (1) hide show
  1. llama_mesh.py +44 -0
llama_mesh.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # timeforge/llama_mesh.py
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ from threading import Thread
5
+
6
+ class LLaMAMesh:
7
+ def __init__(self, model_path="Zhengyi/LLaMA-Mesh", device="cuda"):
8
+ self.model_path = model_path
9
+ self.device = device
10
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
11
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_path, device_map=self.device)
12
+ self.terminators = [
13
+ self.tokenizer.eos_token_id,
14
+ self.tokenizer.convert_tokens_to_ids("<|eot_id|>"),
15
+ ]
16
+
17
+
18
+ def generate_mesh(self, prompt, temperature=0.9, max_new_tokens=4096):
19
+ input_ids = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], return_tensors="pt").to(self.model.device)
20
+ streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
21
+ generate_kwargs = dict(
22
+ input_ids= input_ids,
23
+ streamer=streamer,
24
+ max_new_tokens=max_new_tokens,
25
+ do_sample=True,
26
+ temperature=temperature,
27
+ eos_token_id=self.terminators,
28
+ )
29
+ if temperature == 0:
30
+ generate_kwargs['do_sample'] = False
31
+
32
+ t = Thread(target=self.model.generate, kwargs=generate_kwargs)
33
+ t.start()
34
+
35
+ outputs = []
36
+ for text in streamer:
37
+ outputs.append(text)
38
+ return "".join(outputs)
39
+
40
+ if __name__ == "__main__":
41
+ llama_mesh = LLaMAMesh()
42
+ prompt = "Create a 3D model of a futuristic chair."
43
+ mesh = llama_mesh.generate_mesh(prompt)
44
+ print(mesh)