Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
from scipy import signal | |
from scipy.signal import butter, lfilter, detrend | |
# Make bandpass filter | |
def butter_bandpass(lowcut, highcut, fs, order=5): | |
nyq = 0.5 * fs # Nyquist frequency | |
low = lowcut / nyq # Normalized frequency | |
high = highcut / nyq | |
b, a = butter(order, [low, high], btype="band") # Bandpass filter | |
return b, a | |
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5): | |
b, a = butter_bandpass(lowcut, highcut, fs, order=order) | |
y = lfilter(b, a, data) | |
return y | |
def rotate_waveform(waveform, angle): | |
fft_waveform = np.fft.fft(waveform) # Compute the Fourier transform of the waveform | |
rotate_factor = np.exp( | |
1j * angle | |
) # Create a complex exponential with the specified rotation angle | |
rotated_fft_waveform = ( | |
fft_waveform * rotate_factor | |
) # Multiply the Fourier transform by the rotation factor | |
rotated_waveform = np.fft.ifft( | |
rotated_fft_waveform | |
) # Compute the inverse Fourier transform to get the rotated waveform in the time domain | |
return rotated_waveform | |
def augment(sample): | |
# SET PARAMETERS: | |
crop_length = 6000 | |
padding = 120 | |
test = False | |
waveform = sample["waveform.npy"] | |
meta = sample["meta.json"] | |
if meta["split"] != "train": | |
test = True | |
target_sample_P = meta["trace_p_arrival_sample"] | |
target_sample_S = meta["trace_s_arrival_sample"] | |
if target_sample_P is None: | |
target_sample_P = 0 | |
if target_sample_S is None: | |
target_sample_S = 0 | |
# Randomly select a phase to start the crop | |
current_phases = [x for x in (target_sample_P, target_sample_S) if x > 0] | |
phase_selector = np.random.randint(0, len(current_phases)) | |
first_phase = current_phases[phase_selector] | |
# Shuffle | |
if first_phase - (crop_length - padding) > padding: | |
start_indx = int( | |
first_phase | |
- torch.randint(low=padding, high=(crop_length - padding), size=(1,)) | |
) | |
if test == True: | |
start_indx = int(first_phase - 2 * padding) | |
elif int(first_phase - padding) > 0: | |
start_indx = int( | |
first_phase | |
- torch.randint(low=0, high=(int(first_phase - padding)), size=(1,)) | |
) | |
if test == True: | |
start_indx = int(first_phase - padding) | |
else: | |
start_indx = padding | |
end_indx = start_indx + crop_length | |
if (waveform.shape[-1] - end_indx) < 0: | |
start_indx += waveform.shape[-1] - end_indx | |
end_indx = start_indx + crop_length | |
# Update target | |
new_target_P = target_sample_P - start_indx | |
new_target_S = target_sample_S - start_indx | |
# Cut | |
waveform_cropped = waveform[:, start_indx:end_indx] | |
# Preprocess | |
waveform_cropped = detrend(waveform_cropped) | |
waveform_cropped = butter_bandpass_filter( | |
waveform_cropped, lowcut=0.2, highcut=40, fs=100, order=5 | |
) | |
window = signal.windows.tukey(waveform_cropped[-1].shape[0], alpha=0.1) | |
waveform_cropped = waveform_cropped * window | |
waveform_cropped = detrend(waveform_cropped) | |
if np.isnan(waveform_cropped).any() == True: | |
waveform_cropped = np.zeros(shape=waveform_cropped.shape) | |
new_target_P = 0 | |
new_target_S = 0 | |
if np.sum(waveform_cropped) == 0: | |
new_target_P = 0 | |
new_target_S = 0 | |
# Normalize data | |
max_val = np.max(np.abs(waveform_cropped)) | |
waveform_cropped_norm = waveform_cropped / max_val | |
# Added Z component only | |
if len(waveform_cropped_norm) < 3: | |
zeros = np.zeros((3, waveform_cropped_norm.shape[-1])) | |
zeros[0] = waveform_cropped_norm | |
waveform_cropped_norm = zeros | |
if test == False: | |
##### Rotate waveform ##### | |
probability = torch.randint(0, 2, size=(1,)).item() | |
angle = torch.FloatTensor(size=(1,)).uniform_(0.01, 359.9).item() | |
if probability == 1: | |
waveform_cropped_norm = rotate_waveform(waveform_cropped_norm, angle).real | |
#### Channel DropOUT ##### | |
probability = torch.randint(0, 2, size=(1,)).item() | |
channel = torch.randint(1, 3, size=(1,)).item() | |
if probability == 1: | |
waveform_cropped_norm[channel, :] = 1e-6 | |
# Normalize target | |
new_target_P = new_target_P / crop_length | |
new_target_S = new_target_S / crop_length | |
if (new_target_P <= 0) or (new_target_P >= 1) or (np.isnan(new_target_P)): | |
new_target_P = 0 | |
if (new_target_S <= 0) or (new_target_S >= 1) or (np.isnan(new_target_S)): | |
new_target_S = 0 | |
return waveform_cropped_norm, new_target_P, new_target_S | |
def collation_fn(sample): | |
waveforms = np.stack([x[0] for x in sample]) | |
targets_P = np.stack([x[1] for x in sample]) | |
targets_S = np.stack([x[2] for x in sample]) | |
return ( | |
torch.tensor(waveforms, dtype=torch.float), | |
torch.tensor(targets_P, dtype=torch.float), | |
torch.tensor(targets_S, dtype=torch.float), | |
) | |
def my_split_by_node(urls): | |
node_id, node_count = ( | |
torch.distributed.get_rank(), | |
torch.distributed.get_world_size(), | |
) | |
return list(urls)[node_id::node_count] | |
def prepare_waveform(waveform): | |
# SET PARAMETERS: | |
crop_length = 6000 | |
padding = 120 | |
assert waveform.shape[0] <= 3, "Waveform has more than 3 channels" | |
if waveform.shape[-1] < crop_length: | |
waveform = np.pad( | |
waveform, | |
((0, 0), (0, crop_length - waveform.shape[-1])), | |
mode="constant", | |
constant_values=0, | |
) | |
if waveform.shape[-1] > crop_length: | |
waveform = waveform[:, :crop_length] | |
# Preprocess | |
waveform = detrend(waveform) | |
waveform = butter_bandpass_filter( | |
waveform, lowcut=0.2, highcut=40, fs=100, order=5 | |
) | |
window = signal.windows.tukey(waveform[-1].shape[0], alpha=0.1) | |
waveform = waveform * window | |
waveform = detrend(waveform) | |
assert np.isnan(waveform).any() != True, "Nan in waveform" | |
assert np.sum(waveform) != 0, "Sum of waveform sample is zero" | |
# Normalize data | |
max_val = np.max(np.abs(waveform)) | |
waveform = waveform / max_val | |
# Added Z component only | |
if len(waveform) < 3: | |
zeros = np.zeros((3, waveform.shape[-1])) | |
zeros[0] = waveform | |
waveform = zeros | |
return torch.tensor([waveform]*128, dtype=torch.float) |