yakine commited on
Commit
25b7ae1
·
verified ·
1 Parent(s): e22ecd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -24
app.py CHANGED
@@ -3,12 +3,17 @@ from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
 
6
  app = FastAPI()
7
 
8
  # Load your fine-tuned model and tokenizer
9
- MODEL_NAME = "aubmindlab/aragpt2-medium"
10
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
 
 
 
 
12
 
13
  # Define the general prompt template
14
  general_prompt_template = """
@@ -48,33 +53,37 @@ def generate_text(request: GenerateRequest):
48
  المادة = request.المادة
49
  المستوى = request.المستوى
50
 
51
- if not المادة or not المستوى:
52
- raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان.")
 
 
 
 
53
 
54
- # Format the prompt with user inputs
55
- arabic_prompt = general_prompt_template.format(المادة=المادة, المستوى=المستوى)
56
 
57
- # Tokenize the prompt
58
- inputs = tokenizer(arabic_prompt, return_tensors="pt", max_length=512, truncation=True)
 
 
 
 
 
 
 
 
59
 
60
- # Generate text
61
- with torch.no_grad():
62
- outputs = model.generate(
63
- inputs.input_ids,
64
- max_length=300, # Adjust as needed
65
- num_return_sequences=1,
66
- temperature=0.1, # Adjust for creativity
67
- top_p=0.9, # Adjust for diversity
68
- do_sample=True,
69
- )
70
 
71
- # Decode the generated text
72
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
73
 
74
- # Remove the prompt from the generated text
75
- generated_text = generated_text.replace(arabic_prompt, "").strip()
76
 
77
- return {"generated_text": generated_text}
 
78
 
79
  @app.get("/")
80
  def read_root():
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
+
7
  app = FastAPI()
8
 
9
  # Load your fine-tuned model and tokenizer
10
+ MODEL_NAME = "aubmindlab/aragpt2-medium"
11
+
12
+ try:
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
15
+ except Exception as e:
16
+ raise RuntimeError(f"Failed to load model or tokenizer: {str(e)}")
17
 
18
  # Define the general prompt template
19
  general_prompt_template = """
 
53
  المادة = request.المادة
54
  المستوى = request.المستوى
55
 
56
+ if not المادة or not المستوى or not isinstance(المادة, str) or not isinstance(المستوى, str):
57
+ raise HTTPException(status_code=400, detail="المادة والمستوى مطلوبان ويجب أن يكونا نصًا.")
58
+
59
+ try:
60
+ # Format the prompt with user inputs
61
+ arabic_prompt = general_prompt_template.format(المادة=المادة, المستوى=المستوى)
62
 
63
+ # Tokenize the prompt
64
+ inputs = tokenizer(arabic_prompt, return_tensors="pt", max_length=512, truncation=True)
65
 
66
+ # Generate text
67
+ with torch.no_grad():
68
+ outputs = model.generate(
69
+ inputs.input_ids,
70
+ max_length=300,
71
+ num_return_sequences=1,
72
+ temperature=0.1,
73
+ top_p=0.9,
74
+ do_sample=True,
75
+ )
76
 
77
+ # Decode the generated text
78
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
79
 
80
+ # Remove the prompt from the generated text
81
+ generated_text = generated_text.replace(arabic_prompt, "").strip()
82
 
83
+ return {"generated_text": generated_text}
 
84
 
85
+ except Exception as e:
86
+ raise HTTPException(status_code=500, detail=f"Error during text generation: {str(e)}")
87
 
88
  @app.get("/")
89
  def read_root():