NiniCat commited on
Commit
0bc3d74
·
verified ·
1 Parent(s): 7da2d16

qingyang's new cas9

Browse files
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  Human_genes_HUGO_02242024_annotation.txt filter=lfs diff=lfs merge=lfs -text
37
  SRR25934512.filter.snps.indels.vcf filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  Human_genes_HUGO_02242024_annotation.txt filter=lfs diff=lfs merge=lfs -text
37
  SRR25934512.filter.snps.indels.vcf filter=lfs diff=lfs merge=lfs -text
38
+ cas9_model/Cas9_MultiHeadAttention_weights.keras filter=lfs diff=lfs merge=lfs -text
cas9_model/Cas9_MultiHeadAttention_weights.keras ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0544e98cb16fe99e64c76fd6f8296ba1e7fd785b2d2aa7c049535461e061d546
3
+ size 16304829
cas9_model/cas9_on_target.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cells":[{"cell_type":"code","execution_count":null,"metadata":{"id":"Jikx8tEU_HWh"},"outputs":[],"source":["from google.colab import drive\n","drive.mount('/content/drive')\n","import sys\n","sys.path.append('/content/drive/MyDrive/Colab Notebooks/Cas9/On target')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"udlsQWz-MmJD"},"outputs":[],"source":["import requests\n","import tensorflow as tf\n","import pandas as pd\n","import numpy as np\n","from operator import add\n","from functools import reduce\n","import random\n","import tabulate\n","\n","from keras import Model\n","from keras import regularizers\n","from keras.optimizers import Adam\n","from keras.layers import Conv2D, BatchNormalization, ReLU, Input, Flatten, Softmax\n","from keras.layers import Concatenate, Activation, Dense, GlobalAveragePooling2D, Dropout\n","from keras.layers import AveragePooling1D, Bidirectional, LSTM, GlobalAveragePooling1D, MaxPool1D, Reshape\n","from keras.layers import LayerNormalization, Conv1D, MultiHeadAttention, Layer\n","from keras.models import load_model\n","from keras.callbacks import EarlyStopping, ReduceLROnPlateau\n","\n","!pip install cyvcf2\n","import cyvcf2\n","!pip install parasail\n","import parasail\n","\n","import re"]},{"cell_type":"markdown","metadata":{"id":"2g6DyXWF-Rlt"},"source":["### Data Encoding"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"tSGX_CEieHsX"},"outputs":[],"source":["ntmap = {'A': (1, 0, 0, 0),\n"," 'C': (0, 1, 0, 0),\n"," 'G': (0, 0, 1, 0),\n"," 'T': (0, 0, 0, 1)\n"," }\n","\n","def get_seqcode(seq):\n"," return np.array(reduce(add, map(lambda c: ntmap[c], seq.upper()))).reshape((1, len(seq), -1))"]},{"cell_type":"markdown","metadata":{"id":"2R1aWsfND6EB"},"source":["### Attention model"]},{"cell_type":"code","source":["class PositionalEncoding(Layer):\n"," def __init__(self, sequence_len=None, embedding_dim=None,**kwargs):\n"," super(PositionalEncoding, self).__init__()\n"," self.sequence_len = sequence_len\n"," self.embedding_dim = embedding_dim\n","\n"," def call(self, x):\n","\n"," position_embedding = np.array([\n"," [pos / np.power(10000, 2. * i / self.embedding_dim) for i in range(self.embedding_dim)]\n"," for pos in range(self.sequence_len)])\n","\n"," position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2]) # dim 2i\n"," position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2]) # dim 2i+1\n"," position_embedding = tf.cast(position_embedding, dtype=tf.float32)\n","\n"," return position_embedding+x\n","\n"," def get_config(self):\n"," config = super().get_config().copy()\n"," config.update({\n"," 'sequence_len' : self.sequence_len,\n"," 'embedding_dim' : self.embedding_dim,\n"," })\n"," return config\n","\n","def MultiHeadAttention_model(input_shape):\n"," input = Input(shape=input_shape)\n","\n"," conv1 = Conv1D(256, 3, activation=\"relu\")(input)\n"," pool1 = AveragePooling1D(2)(conv1)\n"," drop1 = Dropout(0.4)(pool1)\n","\n"," conv2 = Conv1D(256, 3, activation=\"relu\")(drop1)\n"," pool2 = AveragePooling1D(2)(conv2)\n"," drop2 = Dropout(0.4)(pool2)\n","\n"," lstm = Bidirectional(LSTM(128,\n"," dropout=0.5,\n"," activation='tanh',\n"," return_sequences=True,\n"," kernel_regularizer=regularizers.l2(0.01)))(drop2)\n","\n"," pos_embedding = PositionalEncoding(sequence_len=int(((23-3+1)/2-3+1)/2), embedding_dim=2*128)(lstm)\n"," atten = MultiHeadAttention(num_heads=2,\n"," key_dim=64,\n"," dropout=0.2,\n"," kernel_regularizer=regularizers.l2(0.01))(pos_embedding, pos_embedding)\n","\n"," flat = Flatten()(atten)\n","\n"," dense1 = Dense(512,\n"," kernel_regularizer=regularizers.l2(1e-4),\n"," bias_regularizer=regularizers.l2(1e-4),\n"," activation=\"relu\")(flat)\n"," drop3 = Dropout(0.1)(dense1)\n","\n"," dense2 = Dense(128,\n"," kernel_regularizer=regularizers.l2(1e-4),\n"," bias_regularizer=regularizers.l2(1e-4),\n"," activation=\"relu\")(drop3)\n"," drop4 = Dropout(0.1)(dense2)\n","\n"," dense3 = Dense(256,\n"," kernel_regularizer=regularizers.l2(1e-4),\n"," bias_regularizer=regularizers.l2(1e-4),\n"," activation=\"relu\")(drop4)\n"," drop5 = Dropout(0.1)(dense3)\n","\n"," output = Dense(1, activation=\"linear\")(drop5)\n","\n"," model = Model(inputs=[input], outputs=[output])\n"," return model"],"metadata":{"id":"msBEVdeQVbTT"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"4c4Ptp57_3Xf"},"source":["### Predict gRNA in one specific gene"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"KF-XoI6TRH3C"},"outputs":[],"source":["def fetch_ensembl_transcripts(gene_symbol):\n"," url = f\"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json\"\n"," response = requests.get(url)\n"," if response.status_code == 200:\n"," gene_data = response.json()\n"," if 'Transcript' in gene_data:\n"," return gene_data['Transcript']\n"," else:\n"," print(\"No transcripts found for gene:\", gene_symbol)\n"," return None\n"," else:\n"," print(f\"Error fetching gene data from Ensembl: {response.text}\")\n"," return None\n","\n","def fetch_ensembl_sequence(transcript_id):\n"," url = f\"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json\"\n"," response = requests.get(url)\n"," if response.status_code == 200:\n"," sequence_data = response.json()\n"," if 'seq' in sequence_data:\n"," return sequence_data['seq']\n"," else:\n"," print(\"No sequence found for transcript:\", transcript_id)\n"," return None\n"," else:\n"," print(f\"Error fetching sequence data from Ensembl: {response.text}\")\n"," return None\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yKowprHtPvEQ"},"outputs":[],"source":["def find_crispr_targets(sequence, chr, start, end, strand, transcript_id, exon_id, pam=\"NGG\", target_length=20):\n"," targets = []\n"," len_sequence = len(sequence)\n"," #complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}\n"," dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}\n","\n"," for i in range(len_sequence - len(pam) + 1):\n"," if sequence[i + 1:i + 3] == pam[1:]:\n"," if i >= target_length:\n"," target_seq = sequence[i - target_length:i + 3]\n"," if strand == -1:\n"," tar_start = end - (i + 2)\n"," tar_end = end - (i - target_length)\n"," #seq_in_ref = ''.join([complement[base] for base in target_seq])[::-1]\n"," else:\n"," tar_start = start + i - target_length\n"," tar_end = start + i + 3 - 1\n"," #seq_in_ref = target_seq\n"," gRNA = ''.join([dnatorna[base] for base in sequence[i - target_length:i]])\n"," #targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id, seq_in_ref])\n"," targets.append([target_seq, gRNA, chr, str(tar_start), str(tar_end), str(strand), transcript_id, exon_id])\n","\n"," return targets\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yhO7K6S8D3aJ"},"outputs":[],"source":["# Function to predict on-target efficiency and format output\n","def format_prediction_output(targets, model_path):\n"," model = MultiHeadAttention_model(input_shape=(23, 4))\n"," model.load_weights(model_path)\n","\n"," formatted_data = []\n","\n"," for target in targets:\n"," # Encode the gRNA sequence\n"," encoded_seq = get_seqcode(target[0])\n","\n"," # Predict on-target efficiency using the model\n"," prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])\n"," if prediction > 100:\n"," prediction = 100\n","\n"," # Format output\n"," gRNA = target[1]\n"," chr = target[2]\n"," start = target[3]\n"," end = target[4]\n"," strand = target[5]\n"," transcript_id = target[6]\n"," exon_id = target[7]\n"," #seq_in_ref = target[8]\n"," #formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, seq_in_ref, prediction[0]])\n"," formatted_data.append([chr, start, end, strand, transcript_id, exon_id, target[0], gRNA, prediction])\n","\n"," return formatted_data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"N_a9E47LUj4J"},"outputs":[],"source":["def gRNADesign(gene_symbol, model_path, write_to_csv=False):\n"," transcripts = fetch_ensembl_transcripts(gene_symbol)\n"," results = []\n"," if transcripts:\n"," for i in range(len(transcripts)):\n"," Exons = transcripts[i]['Exon']\n"," transcript_id = transcripts[i]['id']\n"," for j in range(len(Exons)):\n"," exon_id = Exons[j]['id']\n"," gene_sequence = fetch_ensembl_sequence(exon_id)\n"," if gene_sequence:\n"," start = Exons[j]['start']\n"," end = Exons[j]['end']\n"," strand = Exons[j]['strand']\n"," chr = Exons[j]['seq_region_name']\n"," targets = find_crispr_targets(gene_sequence, chr, start, end, strand, transcript_id, exon_id)\n"," if targets:\n"," formatted_data = format_prediction_output(targets, model_path)\n"," results.append(formatted_data)\n","\n"," #header = ['Chr','Start','End','Strand','Transcript','Exon','Target sequence (5\\' to 3\\')','gRNA','Sequence in reference genome','pred_Score']\n"," header = ['Chrom','Start','End','Strand','Transcript','Exon','Target sequence (5\\' to 3\\')','gRNA','pred_Score']\n"," output = []\n"," for result in results:\n"," for item in result:\n"," output.append(item)\n"," sort_output = sorted(output, key=lambda x: x[8], reverse=True)\n","\n"," if write_to_csv==True:\n"," pd.DataFrame(data=sort_output, columns=header).to_csv(f'/content/drive/MyDrive/Colab Notebooks/Cas9/On target/design_results/Cas9_{gene_symbol}.csv')\n"," else:\n"," return sort_output"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"AlOZrooLWUDl"},"outputs":[],"source":["# design\n","genes = ['TROAP','SPC24','RAD54L','MCM2','COPB2','CKAP5']\n","model_path = '/content/drive/MyDrive/Colab Notebooks/Cas9/On target/saved_model/Cas9_MultiHeadAttention_weights.keras'\n","\n","for gene in genes:\n"," gRNADesign(gene, model_path, write_to_csv=True)"]},{"cell_type":"markdown","metadata":{"id":"8aGSLDWAB4HA"},"source":["### Combine with VCF information"]},{"cell_type":"markdown","metadata":{"id":"wigXTC_8_hfR"},"source":["##### Predict cell type-specific gRNA"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kqQ9G_j2Iw7j"},"outputs":[],"source":["def fetch_ensembl_transcripts(gene_symbol):\n"," url = f\"https://rest.ensembl.org/lookup/symbol/homo_sapiens/{gene_symbol}?expand=1;content-type=application/json\"\n"," response = requests.get(url)\n"," if response.status_code == 200:\n"," gene_data = response.json()\n"," if 'Transcript' in gene_data:\n"," return gene_data['Transcript']\n"," else:\n"," print(\"No transcripts found for gene:\", gene_symbol)\n"," return None\n"," else:\n"," print(f\"Error fetching gene data from Ensembl: {response.text}\")\n"," return None\n","\n","def fetch_ensembl_sequence(transcript_id):\n"," url = f\"https://rest.ensembl.org/sequence/id/{transcript_id}?content-type=application/json\"\n"," response = requests.get(url)\n"," if response.status_code == 200:\n"," sequence_data = response.json()\n"," if 'seq' in sequence_data:\n"," return sequence_data['seq']\n"," else:\n"," print(\"No sequence found for transcript:\", transcript_id)\n"," return None\n"," else:\n"," print(f\"Error fetching sequence data from Ensembl: {response.text}\")\n"," return None\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"4ccz9h9oyHoF"},"outputs":[],"source":["def apply_mutation(ref_sequence, offset, ref, alt):\n"," \"\"\"\n"," Apply a single mutation to the sequence.\n"," \"\"\"\n"," if len(ref) == len(alt) and alt != \"*\": # SNP\n"," mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(alt):]\n","\n"," elif len(ref) < len(alt): # Insertion\n"," mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+1:]\n","\n"," elif len(ref) == len(alt) and alt == \"*\": # Deletion\n"," mutated_seq = ref_sequence[:offset] + ref_sequence[offset+1:]\n","\n"," elif len(ref) > len(alt) and alt != \"*\": # Deletion\n"," mutated_seq = ref_sequence[:offset] + alt + ref_sequence[offset+len(ref):]\n","\n"," elif len(ref) > len(alt) and alt == \"*\": # Deletion\n"," mutated_seq = ref_sequence[:offset] + ref_sequence[offset+len(ref):]\n","\n"," return mutated_seq"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"GcmKxCtVxcvA"},"outputs":[],"source":["def construct_combinations(sequence, mutations):\n"," \"\"\"\n"," Construct all combinations of mutations.\n"," mutations is a list of tuples (position, ref, [alts])\n"," \"\"\"\n"," if not mutations:\n"," return [sequence]\n","\n"," # Take the first mutation and recursively construct combinations for the rest\n"," first_mutation = mutations[0]\n"," rest_mutations = mutations[1:]\n"," offset, ref, alts = first_mutation\n","\n"," sequences = []\n"," for alt in alts:\n"," mutated_sequence = apply_mutation(sequence, offset, ref, alt)\n"," sequences.extend(construct_combinations(mutated_sequence, rest_mutations))\n","\n"," return sequences\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"RZLZgtCvI1fL"},"outputs":[],"source":["def needleman_wunsch_alignment(query_seq, ref_seq):\n"," \"\"\"\n"," Use Needleman-Wunsch alignment to find the maximum alignment position in ref_seq\n"," Use this position to represent the position of target sequence with mutations\n"," \"\"\"\n"," # Needleman-Wunsch alignment\n"," alignment = parasail.nw_trace(query_seq, ref_seq, 10, 1, parasail.blosum62)\n","\n"," # extract CIGAR object\n"," cigar = alignment.cigar\n"," cigar_string = cigar.decode.decode(\"utf-8\")\n","\n"," # record ref_pos\n"," ref_pos = 0\n","\n"," matches = re.findall(r'(\\d+)([MIDNSHP=X])', cigar_string)\n"," max_num_before_equal = 0\n"," max_equal_index = -1\n"," total_before_max_equal = 0\n","\n"," for i, (num_str, op) in enumerate(matches):\n"," num = int(num_str)\n"," if op == '=':\n"," if num > max_num_before_equal:\n"," max_num_before_equal = num\n"," max_equal_index = i\n"," total_before_max_equal = sum(int(matches[j][0]) for j in range(max_equal_index))\n","\n"," ref_pos = total_before_max_equal\n","\n"," return ref_pos\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"lf03oilXckHu"},"outputs":[],"source":["def find_gRNA_with_mutation(ref_sequence, exon_chr, start, end, strand, transcript_id,\n"," exon_id, gene_symbol, vcf_reader, pam=\"NGG\", target_length=20):\n"," # initialization\n"," mutated_sequences = [ref_sequence]\n","\n"," # find mutations within interested region\n"," mutations = vcf_reader(f\"{exon_chr}:{start}-{end}\")\n"," if mutations:\n"," # find mutations\n"," mutation_list = []\n"," for mutation in mutations:\n"," offset = mutation.POS - start\n"," ref = mutation.REF\n"," alts = mutation.ALT[:-1]\n"," mutation_list.append((offset, ref, alts))\n","\n"," # replace reference sequence of mutation\n"," mutated_sequences = construct_combinations(ref_sequence, mutation_list)\n","\n"," # find gRNA in ref_sequence or all mutated_sequences\n"," targets = []\n"," for seq in mutated_sequences:\n"," len_sequence = len(seq)\n"," dnatorna = {'A': 'A', 'T': 'U', 'C': 'C', 'G': 'G'}\n"," for i in range(len_sequence - len(pam) + 1):\n"," if seq[i + 1:i + 3] == pam[1:]:\n"," if i >= target_length:\n"," target_seq = seq[i - target_length:i + 3]\n"," pos = ref_sequence.find(target_seq)\n"," if pos != -1:\n"," is_mut = False\n"," if strand == -1:\n"," tar_start = end - pos - target_length - 2\n"," else:\n"," tar_start = start + pos\n"," else:\n"," is_mut = True\n"," nw_pos = needleman_wunsch_alignment(target_seq, ref_sequence)\n"," if strand == -1:\n"," tar_start = str(end - nw_pos - target_length - 2) + '*'\n"," else:\n"," tar_start = str(start + nw_pos) + '*'\n"," gRNA = ''.join([dnatorna[base] for base in seq[i - target_length:i]])\n"," targets.append([target_seq, gRNA, exon_chr, str(strand), str(tar_start), transcript_id, exon_id, gene_symbol, is_mut])\n","\n"," # filter duplicated targets\n"," unique_targets_set = set(tuple(element) for element in targets)\n"," unique_targets = [list(element) for element in unique_targets_set]\n","\n"," return unique_targets"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"AKBMj93qdO_K"},"outputs":[],"source":["def format_prediction_output_with_mutation(targets, model_path):\n"," model = MultiHeadAttention_model(input_shape=(23, 4))\n"," model.load_weights(model_path)\n","\n"," formatted_data = []\n","\n"," for target in targets:\n"," # Encode the gRNA sequence\n"," encoded_seq = get_seqcode(target[0])\n","\n","\n"," # Predict on-target efficiency using the model\n"," prediction = float(list(model.predict(encoded_seq, verbose=0)[0])[0])\n"," if prediction > 100:\n"," prediction = 100\n","\n"," # Format output\n"," gRNA = target[1]\n"," exon_chr = target[2]\n"," strand = target[3]\n"," tar_start = target[4]\n"," transcript_id = target[5]\n"," exon_id = target[6]\n"," gene_symbol = target[7]\n"," is_mut = target[8]\n"," formatted_data.append([gene_symbol, exon_chr, strand, tar_start, transcript_id,\n"," exon_id, target[0], gRNA, prediction, is_mut])\n","\n"," return formatted_data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"8Io8toDnKTQ0"},"outputs":[],"source":["def gRNADesign_mutation(gene_symbol, vcf_reader, model_path, write_to_csv=False):\n"," results = []\n","\n"," transcripts = fetch_ensembl_transcripts(gene_symbol)\n"," if transcripts:\n"," for transcript in transcripts:\n"," Exons = transcript['Exon']\n"," transcript_id = transcript['id']\n","\n"," for Exon in Exons:\n"," exon_id = Exon['id']\n"," exon_chr = Exon['seq_region_name']\n"," start = Exon['start']\n"," end = Exon['end']\n"," strand = Exon['strand']\n"," gene_sequence = fetch_ensembl_sequence(exon_id) # reference exon sequence\n","\n"," if gene_sequence:\n"," targets = find_gRNA_with_mutation(gene_sequence, exon_chr, start, end, strand,\n"," transcript_id, exon_id, gene_symbol, vcf_reader)\n"," if targets:\n"," # Predict on-target efficiency for each gRNA site\n"," formatted_data = format_prediction_output_with_mutation(targets, model_path)\n"," results.append(formatted_data)\n"," header = ['Gene','Chrom','Strand','Start','Transcript','Exon','Target sequence (5\\' to 3\\')','gRNA','pred_Score','Is_mutation']\n"," output = []\n"," for result in results:\n"," for item in result:\n"," output.append(item)\n"," sort_output = sorted(output, key=lambda x: x[8], reverse=True)\n","\n"," if write_to_csv==True:\n"," pd.DataFrame(data=sort_output, columns=header).to_csv(f'/content/drive/MyDrive/Colab Notebooks/Cas9/On target/design_results/Cas9_{gene_symbol}_mut.csv')\n"," else:\n"," return sort_output"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kjaSaiAzHebS"},"outputs":[],"source":["# read VCF file\n","vcf_reader = cyvcf2.VCF('/content/drive/MyDrive/Colab Notebooks/CRISPR_data/SRR25934512.filter.snps.indels.vcf.gz')"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true},"id":"JWYYKmqBNfu4"},"outputs":[],"source":["# design\n","genes = ['TROAP','SPC24','RAD54L','MCM2','COPB2','CKAP5']\n","model_path = '/content/drive/MyDrive/Colab Notebooks/Cas9/On target/saved_model/Cas9_MultiHeadAttention_weights.keras'\n","\n","for gene in genes:\n"," gRNADesign_mutation(gene, vcf_reader, model_path, write_to_csv=True)"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"A100","provenance":[],"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}