umarigan's picture
Update app.py
da41238 verified
import streamlit as st
import pandas as pd
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
import PyPDF2
import docx
import io
import re
def chunk_text(text, chunk_size=128):
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
if current_length + len(word) + 1 > chunk_size:
chunks.append(' '.join(current_chunk))
current_chunk = [word]
current_length = len(word)
else:
current_chunk.append(word)
current_length += len(word) + 1
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
st.set_page_config(layout="wide")
# Function to read text from uploaded file
def read_file(file):
if file.type == "text/plain":
return file.getvalue().decode("utf-8")
elif file.type == "application/pdf":
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file.getvalue()))
return " ".join(page.extract_text() for page in pdf_reader.pages)
elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
doc = docx.Document(io.BytesIO(file.getvalue()))
return " ".join(paragraph.text for paragraph in doc.paragraphs)
else:
st.error("Unsupported file type")
return None
st.title("Turkish NER Models Testing")
model_list = [
'girayyagmur/bert-base-turkish-ner-cased',
'asahi417/tner-xlm-roberta-base-ontonotes5'
]
st.sidebar.header("Select NER Model")
model_checkpoint = st.sidebar.radio("", model_list)
#st.sidebar.write("For details of models: 'https://huggingface.co/akdeniz27/")
st.sidebar.write("Only PDF, DOCX, and TXT files are supported.")
# Determine aggregation strategy
aggregation = "simple" if model_checkpoint in ["asahi417/tner-xlm-roberta-base-ontonotes5"] else "first"
st.subheader("Select Text Input Method")
input_method = st.radio("", ('Write or Paste New Text', 'Upload File'))
if input_method == "Write or Paste New Text":
input_text = st.text_area('Write or Paste Text Below', value="", height=128)
else:
uploaded_file = st.file_uploader("Choose a file", type=["txt", "pdf", "docx"])
if uploaded_file is not None:
input_text = read_file(uploaded_file)
if input_text:
st.text_area("Extracted Text", input_text, height=128)
else:
input_text = ""
@st.cache_resource
def setModel(model_checkpoint, aggregation):
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
return pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy=aggregation)
@st.cache_resource
def entity_comb(output):
output_comb = []
for ind, entity in enumerate(output):
if ind == 0:
output_comb.append(entity)
elif output[ind]["start"] == output[ind-1]["end"] and output[ind]["entity_group"] == output[ind-1]["entity_group"]:
output_comb[-1]["word"] += output[ind]["word"]
output_comb[-1]["end"] = output[ind]["end"]
else:
output_comb.append(entity)
return output_comb
def create_mask_dict(entities, additional_masks=None):
mask_dict = {}
entity_counters = {}
for entity in entities:
if entity['entity_group'] not in ['CARDINAL', 'EVENT', 'PERCENT', 'QUANTITY', 'DATE', 'TITLE', 'WORK_OF_ART']:
if entity['word'] not in mask_dict: # Corrected indentation
if entity['entity_group'] not in entity_counters:
entity_counters[entity['entity_group']] = 1
else:
entity_counters[entity['entity_group']] += 1
mask_dict[entity['word']] = f"{entity['entity_group']}_{entity_counters[entity['entity_group']]}"
if additional_masks:
for word, replacement in additional_masks.items():
mask_dict[word] = replacement
return mask_dict
def replace_words_in_text(input_text, entities):
replace_dict = create_mask_dict(entities)
for word, replacement in replace_dict.items():
input_text = input_text.replace(word, replacement)
return input_text
# Function to mask email, phone, and address patterns
def mask_patterns(text):
masks = {}
# Email pattern
email_pattern = r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
emails = re.findall(email_pattern, text)
for email in emails:
masks[email] = "<EMAIL>"
#Phone pattern (Turkish)
#phone_pattern = r"\+90\d{10}|\b\d{3}[-.\s]?\d{3}[-.\s]?\d{2}[-.\s]?\d{2}\b"
phone_pattern = r"\b(0?5\d{2}[-.\s]?\d{3}[-.\s]?\d{2}[-.\s]?\d{2}|\b5\d{3}[-.\s]?\d{3}[-.\s]?\d{2}[-.\s]?\d{2}|\b\d{3}[-.\s]?\d{3}[-.\s]?\d{2}[-.\s]?\d{2})\b"
phones = re.findall(phone_pattern, text)
for phone in phones:
masks[phone] = "<PHONE>"
# Replace patterns in text
for word, replacement in masks.items():
text = text.replace(word, replacement)
return text, masks
Run_Button = st.button("Run")
if Run_Button and input_text:
ner_pipeline = setModel(model_checkpoint, aggregation)
# Chunk the input text
chunks = chunk_text(input_text)
# Process each chunk
all_outputs = []
for i, chunk in enumerate(chunks):
output = ner_pipeline(chunk)
# Adjust start and end positions for entities in chunks after the first
if i > 0:
offset = len(' '.join(chunks[:i])) + 1
for entity in output:
entity['start'] += offset
entity['end'] += offset
all_outputs.extend(output)
# Combine entities
output_comb = entity_comb(all_outputs)
# Mask emails, phone numbers, and addresses
masked_text, additional_masks = mask_patterns(input_text)
# Create masked text and masking dictionary
masked_text = replace_words_in_text(masked_text, output_comb)
mask_dict = create_mask_dict(output_comb, additional_masks)
# Display the masked text and masking dictionary
st.subheader("Masked Text Preview")
st.text(masked_text)
st.subheader("Masking Dictionary")
st.json(mask_dict)