mgfrantz commited on
Commit
f9eb567
·
verified ·
1 Parent(s): 7441e5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -63,6 +63,18 @@ async def zero_shot_predict(text):
63
  response = await structured_llm.achat(messages)
64
  return response.raw.rating
65
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  few_shot_prompt_tmpl_str = """\
67
  The review text is below.
68
  ---------------------
@@ -83,18 +95,6 @@ few_shot_prompt_tmpl = PromptTemplate(
83
  function_mappings={"random_few_shot_examples": random_few_shot_examples_fn},
84
  )
85
 
86
- rng = np.random.Generator(np.random.PCG64(1234))
87
- def random_few_shot_examples_fn(**kwargs):
88
- if n_samples:=kwargs.get('n_samples'):
89
- random_examples = train.shuffle(generator=rng)[:n_samples]
90
- else:
91
- random_examples = train.shuffle(generator=rng)[:5]
92
-
93
- result_strs = []
94
- for text, rating in zip(random_examples['text'], random_examples['label']):
95
- result_strs.append(f"Text: {text}\nRating: {rating}")
96
- return "\n\n".join(result_strs)
97
-
98
  async def random_few_shot_predict(text, n_examples=5):
99
  tasks = []
100
  for _ in range(3):
 
63
  response = await structured_llm.achat(messages)
64
  return response.raw.rating
65
 
66
+ rng = np.random.Generator(np.random.PCG64(1234))
67
+ def random_few_shot_examples_fn(**kwargs):
68
+ if n_samples:=kwargs.get('n_samples'):
69
+ random_examples = train.shuffle(generator=rng)[:n_samples]
70
+ else:
71
+ random_examples = train.shuffle(generator=rng)[:5]
72
+
73
+ result_strs = []
74
+ for text, rating in zip(random_examples['text'], random_examples['label']):
75
+ result_strs.append(f"Text: {text}\nRating: {rating}")
76
+ return "\n\n".join(result_strs)
77
+
78
  few_shot_prompt_tmpl_str = """\
79
  The review text is below.
80
  ---------------------
 
95
  function_mappings={"random_few_shot_examples": random_few_shot_examples_fn},
96
  )
97
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  async def random_few_shot_predict(text, n_examples=5):
99
  tasks = []
100
  for _ in range(3):