Update README.md
Browse files
README.md
CHANGED
@@ -58,13 +58,13 @@ def predict_with_vllm(prompts: List[str], model_name: str, max_context_length: i
|
|
58 |
return predictions
|
59 |
|
60 |
|
61 |
-
def predict_with_hf_generation_pipeline(prompts: List[str], model_name: str,
|
62 |
-
batch_size: int = 2):
|
63 |
text_generation_pipeline = pipeline("text-generation", model=model_name,
|
64 |
model_kwargs={"torch_dtype": torch.float16}, device_map="auto",
|
65 |
batch_size=batch_size)
|
66 |
|
67 |
-
batch_output = text_generation_pipeline(prompts, truncation=True,
|
68 |
return_full_text=False)
|
69 |
predictions = [result[0]['generated_text'] for result in batch_output]
|
70 |
return predictions
|
@@ -89,7 +89,8 @@ Satz: {sentence}
|
|
89 |
### Erklärung und Label:"""
|
90 |
|
91 |
prompts = generate_prompts_for_generation(prompt_template=prompt_template, article=article, summary_sentences=summary_sentences)
|
92 |
-
predictions = predict_with_hf_generation_pipeline(prompts=prompts, model_name=model_name,
|
|
|
93 |
print(predictions)
|
94 |
|
95 |
# Uncomment the following lines to use vllm for prediction
|
|
|
58 |
return predictions
|
59 |
|
60 |
|
61 |
+
def predict_with_hf_generation_pipeline(prompts: List[str], model_name: str, max_context_length: int = 4096,
|
62 |
+
batch_size: int = 2) -> List[str]:
|
63 |
text_generation_pipeline = pipeline("text-generation", model=model_name,
|
64 |
model_kwargs={"torch_dtype": torch.float16}, device_map="auto",
|
65 |
batch_size=batch_size)
|
66 |
|
67 |
+
batch_output = text_generation_pipeline(prompts, truncation=True, max_length=max_context_length,
|
68 |
return_full_text=False)
|
69 |
predictions = [result[0]['generated_text'] for result in batch_output]
|
70 |
return predictions
|
|
|
89 |
### Erklärung und Label:"""
|
90 |
|
91 |
prompts = generate_prompts_for_generation(prompt_template=prompt_template, article=article, summary_sentences=summary_sentences)
|
92 |
+
predictions = predict_with_hf_generation_pipeline(prompts=prompts, model_name=model_name,
|
93 |
+
max_context_length=max_context_length, batch_size=batch_size)
|
94 |
print(predictions)
|
95 |
|
96 |
# Uncomment the following lines to use vllm for prediction
|