Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Yiyu Ni <niyiyu@uw.edu>
  • Loading branch information
niyiyu committed Aug 13, 2024
1 parent 8918352 commit 73b6f85
Show file tree
Hide file tree
Showing 8 changed files with 433 additions and 314 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ figures/INR/*.pdf
*.out
*.mov
*.key
SYNC.sh

# docs shall not be uploaded
docs/manuscript/
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ This study explores wavefield reconstruction using machine learning methods for
This repository provides independent notebook examples of model training and inference performed in the manuscript. All codes are implemented in PyTorch.

### SHallow REcurrent Decoder
The notebook of SHRED model training is available at [notebooks/SHRED_KKFLS_training.ipynb](./notebooks/SHRED_KKFLS_training.ipynb). See below for instructions of getting the training data.
The notebook of SHRED model training on the CI DAS data is available at [notebooks/training_SHRED_KKFLS.ipynb](./notebooks/training_SHRED_KKFLS.ipynb). See below for instructions of getting the training data.
![SHRED](./docs/shred.png)

### Implicit Neural Representation
![SIREN_vs_RFFN](./docs/siren_vs_rffn_50_40epoch.gif)
- Random Fourier Feature Network (RFFN, [Tancik et al., 2020](https://arxiv.org/abs/2006.10739)): [notebooks/RFFN_KKFLS_training.ipynb](./notebooks/RFFN_KKFLS_training.ipynb)
- Sinusoidal Representation Network (SIREN, [Sitzmann et al., 2020](https://arxiv.org/abs/2006.09661)): [notebooks/SIREN_KKFLS_training.ipynb](./notebooks/SIREN_KKFLS_training.ipynb)
- Random Fourier Feature Network (RFFN, [Tancik et al., 2020](https://arxiv.org/abs/2006.10739)): [notebooks/training_RFFN_KKFLS.ipynb](./notebooks/training_RFFN_KKFLS.ipynb)
- Sinusoidal Representation Network (SIREN, [Sitzmann et al., 2020](https://arxiv.org/abs/2006.09661)): [notebooks/training_SIREN_KKFLS.ipynb](./notebooks/training_SIREN_KKFLS.ipynb)

## Data
The earthquake data from the Cook Inlet DAS experiment are available at [https://dasway.ess.washington.edu/gci/index.html](https://dasway.ess.washington.edu/gci/index.html). Earthquakes and daily data reports will be updated daily.
The earthquake data from the Cook Inlet DAS experiment are available at [https://dasway.ess.washington.edu/gci/index.html](https://dasway.ess.washington.edu/gci/index.html). Earthquakes and daily data reports are updated daily.

Due to the size of the data used in this study (~260 GB per cable), we cannot upload it directly in this repository. However, we prepared a Python script to download these data from our archival server. Please refer to the script [download.py](./data/download.py) and list of events [event_list.csv](./data/event_list.csv) in the repository.

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2,76 +2,27 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "bda9187d",
"metadata": {},
"outputs": [],
"source": [
"import os, time, sys, gc\n",
"\n",
"# please update this path accordingly\n",
"sys.path.append(\"../../DAS-reconstruction/scripts/\")\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = \"3\"\n",
"\n",
"import glob\n",
"import h5py\n",
"import torch\n",
"import numpy as np\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"from tqdm import tqdm\n",
"import scipy\n",
"import pickle\n",
"import os\n",
"import h5py\n",
"import obspy\n",
"import pandas as pd\n",
"import time\n",
"from obspy.core.utcdatetime import UTCDateTime\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = \"3\"\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"from torch.nn import MSELoss\n",
"from torch.optim import lr_scheduler\n",
"\n",
"from obspy.signal.filter import bandpass, lowpass, highpass\n",
"from obspy.signal.invsim import cosine_taper\n",
"from scipy.signal import butter, filtfilt, detrend"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2da99073",
"metadata": {},
"outputs": [],
"source": [
"class DASDataset(torch.utils.data.Dataset):\n",
" def __init__(self, inputs, outputs):\n",
" 'Initialization'\n",
" if isinstance(inputs, torch.Tensor):\n",
" self.inputs = inputs\n",
" self.outputs = outputs\n",
" else:\n",
" self.inputs = inputs.astype(np.float32)\n",
" self.outputs = outputs.astype(np.float32)\n",
"\n",
" def __len__(self):\n",
" 'Denotes the total number of samples'\n",
" return len(self.outputs)\n",
"\n",
" def __getitem__(self, index):\n",
" 'Generates one sample of data'\n",
" X = self.inputs[index, :]\n",
" y = self.outputs[index, :]\n",
"\n",
" return X, y\n",
" \n",
"class SHRED(torch.nn.Module):\n",
" def __init__(self, input_size, hidden_size, output_size, num_layers):\n",
" super().__init__()\n",
" self.lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=0.2)\n",
" self.sdn1 = torch.nn.Linear(hidden_size, output_size//2)\n",
" self.sdn3 = torch.nn.Linear(output_size//2, output_size)\n",
" self.relu = torch.nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" x = self.lstm(x)[1][0][-1] # should be -1\n",
" x = self.relu(self.sdn1(x))\n",
" x = self.sdn3(x)\n",
" return x"
"from utils import clean_up, count_weights\n",
"from models import SHRED\n",
"from datasets import DASDataset"
]
},
{
Expand All @@ -81,13 +32,14 @@
"metadata": {},
"outputs": [],
"source": [
"# please update this path accordingly\n",
"flist = glob.glob(\"../../datasets/earthquakes/*\")\n",
"\n",
"nsample_train = 300\n",
"nsample_val = 300\n",
"nsample_test = 300\n",
"\n",
"ncha = 201\n",
"ncha = 151\n",
"ntime = 200\n",
"\n",
"ncha_start = 1000\n",
Expand Down Expand Up @@ -126,31 +78,14 @@
" f.close()\n",
"\n",
" for _ in range(nsample_train):\n",
" idt = np.random.randint(ntime+1, 3000) # last time index\n",
" idt = np.random.randint(ntime+1, 3000) # last time index\n",
" ic = np.random.randint(0, ncha_end-ncha_start-noutput) # first channel indexes\n",
" X[i, :, :] = data[ic+cidx, idt-(ntime-1):idt+1].T\n",
" Y[i, :] = data[ic:ic+noutput, idt]\n",
" i += 1\n",
"print(f\"training set size: {i}\") "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0208ddd5",
"metadata": {},
"outputs": [],
"source": [
"vmax = 5\n",
"\n",
"plt.figure(figsize=(10, 8), dpi=300)\n",
"plt.imshow(data, aspect='auto', cmap='RdBu', origin='lower', \n",
" vmax = vmax, vmin = -vmax)\n",
"plt.title(\"original\", fontsize=20)\n",
"plt.xticks([]); \n",
"plt.yticks([])"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -183,30 +118,19 @@
"Y_test_ts = torch.Tensor(Y[idx_test, :])\n",
"\n",
"dataset = DASDataset(X_train_ts, Y_train_ts)\n",
"data_loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)\n",
"data_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)\n",
"\n",
"val_dataset = DASDataset(X_val_ts, Y_val_ts)\n",
"val_data_loader = DataLoader(val_dataset, batch_size=512, shuffle=True, num_workers=0)\n",
"val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=512, shuffle=True, num_workers=0)\n",
"\n",
"test_dataset = DASDataset(X_test_ts, Y_test_ts)\n",
"test_data_loader = DataLoader(test_dataset, batch_size=512, shuffle=True, num_workers=0)\n",
"test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512, shuffle=True, num_workers=0)\n",
"\n",
"print(\"train: \", X_train_ts.shape, Y_train_ts.shape)\n",
"print(\"validate: \", X_val_ts.shape, Y_val_ts.shape)\n",
"print(\"test: \", X_test_ts.shape, Y_test_ts.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27043620",
"metadata": {},
"outputs": [],
"source": [
"import gc\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -226,10 +150,7 @@
" torch.nn.init.kaiming_normal_(m.weight)\n",
" m.bias.data.fill_(0.01)\n",
" \n",
"n_weights = 0\n",
"for i in model.parameters():\n",
" n_weights += len(i.data.flatten())\n",
"print(f\"have total {n_weights} weights\")"
"print(f\"The model have total {count_weights(model)} weights\")"
]
},
{
Expand All @@ -241,8 +162,8 @@
"source": [
"nepoch = 80\n",
"optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)\n",
"loss_fn = MSELoss()\n",
"scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=nepoch)"
"loss_fn = torch.nn.MSELoss()\n",
"scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=nepoch)"
]
},
{
Expand All @@ -254,11 +175,9 @@
},
"outputs": [],
"source": [
"t0 = time.time()\n",
"train_loss_log = []\n",
"val_loss_log = []\n",
"test_loss_log = []\n",
" \n",
"train_loss_log = []; val_loss_log = []; test_loss_log = []\n",
"\n",
"t0 = time.time() \n",
"for t in range(nepoch):\n",
" model.train()\n",
" train_loss = []\n",
Expand Down Expand Up @@ -294,22 +213,28 @@
" print(\"Epoch %d: Adam lr %.4f -> %.4f\" % (t, before_lr, after_lr))\n",
" print(\"%d, %.4f, %.4f, %.4f\" % (t, np.mean(train_loss), np.mean(test_loss), np.mean(val_loss)))\n",
" \n",
"# torch.save(model.state_dict(), \n",
"# f\"/home/niyiyu/Research/DAS-NIR/gci-summary/results/weights/SHRED_KKFLS_25Hz_201i_1000o_200sp_epo{t}.pt\")\n",
"print(time.time() - t0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4394698d",
"id": "a7ff4216",
"metadata": {},
"outputs": [],
"source": [
"with open(\"../../datasets/loss.pt\", \"wb\") as f:\n",
" pickle.dump({\"train\": train_loss_log,\n",
" \"validate\": val_loss_log,\n",
" \"test\": test_loss_log}, f)"
"clean_up()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d473baef",
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), \n",
" f\"../../datasets/weights/SHRED_KKFLS_25Hz_151i_1000o_200sp.pt\")"
]
},
{
Expand All @@ -327,8 +252,7 @@
"plt.xlabel(\"Epoch\", fontsize = 15)\n",
"plt.ylabel(\"Loss\", fontsize = 15)\n",
"plt.grid(True)\n",
"plt.savefig(\"../figures/manuscripts/FigS_loss.pdf\", bbox_inches='tight', dpi=300)\n",
"# plt.yscale('log')"
"plt.savefig(\"../figures/manuscripts/FigS_loss.pdf\", bbox_inches='tight', dpi=300)"
]
},
{
Expand All @@ -338,10 +262,9 @@
"metadata": {},
"outputs": [],
"source": [
"model.train()\n",
"idx = np.random.randint(0, len(X_val_ts))\n",
"model.eval()\n",
"\n",
"# model.eval()\n",
"idx = np.random.randint(0, len(X_val_ts))\n",
"inputs = X_val_ts[idx, :, :]\n",
"label = Y_val_ts[idx, :]\n",
"predict = model(inputs.to(device)).cpu().detach().numpy()\n",
Expand Down
Loading

0 comments on commit 73b6f85

Please sign in to comment.