TraceVLA-7B

TraceVLA-7B model is a vision-language-action model obtained by finetuning the base OpenVLA model with visual trace prompting technique.

Results on SimplerEnv Fractal + SimplerEnv:

Fractal:

Policy/Settings Pick up Coke Move near Open/Close Drawer Put in Drawer Average Success Rate
(Visual Matching) OpenVLA-7B 23.7% 65.0% 57.4% 0.% 36.5%
(Visual Matching) TraceVLA-7B 45.0% 63.8% 63.1% 11.1.% 45.8%
(Variant Aggregation) OpenVLA-7B 61.3% 55.8% 24.9% 1.0% 35.8%
(Variant Aggregation) TraceVLA-7B 64.3% 60.6% 61.6% 12.5.% 49.8%

Bridge:

Policy/Settings Put Spoon Put Carrot Stack Block Put Eggplant Average Success Rate
OpenVLA-7B 8.3% 8.3% 4.2% 45.8% 16.7%
TraceVLA-7B 12.5% 16.6% 16.6% 65.0% 27.7%

Sample Inference Code

Here is the sample inference code of TraceVLA-7B model.

model_path = "furonghuang-lab/tracevla_7b" 
# Load Processor & VLA
processor = AutoProcessor.from_pretrained(
    model_path,
    trust_remote_code=True,
    num_crops=1, 
)

vla = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    _attn_implementation='flash_attention_2',
    use_cache=True
).to(device='cuda')

# Load Visual Trace Processor
# cotracker_model_path corresponds to the path to your downloaded scaled_offline.pth checkpoint
from prismatic.eval.trace_processor import TraceProcessor
trace_processor = TraceProcessor(cotracker_model_path)

# Grab image input & format prompt
# In case where the visual trace returned by Co-Tracker is not valid, we use the default openvla prompt.
openvla_prompt_template = "In: What action should the robot take to {task_description}?\nOut:"
tracevla_prompt_template = "In: You are given two images: one with the original robot observation, and another one marked with historical traces of the robot end effector and moving objects, separated by a special separator token. What action should the robot take to {task_description}?\nOut:"

image: Image.Image = get_from_camera(...)
image_overlaid, has_trace = trace_processors.process_image(image)

if not has_trace:
    prompt = openvla_prompt_template.format(task_description=task_description)
    inputs = processor(prompt, [image, image]).to(device='cuda', dtype=torch.bfloat16)
else:
    prompt = tracevla_prompt_template.format(task_description=task_description)
    inputs = processor(prompt, [image, image_overlaid]).to(device='cuda', dtype=torch.bfloat16)

### Predict the action
with torch.inference_mode():
    action = vla.predict_action(**inputs)

# Execute the action
robot.act(action, ...)

For more examples, including scripts for finetuning TraceVLA models on your own robot demonstration datasets, check out our repository.

Citation

If you find our code or models useful in your work, please cite our paper:

@misc{zheng2024tracevlavisualtraceprompting,
      title={TraceVLA: Visual Trace Prompting Enhances Spatial-Temporal Awareness for Generalist Robotic Policies}, 
      author={Ruijie Zheng and Yongyuan Liang and Shuaiyi Huang and Jianfeng Gao and Hal Daumé III and Andrey Kolobov and Furong Huang and Jianwei Yang},
      year={2024},
      eprint={2412.10345},
      archivePrefix={arXiv},
      primaryClass={cs.RO},
      url={https://arxiv.org/abs/2412.10345}, 
}
Downloads last month
865
Safetensors
Model size
7.54B params
Tensor type
BF16
·
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.

Collection including furonghuang-lab/tracevla_7b