visor841's picture
Update the labels
bc64b44
from transformers import pipeline
import numpy as np
import gradio as gr
HEXACO = [
"honesty-humility",
"emotionality",
"extraversion",
"agreeableness",
"conscientiousness",
"openness to experience"
]
def netScores(tagList: list, sequence_to_classify: str, modelName: str) -> dict:
classifier = pipeline("zero-shot-classification", model=modelName)
hypothesis_template_pos = "This example is {}"
hypothesis_template_neg = "This example is not {}"
output_pos = classifier(sequence_to_classify, tagList, hypothesis_template=hypothesis_template_pos, multi_label=True)
output_neg = classifier(sequence_to_classify, tagList, hypothesis_template=hypothesis_template_neg, multi_label=True)
positive_scores = {}
for x in range(len(tagList)):
positive_scores[output_pos["labels"][x]] = output_pos["scores"][x]
negative_scores = {}
for x in range(len(tagList)):
negative_scores[output_neg["labels"][x]] = output_neg["scores"][x]
pos_neg_scores = {}
for tag in tagList:
pos_neg_scores[tag] = [positive_scores[tag],negative_scores[tag]]
net_scores = {}
for tag in tagList:
net_scores[tag] = positive_scores[tag]-negative_scores[tag]
net_scores = dict(sorted(net_scores.items(), key=lambda x:x[1], reverse=True))
return net_scores
def scoresMatch(tagList: list, scoresA: dict, scoresB: dict):
maxDistance = 2*np.sqrt(len(tagList))
differenceSquares = []
for tag in tagList:
difference = (scoresA[tag] - scoresB[tag])
differenceSquare = difference*difference
differenceSquares.append(differenceSquare)
distance = np.sqrt(np.sum(differenceSquares))
percentDifference = distance/maxDistance
return 1-percentDifference
def compareDocuments (userText1, userText2):
scores1 = netScores (HEXACO, userText1, 'akhtet/mDeBERTa-v3-base-myXNLI')
scores2 = netScores (HEXACO, userText2, 'akhtet/mDeBERTa-v3-base-myXNLI')
return scoresMatch(HEXACO, scores1, scores2)
demo = gr.Interface(
fn=compareDocuments,
inputs=[gr.Textbox(label="Text 1"), gr.Textbox(label="Text 2")],
outputs=[gr.Textbox(label="HEXACO match")],
)
demo.launch()