{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"[TPU] SplitTransformer Pred X2 - Without Embedding","provenance":[{"file_id":"12Dps5MTL_QQFnmGcCyuePCHlDlsvb6A6","timestamp":1642554601295},{"file_id":"https://github.com/keras-team/keras-io/blob/master/examples/vision/ipynb/perceiver_image_classification.ipynb","timestamp":1621552889682}],"collapsed_sections":[],"machine_shape":"hm"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.0"},"accelerator":"TPU"},"cells":[{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"liJJGzQp4qzO","executionInfo":{"status":"ok","timestamp":1650123887494,"user_tz":240,"elapsed":7227,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"3314cb27-9b58-43c1-824e-499a86683900"},"source":["!pip install tensorflow-addons\n","!pip install pyyaml h5py"],"execution_count":1,"outputs":[{"output_type":"stream","name":"stdout","text":["Collecting tensorflow-addons\n","  Downloading tensorflow_addons-0.16.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)\n","\u001b[?25l\r\u001b[K     |▎                               | 10 kB 26.9 MB/s eta 0:00:01\r\u001b[K     |▋                               | 20 kB 10.2 MB/s eta 0:00:01\r\u001b[K     |▉                               | 30 kB 8.9 MB/s eta 0:00:01\r\u001b[K     |█▏                              | 40 kB 8.2 MB/s eta 0:00:01\r\u001b[K     |█▌                              | 51 kB 4.5 MB/s eta 0:00:01\r\u001b[K     |█▊                              | 61 kB 5.4 MB/s eta 0:00:01\r\u001b[K     |██                              | 71 kB 5.5 MB/s eta 0:00:01\r\u001b[K     |██▍                             | 81 kB 5.5 MB/s eta 0:00:01\r\u001b[K     |██▋                             | 92 kB 6.1 MB/s eta 0:00:01\r\u001b[K     |███                             | 102 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███▏                            | 112 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███▌                            | 122 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███▉                            | 133 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████                            | 143 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████▍                           | 153 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████▊                           | 163 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████                           | 174 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████▎                          | 184 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████▌                          | 194 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████▉                          | 204 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████▏                         | 215 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████▍                         | 225 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████▊                         | 235 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████                         | 245 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████▎                        | 256 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████▋                        | 266 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████▉                        | 276 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████▏                       | 286 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████▌                       | 296 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████▊                       | 307 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████                       | 317 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████▍                      | 327 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████▋                      | 337 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████                      | 348 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████▏                     | 358 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████▌                     | 368 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████▉                     | 378 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████                     | 389 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████▍                    | 399 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████▊                    | 409 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████                    | 419 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████▎                   | 430 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████▌                   | 440 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████▉                   | 450 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████▏                  | 460 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████▍                  | 471 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████▊                  | 481 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████                  | 491 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████▎                 | 501 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████▋                 | 512 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████▉                 | 522 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████▏                | 532 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████▌                | 542 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████▊                | 552 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████                | 563 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████▍               | 573 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████▋               | 583 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████               | 593 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████▏              | 604 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████▌              | 614 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████▉              | 624 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████              | 634 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████▍             | 645 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████▊             | 655 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████             | 665 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████▎            | 675 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████▌            | 686 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████▉            | 696 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████▏           | 706 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████▍           | 716 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████▊           | 727 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████████           | 737 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████████▎          | 747 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████████▋          | 757 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████████▉          | 768 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████████▏         | 778 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████████▌         | 788 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████████▊         | 798 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████████         | 808 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████████▍        | 819 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████████▋        | 829 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████████        | 839 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▏       | 849 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▌       | 860 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████████▉       | 870 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████       | 880 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▍      | 890 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▊      | 901 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████      | 911 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▎     | 921 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▌     | 931 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▉     | 942 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▏    | 952 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▍    | 962 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▊    | 972 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████    | 983 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▎   | 993 kB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▋   | 1.0 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▉   | 1.0 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▏  | 1.0 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▌  | 1.0 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▊  | 1.0 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████  | 1.1 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▍ | 1.1 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▋ | 1.1 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████ | 1.1 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▏| 1.1 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▌| 1.1 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▉| 1.1 MB 5.3 MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 1.1 MB 5.3 MB/s \n","\u001b[?25hRequirement already satisfied: typeguard>=2.7 in /usr/local/lib/python3.7/dist-packages (from tensorflow-addons) (2.7.1)\n","Installing collected packages: tensorflow-addons\n","Successfully installed tensorflow-addons-0.16.1\n","Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (3.13)\n","Requirement already satisfied: h5py in /usr/local/lib/python3.7/dist-packages (3.1.0)\n","Requirement already satisfied: cached-property in /usr/local/lib/python3.7/dist-packages (from h5py) (1.5.2)\n","Requirement already satisfied: numpy>=1.14.5 in /usr/local/lib/python3.7/dist-packages (from h5py) (1.21.5)\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ZWzi3Z3L5RO2","executionInfo":{"status":"ok","timestamp":1650123890558,"user_tz":240,"elapsed":3067,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"80cd4314-31fc-4e8e-fbfe-adc57e74a6ca"},"source":["import numpy as np\n","%tensorflow_version 2.x\n","import tensorflow as tf\n","print(\"Tensorflow version \" + tf.__version__)"],"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Tensorflow version 2.8.0\n"]}]},{"cell_type":"code","metadata":{"id":"ZhJ7px_V4u_B","executionInfo":{"status":"ok","timestamp":1650124216536,"user_tz":240,"elapsed":325980,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["from pydrive.auth import GoogleAuth\n","from pydrive.drive import GoogleDrive\n","from google.colab import auth\n","from oauth2client.client import GoogleCredentials\n","\n","auth.authenticate_user()\n","gauth = GoogleAuth()\n","gauth.credentials = GoogleCredentials.get_application_default()\n","drive = GoogleDrive(gauth)\n","# full file\n","\n","downloaded = drive.CreateFile({'id':\"1mhsqTpHAOdo90kReD1Cx2mdyxwmGgdcB\"})   # replace the id with id of file you want to access\n","downloaded.GetContentFile('genotype_full.csv')  \n","\n","downloaded = drive.CreateFile({'id':\"111m8AYEJtc5GKvO4Z1gx8mRxLz9_D8zD\"})   # replace the id with id of file you want to access\n","downloaded.GetContentFile('phenotypes.csv')  \n","\n"],"execution_count":3,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"RjGOO5PdFPf7"},"source":["## Setup"]},{"cell_type":"code","metadata":{"id":"odmhCqSVFPf8","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1650124217239,"user_tz":240,"elapsed":706,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"8fd8c17a-39a5-454c-dc0b-c7e515806c7a"},"source":["import os\n","# os.environ[\"MODIN_CPUS\"] = \"8\"\n","# from distributed import Client\n","# client = Client()\n","import math\n","import re\n","import pandas as pd\n","from sklearn.model_selection import train_test_split\n","from sklearn.utils import resample\n","import tensorflow as tf\n","from tensorflow import keras\n","import tensorflow.keras.backend as K\n","from tensorflow.keras import layers\n","from tensorflow.keras import regularizers\n","from tensorflow.keras.preprocessing.sequence import pad_sequences\n","import tensorflow_addons as tfa\n","from tensorflow.keras.utils import to_categorical\n","from tensorflow.keras import constraints\n","from tensorflow.keras import initializers\n","from tensorflow.keras import regularizers\n","from tensorflow.keras.layers import InputSpec\n","%matplotlib inline   \n","from matplotlib import pyplot as plt\n","import tensorflow_datasets as tfds\n","from sklearn.metrics import mean_squared_error\n","from sklearn.linear_model import LassoCV, ElasticNetCV\n","\n","print(\"Tensorflow version \" + tf.__version__)\n","# resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])\n","# tf.config.experimental_connect_to_cluster(resolver)\n","# # This is the TPU initialization code that has to be at the beginning.\n","# tf.tpu.experimental.initialize_tpu_system(resolver)\n","# print(\"All devices: \", tf.config.list_logical_devices('TPU'))\n","# strategy = tf.distribute.TPUStrategy(resolver)"],"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["Tensorflow version 2.8.0\n"]}]},{"cell_type":"code","source":["# Detect hardware, return appropriate distribution strategy\n","try:\n","    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.\n","    print('Running on TPU ', TPU.master())\n","except ValueError:\n","    print('Running on GPU')\n","    TPU = None\n","\n","if TPU:\n","    tf.config.experimental_connect_to_cluster(TPU)\n","    tf.tpu.experimental.initialize_tpu_system(TPU)\n","    strategy = tf.distribute.experimental.TPUStrategy(TPU)\n","else:\n","    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.\n","\n","N_REPLICAS = strategy.num_replicas_in_sync\n","# Number of computing cores, is 8 for a TPU V3-8\n","print(f'N_REPLICAS: {N_REPLICAS}')"],"metadata":{"id":"2g2HHI7AopmH","colab":{"base_uri":"https://localhost:8080/"},"outputId":"eeebd116-e08f-4c3e-cc89-3ff0055a1b88","executionInfo":{"status":"ok","timestamp":1650124273748,"user_tz":240,"elapsed":56512,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"execution_count":5,"outputs":[{"output_type":"stream","name":"stdout","text":["Running on TPU  grpc://10.121.106.26:8470\n","INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:Deallocate tpu buffers before initializing tpu system.\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:Initializing the TPU system: grpc://10.121.106.26:8470\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:Initializing the TPU system: grpc://10.121.106.26:8470\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:Finished initializing TPU system.\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:Finished initializing TPU system.\n","WARNING:absl:`tf.distribute.experimental.TPUStrategy` is deprecated, please use  the non experimental symbol `tf.distribute.TPUStrategy` instead.\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:Found TPU system:\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:Found TPU system:\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Num TPU Cores: 8\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Num TPU Cores: 8\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Num TPU Workers: 1\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Num TPU Workers: 1\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"]},{"output_type":"stream","name":"stderr","text":["INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"]},{"output_type":"stream","name":"stdout","text":["N_REPLICAS: 8\n"]}]},{"cell_type":"markdown","metadata":{"id":"A77GFE3xFPf8"},"source":["## Prepare the data"]},{"cell_type":"code","metadata":{"id":"j3zy8i_8FPf_","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1650124320034,"user_tz":240,"elapsed":46289,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"f7e597b5-a4de-4482-8d0a-de2c2d5b581b"},"source":["phenoIndex = 2\n","\n","genotypes = pd.read_csv('genotype_full.csv', sep='\\t', index_col=0)\n","genotypes[genotypes == -1] = 0\n","multi_pheno = pd.read_csv('phenotypes.csv', sep=',', index_col=0)\n","#use only training data to generate this\n","\n","headers = genotypes.columns[:]\n","phenoName = multi_pheno.columns[phenoIndex]\n","print(phenoName)"],"execution_count":6,"outputs":[{"output_type":"stream","name":"stdout","text":["1_Diamide_1\n"]}]},{"cell_type":"code","metadata":{"id":"H2LdnKzXnW0P","colab":{"base_uri":"https://localhost:8080/","height":331},"executionInfo":{"status":"ok","timestamp":1650124320036,"user_tz":240,"elapsed":10,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"aca27025-d98b-4691-f410-2d110104f0ea"},"source":["genotypes.head()"],"execution_count":7,"outputs":[{"output_type":"execute_result","data":{"text/plain":["       33070_chrI_33070_A_T  33147_chrI_33147_G_T  33152_chrI_33152_T_C  \\\n","SAMID                                                                     \n","01_01                     1                     1                     1   \n","01_02                     1                     1                     1   \n","01_03                     0                     0                     0   \n","01_04                     1                     1                     1   \n","01_06                     0                     0                     0   \n","\n","       33200_chrI_33200_C_T  33293_chrI_33293_A_T  33328_chrI_33328_C_A  \\\n","SAMID                                                                     \n","01_01                     1                     1                     1   \n","01_02                     1                     1                     1   \n","01_03                     0                     0                     0   \n","01_04                     1                     1                     1   \n","01_06                     0                     0                     0   \n","\n","       33348_chrI_33348_G_C  33403_chrI_33403_C_T  33502_chrI_33502_A_G  \\\n","SAMID                                                                     \n","01_01                     1                     1                     1   \n","01_02                     1                     1                     1   \n","01_03                     0                     0                     0   \n","01_04                     1                     1                     1   \n","01_06                     0                     0                     0   \n","\n","       33548_chrI_33548_A_C  ...  12048853_chrXVI_925593_G_C  \\\n","SAMID                        ...                               \n","01_01                     1  ...                           0   \n","01_02                     1  ...                           0   \n","01_03                     0  ...                           1   \n","01_04                     1  ...                           1   \n","01_06                     0  ...                           0   \n","\n","       12049199_chrXVI_925939_T_C  12049441_chrXVI_926181_C_T  \\\n","SAMID                                                           \n","01_01                           0                           0   \n","01_02                           0                           0   \n","01_03                           1                           1   \n","01_04                           1                           1   \n","01_06                           0                           0   \n","\n","       12050613_chrXVI_927353_T_G  12051167_chrXVI_927907_A_C  \\\n","SAMID                                                           \n","01_01                           0                           0   \n","01_02                           0                           0   \n","01_03                           1                           1   \n","01_04                           1                           1   \n","01_06                           0                           0   \n","\n","       12051240_chrXVI_927980_A_G  12051367_chrXVI_928107_C_T  \\\n","SAMID                                                           \n","01_01                           0                           0   \n","01_02                           0                           0   \n","01_03                           1                           1   \n","01_04                           1                           1   \n","01_06                           0                           0   \n","\n","       12052782_chrXVI_929522_C_T  12052988_chrXVI_929728_A_G  \\\n","SAMID                                                           \n","01_01                           0                           0   \n","01_02                           0                           0   \n","01_03                           1                           1   \n","01_04                           1                           1   \n","01_06                           0                           0   \n","\n","       12053130_chrXVI_929870_C_T  \n","SAMID                              \n","01_01                           0  \n","01_02                           0  \n","01_03                           1  \n","01_04                           1  \n","01_06                           0  \n","\n","[5 rows x 28220 columns]"],"text/html":["\n","  <div id=\"df-a2fd737c-35a3-4702-92e4-e83e98f45d3d\">\n","    <div class=\"colab-df-container\">\n","      <div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>33070_chrI_33070_A_T</th>\n","      <th>33147_chrI_33147_G_T</th>\n","      <th>33152_chrI_33152_T_C</th>\n","      <th>33200_chrI_33200_C_T</th>\n","      <th>33293_chrI_33293_A_T</th>\n","      <th>33328_chrI_33328_C_A</th>\n","      <th>33348_chrI_33348_G_C</th>\n","      <th>33403_chrI_33403_C_T</th>\n","      <th>33502_chrI_33502_A_G</th>\n","      <th>33548_chrI_33548_A_C</th>\n","      <th>...</th>\n","      <th>12048853_chrXVI_925593_G_C</th>\n","      <th>12049199_chrXVI_925939_T_C</th>\n","      <th>12049441_chrXVI_926181_C_T</th>\n","      <th>12050613_chrXVI_927353_T_G</th>\n","      <th>12051167_chrXVI_927907_A_C</th>\n","      <th>12051240_chrXVI_927980_A_G</th>\n","      <th>12051367_chrXVI_928107_C_T</th>\n","      <th>12052782_chrXVI_929522_C_T</th>\n","      <th>12052988_chrXVI_929728_A_G</th>\n","      <th>12053130_chrXVI_929870_C_T</th>\n","    </tr>\n","    <tr>\n","      <th>SAMID</th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","      <th></th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>01_01</th>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>...</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","    </tr>\n","    <tr>\n","      <th>01_02</th>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>...</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","    </tr>\n","    <tr>\n","      <th>01_03</th>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>...</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","    </tr>\n","    <tr>\n","      <th>01_04</th>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>...</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","      <td>1</td>\n","    </tr>\n","    <tr>\n","      <th>01_06</th>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>...</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","      <td>0</td>\n","    </tr>\n","  </tbody>\n","</table>\n","<p>5 rows × 28220 columns</p>\n","</div>\n","      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-a2fd737c-35a3-4702-92e4-e83e98f45d3d')\"\n","              title=\"Convert this dataframe to an interactive table.\"\n","              style=\"display:none;\">\n","        \n","  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n","       width=\"24px\">\n","    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n","    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n","  </svg>\n","      </button>\n","      \n","  <style>\n","    .colab-df-container {\n","      display:flex;\n","      flex-wrap:wrap;\n","      gap: 12px;\n","    }\n","\n","    .colab-df-convert {\n","      background-color: #E8F0FE;\n","      border: none;\n","      border-radius: 50%;\n","      cursor: pointer;\n","      display: none;\n","      fill: #1967D2;\n","      height: 32px;\n","      padding: 0 0 0 0;\n","      width: 32px;\n","    }\n","\n","    .colab-df-convert:hover {\n","      background-color: #E2EBFA;\n","      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n","      fill: #174EA6;\n","    }\n","\n","    [theme=dark] .colab-df-convert {\n","      background-color: #3B4455;\n","      fill: #D2E3FC;\n","    }\n","\n","    [theme=dark] .colab-df-convert:hover {\n","      background-color: #434B5C;\n","      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n","      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n","      fill: #FFFFFF;\n","    }\n","  </style>\n","\n","      <script>\n","        const buttonEl =\n","          document.querySelector('#df-a2fd737c-35a3-4702-92e4-e83e98f45d3d button.colab-df-convert');\n","        buttonEl.style.display =\n","          google.colab.kernel.accessAllowed ? 'block' : 'none';\n","\n","        async function convertToInteractive(key) {\n","          const element = document.querySelector('#df-a2fd737c-35a3-4702-92e4-e83e98f45d3d');\n","          const dataTable =\n","            await google.colab.kernel.invokeFunction('convertToInteractive',\n","                                                     [key], {});\n","          if (!dataTable) return;\n","\n","          const docLinkHtml = 'Like what you see? Visit the ' +\n","            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n","            + ' to learn more about interactive tables.';\n","          element.innerHTML = '';\n","          dataTable['output_type'] = 'display_data';\n","          await google.colab.output.renderOutput(dataTable, element);\n","          const docLink = document.createElement('div');\n","          docLink.innerHTML = docLinkHtml;\n","          element.appendChild(docLink);\n","        }\n","      </script>\n","    </div>\n","  </div>\n","  "]},"metadata":{},"execution_count":7}]},{"cell_type":"code","metadata":{"id":"0s9pKob5uO7v","executionInfo":{"status":"ok","timestamp":1650124322707,"user_tz":240,"elapsed":2679,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["y = multi_pheno.iloc[:, phenoIndex]\n","\n","# move the gene loci with NA traits\n","x_reserve = genotypes[y.isna()]\n","x = genotypes[~y.isna()]\n","y = y[~y.isna()]\n","\n","# normlization\n","scaled_Y = (y - y.min()) / (y.max() - y.min())\n","# print(max(scaled_Y), min(scaled_Y))\n","# temp_Y = scaled_Y[~scaled_Y.isna()]\n","# outliers_index = detect_outliers(temp_Y)\n","# # set outliers as NAN\n","# scaled_Y[outliers_index] = np.nan\n","#\n","x = to_categorical(x.to_numpy())\n","x_reserve = to_categorical(x_reserve.to_numpy())\n","# x = x.to_numpy()\n","# x = x.reshape(x.shape[0], x.shape[1], 1, x.shape[2])\n","# y = scaled_Y[~np.isnan(scaled_Y)].to_numpy()\n","y = scaled_Y.to_numpy()\n","input_shape = x[0,].shape\n","\n"],"execution_count":8,"outputs":[]},{"cell_type":"code","metadata":{"id":"qndKpCqCTJZ9","executionInfo":{"status":"ok","timestamp":1650124322708,"user_tz":240,"elapsed":4,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["def romanToInt(s):\n","      \"\"\"\n","      :type s: str\n","      :rtype: int\n","      \"\"\"\n","      roman = {'I':1,'V':5,'X':10,'L':50,'C':100,'D':500,'M':1000,'IV':4,'IX':9,'XL':40,'XC':90,'CD':400,'CM':900}\n","      i = 0\n","      num = 0\n","      while i < len(s):\n","         if i+1<len(s) and s[i:i+2] in roman:\n","            num+=roman[s[i:i+2]]\n","            i+=2\n","         else:\n","            #print(i)\n","            num+=roman[s[i]]\n","            i+=1\n","      return num"],"execution_count":9,"outputs":[]},{"cell_type":"code","metadata":{"id":"MGJlpQ1TTPJI","executionInfo":{"status":"ok","timestamp":1650124322912,"user_tz":240,"elapsed":207,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["chromosomes = np.zeros((len(headers)), dtype=np.int32)\n","positions = np.zeros((len(headers)), dtype=np.int32)\n","for i, h in enumerate(headers.to_list()):\n","  split_header = h.split('_')\n","  chromosomes[i] = romanToInt(split_header[1].replace(\"chr\", \"\"))\n","  positions[i] = int(split_header[0])"],"execution_count":10,"outputs":[]},{"cell_type":"code","metadata":{"id":"PHUFPzERT3U4","executionInfo":{"status":"ok","timestamp":1650124322913,"user_tz":240,"elapsed":5,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["chr_and_counts = np.unique(chromosomes, return_counts=True)\n","chromosomeStarts = np.zeros((len(np.unique(chromosomes))), dtype=np.int32)\n","for i in range(len(chromosomeStarts)):\n","  chromosomeStarts[i] = np.where(chromosomes==i+1)[0][0]\n","# chromosomeStarts"],"execution_count":11,"outputs":[]},{"cell_type":"markdown","source":["## Hyperparams"],"metadata":{"id":"cPVt9Q9JxDiu"}},{"cell_type":"code","metadata":{"id":"_pZoO-FvKdr3","executionInfo":{"status":"ok","timestamp":1650124322913,"user_tz":240,"elapsed":4,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["# hyperparameters\n","feature_size = x.shape[1]\n","inChannel = x.shape[2]\n","learning_rate = 0.001\n","weight_decay = 0.0001\n","embed_dim = 128  # Embedding size for each token\n","num_heads = 8  # Number of attention heads\n","ff_dim = 64  # Hidden layer size in feed forward network inside transformer\n","maxlen = x.shape[1]\n","batch_size = 64\n","dropout_rate = 0.25"],"execution_count":12,"outputs":[]},{"cell_type":"code","metadata":{"id":"SlyxYCy96H7V","executionInfo":{"status":"ok","timestamp":1650124322913,"user_tz":240,"elapsed":4,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["def get_dataset(x, y, batch_size, is_training, missing_perc=0.1):\n","  AUTO = tf.data.AUTOTUNE\n","  # indexes = np.arange(x.shape[0])\n","  # x_missing = x[indexes].copy()\n","  # for i in range(x_missing.shape[0]):\n","  #   missing_size = int(missing_perc * x_missing.shape[1])\n","  #   missing_index = np.random.randint(\n","  #       x_missing.shape[1], size=missing_size)\n","  #   # missing loci are encoded as [0, 0, 1]\n","  #   x_missing[i, missing_index, :] = [0, 0, 1]\n","\n","  dataset= tf.data.Dataset.from_tensor_slices((x, y))\n","  # del x_missing, indexes\n","  # Only shuffle and repeat the dataset in training. The advantage to have a\n","  # infinite dataset for training is to avoid the potential last partial batch\n","  # in each epoch, so users don't need to think about scaling the gradients\n","  # based on the actual batch size.\n","  if is_training:\n","    dataset = dataset.shuffle(x.shape[0], reshuffle_each_iteration=True)\n","    dataset = dataset.repeat()\n","\n","  # Prefetech to not map the whole dataset\n","  dataset = dataset.prefetch(AUTO)\n","\n","  dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_calls=AUTO)\n","\n","  return dataset"],"execution_count":13,"outputs":[]},{"cell_type":"markdown","source":["## Layers"],"metadata":{"id":"vXjlPeNzxKzJ"}},{"cell_type":"code","metadata":{"id":"QYJr9h1LO1Ev","executionInfo":{"status":"ok","timestamp":1650124323178,"user_tz":240,"elapsed":268,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["class TransformerBlock(layers.Layer):\n","  def __init__(self, embed_dim, num_heads, ff_dim, activation=tf.nn.gelu, rate=0.25):\n","      super(TransformerBlock, self).__init__()\n","      self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=rate)\n","      self.ffn = tf.keras.Sequential(\n","          [\n","            layers.Dense(ff_dim, activation=activation, \n","                        ),\n","            layers.Dense(embed_dim, \n","                        activation=activation,\n","                        ), ]\n","      )\n","      self.ffn1 = layers.Dense(ff_dim, activation=activation)\n","      self.ffn2 = layers.Dense(embed_dim, activation=activation)\n","      self.dropout1 = layers.Dropout(rate)\n","      self.dropout2 = layers.Dropout(rate)\n","\n","  def call(self, inputs, training):\n","      attn_output = self.att(inputs, inputs)\n","      attn_output = self.dropout1(attn_output, training=training)\n","      out1 = inputs + attn_output\n","      ffn_output = self.ffn(out1)\n","      ffn_output = self.dropout2(ffn_output, training=training)\n","      return out1 + ffn_output\n","      # return self.lnorm2(out1 + ffn_output)\n","\n","\n","class NewGenoEmbeddings(layers.Layer):\n","  def __init__(self, embedding_dim, \n","               embeddings_initializer='glorot_uniform',\n","               embeddings_regularizer=None,\n","               activity_regularizer=None,\n","               embeddings_constraint=None):\n","    super(NewGenoEmbeddings, self).__init__()\n","    self.embedding_dim = embedding_dim\n","    self.embeddings_initializer = initializers.get(embeddings_initializer)\n","    self.embeddings_regularizer = regularizers.get(embeddings_regularizer)\n","    self.activity_regularizer = regularizers.get(activity_regularizer)\n","    self.embeddings_constraint = constraints.get(embeddings_constraint)\n","\n","  def build(self, input_shape):\n","    # print(input_shape)\n","    \n","    self.num_of_allels = input_shape[-1]\n","    self.n_snps = input_shape[-2]\n","    self.position_embedding = layers.Embedding(\n","            input_dim=self.n_snps, output_dim=self.embedding_dim\n","        )\n","    # self.projection = layers.Dense(units=self.embedding_dim)\n","    self.embedding = self.add_weight(\n","            shape=(self.num_of_allels, self.embedding_dim),\n","            initializer=self.embeddings_initializer,\n","            trainable=True, name='geno_embeddings',\n","            regularizer=self.embeddings_regularizer,\n","            constraint=self.embeddings_constraint,\n","            experimental_autocast=False\n","        )\n","    self.positions = tf.range(start=0, limit=self.n_snps, delta=1)\n","    # self.matmul_calculator = MyMatmul()\n","    # self.myEinSumLayer = layers.Lambda(lambda x: tf.einsum('ijk,kl->ijl',x[0], x[1]))\n","  def call(self, inputs):\n","    # return self.projection(inputs) + self.position_embedding(positions)\n","    self.immediate_result = tf.einsum('ijk,kl->ijl', inputs, self.embedding)\n","    return self.immediate_result + self.position_embedding(self.positions)\n","\n","class WindowBlock(layers.Layer):\n","  def __init__(self, input_length, num_heads, activation=tf.nn.gelu, rate=0.25, downscale_factor=5):\n","    super(WindowBlock, self).__init__()\n","    self.transformer_size = 32#input_length//downscale_factor\n","    # self.embedding = NewGenoEmbeddings(self.transformer_size)\n","    self.conv1 = layers.Conv1D(self.transformer_size, 3, padding='same', activation=activation,\n","                    )\n","    self.dropout1 = layers.Dropout(rate)\n","    # self.mhtransformer = TransformerBlock(64, num_heads, self.transformer_size, activation=activation)\n","    self.mhtransformer = TransformerBlock(self.transformer_size, num_heads, self.transformer_size, activation=activation)\n","    # self.conv2 = layers.Conv1D(self.transformer_size//2, 1, padding='same', activation=activation,\n","    #                 )\n","    self.dropout2 = layers.Dropout(rate)\n","    self.max_pool = layers.MaxPooling1D(downscale_factor)\n","    # self.conv3 = layers.Conv1D(32, 1, padding='same', activation=tf.nn.gelu,\n","    #                 )\n","    self.ffn = tf.keras.Sequential(\n","          [\n","            layers.Dense(self.transformer_size//2, activation=activation, \n","                        ),\n","            layers.Dense(self.transformer_size, \n","                        activation=activation,\n","                        ), ]\n","      )\n","    self.bn1 = layers.BatchNormalization()\n","    self.bn2 = layers.BatchNormalization()\n","    self.bn3 = layers.BatchNormalization()\n","    # self.flatten = layers.Flatten()\n","    # self.dense_discreet = layers.Dense(1, activation=tf.nn.sigmoid)\n","    # self.dense_quantitative = layers.Dense(inChannel, activation=tf.nn.softmax)\n","    # self.concat = layers.Concatenate(axis=1)\n","    # self.reshape = layers.Reshape([-1, inChannel + 1, 1])\n","    pass\n","\n","  def call(self, inputs, training):\n","    # x = self.conv1(inputs)\n","    # x = self.dropout1(x, training=training)\n","    # emb = self.embedding(inputs)\n","    emb = self.conv1(inputs)\n","    x = self.mhtransformer(emb)\n","    x = self.dropout1(x)\n","    # y = self.conv1(emb)\n","    # x = self.conv2(x + y)\n","    x = self.bn1(x + emb)\n","\n","    x = self.max_pool(x)\n","    \n","    y = self.ffn(x)\n","    y = self.bn2(y)\n","    y = self.dropout2(y, training=training)\n","    \n","    # x = self.conv3(x)\n","    return self.bn3(y+x)\n","    # x = self.flatten(x)\n","    # y = self.dense_discreet(x)\n","    # z = self.dense_quantitative(x)\n","    # x = self.concat([y, z])\n","    # return self.reshape(x)\n","    # return self.dense(x)\n","\n","\n","class ChromosomeBlock(layers.Layer):\n","  def __init__(self, embed_dim, num_heads, chrm_snp_count, activation=tf.nn.gelu, rate=0.25, window_size=500, downscale_factor=5):\n","    super(ChromosomeBlock, self).__init__()\n","    # self.mTransformer_size = window_size//2\n","    self.additional_sliding_length = window_size//downscale_factor\n","    # self.embedding = GenoEmbeddings(embed_dim)\n","    # self.conv0 = layers.Conv1D(self.mTransformer_size, 5, padding='same', activation=activation,\n","    #                 )\n","    # self.do_0 = layers.Dropout(rate)\n","    window_lengthz = []\n","    self.window_bounds = []\n","    for block_start in range(0, chrm_snp_count, window_size):\n","      block_true_start = max(block_start, block_start - self.additional_sliding_length)\n","      block_end = min(block_start + window_size + self.additional_sliding_length, chrm_snp_count)\n","      shall_extend_to_end = chrm_snp_count - block_end < window_size\n","      block_end = block_end if not shall_extend_to_end else chrm_snp_count\n","      wl = block_end - block_true_start\n","      window_lengthz.append(wl)\n","      self.window_bounds.append((block_true_start, block_end))\n","      if shall_extend_to_end:\n","        break\n","    self.window_blocks = [WindowBlock(window_length, num_heads) for ind, window_length in enumerate(window_lengthz)]\n","    self.concatenate = layers.Concatenate(axis=1)\n","    # transformer_size = sum([window_length//10 for window_length in window_lengthz])\n","    transformer_size = 32\n","    # print(transformer_size)\n","    # self.reshaper = layers.Reshape((-1, 1))\n","    self.do_1 = layers.Dropout(rate)\n","    self.do_2 = layers.Dropout(rate)\n","    # self.conv_before = layers.Conv1D(transformer_size, 1, padding='same', activation=tf.nn.gelu,\n","    #                 )\n","    self.mhtransformer = TransformerBlock(transformer_size, num_heads, transformer_size//2, activation=tf.nn.gelu)\n","    self.ffn = tf.keras.Sequential(\n","          [\n","            layers.Dense(embed_dim//4, activation=activation, \n","                        ),\n","            layers.Dense(embed_dim//4, \n","                        activation=activation,\n","                        ), ]\n","      )\n","    self.bn1 = layers.BatchNormalization()\n","    self.bn2 = layers.BatchNormalization()\n","    self.bn3 = layers.BatchNormalization()\n","    self.bn4 = layers.BatchNormalization()\n","    # self.conv_final = layers.Conv1D(transformer_size, 1, padding='same', activation=tf.nn.gelu,\n","    #                 )\n","    # self.dense_final = layers.Dense(transformer_size, activation=tf.nn.sigmoid)\n","    pass\n","\n","  def call(self, inputs, training):\n","    # x = self.embedding(inputs)\n","    # x = self.conv0(inputs)\n","    # x = self.do_0(x, training=training)\n","    x = self.concatenate([self.window_blocks[i](inputs[:, block_start:block_end, :]) for i, (block_start, block_end) in enumerate(self.window_bounds)])\n","    # print(x.shape)\n","    # y = self.conv_before(x)\n","    # x = self.reshaper(x)\n","    # x = self.conv_before(x)\n","    y = self.mhtransformer(x)\n","    y = self.bn1(y)\n","    y = self.do_1(y, training=training)\n","    x = self.bn2(x+y)\n","    y = self.ffn(x)\n","    y = self.bn3(y)\n","    y = self.do_2(y, training=training)\n","    return self.bn4(x+y)\n","    # return self.conv_final(x+y)\n","    # return self.dense_final(x)\n","\n"],"execution_count":14,"outputs":[]},{"cell_type":"code","metadata":{"id":"NSthuu25pO--","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1650124323347,"user_tz":240,"elapsed":175,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"6e7e311f-4dc8-440c-b8ed-d64ac085e0b7"},"source":["# inputs = tf.constant(to_categorical([[0, 1],[2, 2]], dtype='int32'))\n","inputs = layers.Input(shape=(28220, 2))\n","print(inputs.shape)\n","embedding = NewGenoEmbeddings(embedding_dim=5)\n","print(embedding(inputs).shape)\n","# embedding, tf.argmax(embedding, axis=2)\n","# print(inputs(axis=2))\n","\n","# embedding[inputs.argmax(axis=1)]"],"execution_count":15,"outputs":[{"output_type":"stream","name":"stdout","text":["(None, 28220, 2)\n","(None, 28220, 5)\n"]}]},{"cell_type":"code","metadata":{"id":"tizbPIvmi0hy","executionInfo":{"status":"ok","timestamp":1650124323348,"user_tz":240,"elapsed":6,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["chrsf = []\n","chr_ends = []\n","for i, chr_start in enumerate(chromosomeStarts):\n","  chr_ends.append(chromosomeStarts[i+1] if i < len(chromosomeStarts) - 1 else feature_size)"],"execution_count":16,"outputs":[]},{"cell_type":"code","metadata":{"id":"vgGz9_6u8u-O","executionInfo":{"status":"ok","timestamp":1650124323348,"user_tz":240,"elapsed":5,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["def create_model():\n","  inputt = layers.Input(shape=(feature_size, inChannel))\n","  # xa = NewGenoEmbeddings(embedding_dim=embed_dim)(inputt)\n","  # xa = layers.Dropout(dropout_rate)(xa)\n","\n","  \n","\n","  xa = layers.concatenate([ChromosomeBlock(embed_dim, num_heads, chr_and_counts[1][i])(inputt[:, chr_start:chr_ends[i], :]) for i, chr_start in enumerate(chromosomeStarts)],axis=1)\n","  xa = layers.Conv1D(embed_dim//2, 5, padding='same', activation=tf.nn.gelu,\n","                    # kernel_regularizer=regularizers.l1(regularization_coef_l1),\n","                    )(xa)\n","  # xa = layers.AveragePooling1D(2)(xa)\n","  xa = layers.BatchNormalization(epsilon=2e-5, momentum=9e-1)(xa)\n","  xa = layers.DepthwiseConv1D(embed_dim, 1, padding='same')(xa)\n","  # xa = layers.BatchNormalization(epsilon=2e-5, momentum=9e-1)(xa)\n","  xa = layers.Activation(tf.nn.gelu)(xa)\n","  xa = layers.BatchNormalization(epsilon=2e-5, momentum=9e-1)(xa)\n","  xa = layers.DepthwiseConv1D(embed_dim, 1, padding='same')(xa)\n","  # xa = layers.BatchNormalization(epsilon=2e-5, momentum=9e-1)(xa)\n","  xa = layers.Activation(tf.nn.gelu)(xa)\n","  # xb = layers.Dense(xa.shape[2], activation=tf.nn.gelu)(xa)\n","  # xa = xa + xb\n","  xa = layers.Dropout(dropout_rate)(xa)\n","  # xa = layers.Dropout(dropout_rate)(xa)\n","  # xa = layers.Conv1D(embed_dim//4, 5, strides=2, padding='same', activation=tf.nn.gelu,\n","  #                   # kernel_regularizer=regularizers.l1(regularization_coef_l1),\n","  #                   )(xa)\n","  # xa = layers.Dropout(dropout_rate)(xa)\n","  # output1 = layers.Conv1D(inChannel, 5, padding='same', activation=tf.nn.softmax,\n","  #                   # kernel_regularizer=regularizers.l1(regularization_coef_l1),\n","  #                   )(xa)\n","  xa = layers.Flatten()(xa)\n","  xa = layers.Dense(64, activation=tf.nn.gelu)(xa)\n","  xa = layers.Dropout(dropout_rate)(xa)\n","  output1 = layers.Dense(1, activation=\"sigmoid\")(xa)\n","\n","  # output1 = layers.Dense(inChannel, activation='softmax')(xa)\n","\n","  classifier= tf.keras.Model(inputs=inputt, outputs=output1)\n","  # classifier= keras.Model(inputs=inputt, outputs=[output0, output1])\n","  # classifier.summary()\n","\n","  return classifier"],"execution_count":17,"outputs":[]},{"cell_type":"code","metadata":{"id":"Mv32H3-y75Fo","executionInfo":{"status":"ok","timestamp":1650124323349,"user_tz":240,"elapsed":5,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["def get_three_sets(x, batch_size, random_state, missing_perc):\n","  np.random.seed(seed=random_state)\n","\n","  x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.20,\n","                                        random_state=random_state,\n","                                        shuffle=True,\n","                                                      )\n","  x_train, x_valid, y_train, y_valid= train_test_split(x_train, y_train, test_size=0.19,\n","                                              random_state=random_state,\n","                                              shuffle=True,\n","                                                      )\n","  steps_per_epoch = x_train.shape[0] // batch_size\n","  validation_steps = x_valid.shape[0] // batch_size\n","  print(f\"x_train percentage: {x_train.shape[0]/x.shape[0]}\")\n","  print(f\"x_valid percentage: {x_valid.shape[0]/x.shape[0]}\")\n","  print(f\"x_test percentage: {x_test.shape[0]/x.shape[0]}\")\n","  return get_dataset(x_train, y_train, batch_size, is_training=True, missing_perc=missing_perc), \\\n","        get_dataset(x_valid, y_valid, batch_size, is_training=False, missing_perc=missing_perc), \\\n","        (x_test, y_test), \\\n","        steps_per_epoch, validation_steps\n"],"execution_count":18,"outputs":[]},{"cell_type":"code","metadata":{"id":"_Prw2xApHKP9","executionInfo":{"status":"ok","timestamp":1650082912882,"user_tz":240,"elapsed":5,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}}},"source":["# tf.keras.utils.plot_model(create_model(), show_shapes=True)"],"execution_count":19,"outputs":[]},{"cell_type":"code","metadata":{"id":"6hSe1iJFICcR","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1650082931333,"user_tz":240,"elapsed":18455,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"37aef49b-7f09-4ebb-b2bb-0633c87f10f6"},"source":["create_model().summary()"],"execution_count":20,"outputs":[{"output_type":"stream","name":"stdout","text":["Model: \"model\"\n","__________________________________________________________________________________________________\n"," Layer (type)                   Output Shape         Param #     Connected to                     \n","==================================================================================================\n"," input_2 (InputLayer)           [(None, 28220, 2)]   0           []                               \n","                                                                                                  \n"," tf.__operators__.getitem (Slic  (None, 647, 2)      0           ['input_2[0][0]']                \n"," ingOpLambda)                                                                                     \n","                                                                                                  \n"," tf.__operators__.getitem_1 (Sl  (None, 2419, 2)     0           ['input_2[0][0]']                \n"," icingOpLambda)                                                                                   \n","                                                                                                  \n"," tf.__operators__.getitem_2 (Sl  (None, 511, 2)      0           ['input_2[0][0]']                \n"," icingOpLambda)                                                                                   \n","                                                                                                  \n"," tf.__operators__.getitem_3 (Sl  (None, 2780, 2)     0           ['input_2[0][0]']                \n"," icingOpLambda)                                                                                   \n","                                                                                                  \n"," tf.__operators__.getitem_4 (Sl  (None, 1237, 2)     0           ['input_2[0][0]']                \n"," icingOpLambda)                                                                                   \n","                                                                                                  \n"," tf.__operators__.getitem_5 (Sl  (None, 904, 2)      0           ['input_2[0][0]']                \n"," icingOpLambda)                                                                                   \n","                                                                                                  \n"," tf.__operators__.getitem_6 (Sl  (None, 2348, 2)     0           ['input_2[0][0]']                \n"," icingOpLambda)                                                                                   \n","                                                                                                  \n"," tf.__operators__.getitem_7 (Sl  (None, 1096, 2)     0           ['input_2[0][0]']                \n"," icingOpLambda)                                                                                   \n","                                                                                                  \n"," tf.__operators__.getitem_8 (Sl  (None, 1287, 2)     0           ['input_2[0][0]']                \n"," icingOpLambda)                                                                                   \n","                                                                                                  \n"," tf.__operators__.getitem_9 (Sl  (None, 2015, 2)     0           ['input_2[0][0]']                \n"," icingOpLambda)                                                                                   \n","                                                                                                  \n"," tf.__operators__.getitem_10 (S  (None, 2370, 2)     0           ['input_2[0][0]']                \n"," licingOpLambda)                                                                                  \n","                                                                                                  \n"," tf.__operators__.getitem_11 (S  (None, 2951, 2)     0           ['input_2[0][0]']                \n"," licingOpLambda)                                                                                  \n","                                                                                                  \n"," tf.__operators__.getitem_12 (S  (None, 2057, 2)     0           ['input_2[0][0]']                \n"," licingOpLambda)                                                                                  \n","                                                                                                  \n"," tf.__operators__.getitem_13 (S  (None, 1391, 2)     0           ['input_2[0][0]']                \n"," licingOpLambda)                                                                                  \n","                                                                                                  \n"," tf.__operators__.getitem_14 (S  (None, 2694, 2)     0           ['input_2[0][0]']                \n"," licingOpLambda)                                                                                  \n","                                                                                                  \n"," tf.__operators__.getitem_15 (S  (None, 1513, 2)     0           ['input_2[0][0]']                \n"," licingOpLambda)                                                                                  \n","                                                                                                  \n"," chromosome_block (ChromosomeBl  (None, 129, 32)     74624       ['tf.__operators__.getitem[0][0]'\n"," ock)                                                            ]                                \n","                                                                                                  \n"," chromosome_block_1 (Chromosome  (None, 543, 32)     186704      ['tf.__operators__.getitem_1[0][0\n"," Block)                                                          ]']                              \n","                                                                                                  \n"," chromosome_block_2 (Chromosome  (None, 102, 32)     74624       ['tf.__operators__.getitem_2[0][0\n"," Block)                                                          ]']                              \n","                                                                                                  \n"," chromosome_block_3 (Chromosome  (None, 636, 32)     224064      ['tf.__operators__.getitem_3[0][0\n"," Block)                                                          ]']                              \n","                                                                                                  \n"," chromosome_block_4 (Chromosome  (None, 267, 32)     111984      ['tf.__operators__.getitem_4[0][0\n"," Block)                                                          ]']                              \n","                                                                                                  \n"," chromosome_block_5 (Chromosome  (None, 180, 32)     74624       ['tf.__operators__.getitem_5[0][0\n"," Block)                                                          ]']                              \n","                                                                                                  \n"," chromosome_block_6 (Chromosome  (None, 529, 32)     186704      ['tf.__operators__.getitem_6[0][0\n"," Block)                                                          ]']                              \n","                                                                                                  \n"," chromosome_block_7 (Chromosome  (None, 219, 32)     74624       ['tf.__operators__.getitem_7[0][0\n"," Block)                                                          ]']                              \n","                                                                                                  \n"," chromosome_block_8 (Chromosome  (None, 277, 32)     111984      ['tf.__operators__.getitem_8[0][0\n"," Block)                                                          ]']                              \n","                                                                                                  \n"," chromosome_block_9 (Chromosome  (None, 443, 32)     149344      ['tf.__operators__.getitem_9[0][0\n"," Block)                                                          ]']                              \n","                                                                                                  \n"," chromosome_block_10 (Chromosom  (None, 534, 32)     186704      ['tf.__operators__.getitem_10[0][\n"," eBlock)                                                         0]']                             \n","                                                                                                  \n"," chromosome_block_11 (Chromosom  (None, 670, 32)     224064      ['tf.__operators__.getitem_11[0][\n"," eBlock)                                                         0]']                             \n","                                                                                                  \n"," chromosome_block_12 (Chromosom  (None, 451, 32)     149344      ['tf.__operators__.getitem_12[0][\n"," eBlock)                                                         0]']                             \n","                                                                                                  \n"," chromosome_block_13 (Chromosom  (None, 298, 32)     111984      ['tf.__operators__.getitem_13[0][\n"," eBlock)                                                         0]']                             \n","                                                                                                  \n"," chromosome_block_14 (Chromosom  (None, 618, 32)     224064      ['tf.__operators__.getitem_14[0][\n"," eBlock)                                                         0]']                             \n","                                                                                                  \n"," chromosome_block_15 (Chromosom  (None, 322, 32)     111984      ['tf.__operators__.getitem_15[0][\n"," eBlock)                                                         0]']                             \n","                                                                                                  \n"," concatenate_16 (Concatenate)   (None, 6218, 32)     0           ['chromosome_block[0][0]',       \n","                                                                  'chromosome_block_1[0][0]',     \n","                                                                  'chromosome_block_2[0][0]',     \n","                                                                  'chromosome_block_3[0][0]',     \n","                                                                  'chromosome_block_4[0][0]',     \n","                                                                  'chromosome_block_5[0][0]',     \n","                                                                  'chromosome_block_6[0][0]',     \n","                                                                  'chromosome_block_7[0][0]',     \n","                                                                  'chromosome_block_8[0][0]',     \n","                                                                  'chromosome_block_9[0][0]',     \n","                                                                  'chromosome_block_10[0][0]',    \n","                                                                  'chromosome_block_11[0][0]',    \n","                                                                  'chromosome_block_12[0][0]',    \n","                                                                  'chromosome_block_13[0][0]',    \n","                                                                  'chromosome_block_14[0][0]',    \n","                                                                  'chromosome_block_15[0][0]']    \n","                                                                                                  \n"," conv1d_45 (Conv1D)             (None, 6218, 64)     10304       ['concatenate_16[0][0]']         \n","                                                                                                  \n"," batch_normalization_199 (Batch  (None, 6218, 64)    256         ['conv1d_45[0][0]']              \n"," Normalization)                                                                                   \n","                                                                                                  \n"," depthwise_conv1d (DepthwiseCon  (None, 6218, 64)    8256        ['batch_normalization_199[0][0]']\n"," v1D)                                                                                             \n","                                                                                                  \n"," activation (Activation)        (None, 6218, 64)     0           ['depthwise_conv1d[0][0]']       \n","                                                                                                  \n"," batch_normalization_200 (Batch  (None, 6218, 64)    256         ['activation[0][0]']             \n"," Normalization)                                                                                   \n","                                                                                                  \n"," depthwise_conv1d_1 (DepthwiseC  (None, 6218, 64)    8256        ['batch_normalization_200[0][0]']\n"," onv1D)                                                                                           \n","                                                                                                  \n"," activation_1 (Activation)      (None, 6218, 64)     0           ['depthwise_conv1d_1[0][0]']     \n","                                                                                                  \n"," dropout_244 (Dropout)          (None, 6218, 64)     0           ['activation_1[0][0]']           \n","                                                                                                  \n"," flatten (Flatten)              (None, 397952)       0           ['dropout_244[0][0]']            \n","                                                                                                  \n"," dense_366 (Dense)              (None, 64)           25468992    ['flatten[0][0]']                \n","                                                                                                  \n"," dropout_245 (Dropout)          (None, 64)           0           ['dense_366[0][0]']              \n","                                                                                                  \n"," dense_367 (Dense)              (None, 1)            65          ['dropout_245[0][0]']            \n","                                                                                                  \n","==================================================================================================\n","Total params: 27,773,809\n","Trainable params: 27,760,817\n","Non-trainable params: 12,992\n","__________________________________________________________________________________________________\n"]}]},{"cell_type":"code","source":["def cal_prob(predict_missing_onehot):\n","    # calcaulate the probility of genotype 0, 1, 2\n","    predict_prob = predict_missing_onehot[:,:,:2] / predict_missing_onehot[:,:,:2].sum(axis=2, keepdims=True)\n","    return predict_prob[0]\n","# cal_prob(x_test[:10])\n","\n","# tf.keras.utils.plot_model(create_model(), show_shapes=True)\n","missing_perc = 0.05\n","\n","accuracies = []\n","for random_state in range(8, 10):\n","  print(f\"Training using seed {random_state}:\")\n","  train_dataset, valid_dataset, test_dataset, steps_per_epoch, validation_steps = get_three_sets(x, batch_size, random_state, missing_perc)\n","\n","  # Create a learning rate scheduler callback.\n","  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(\n","      monitor=\"val_loss\", factor=0.8, patience=3\n","  )\n","\n","  # Create an early stopping callback.\n","  early_stopping = tf.keras.callbacks.EarlyStopping(\n","      monitor=\"val_loss\", patience=10, restore_best_weights=True\n","  )\n","\n","  num_epochs = 100\n","\n","  K.clear_session()\n","  with strategy.scope():\n","    \n","    model = create_model()\n","    # tf.keras.utils.plot_model(model, show_shapes=True)\n","    optimizer = tfa.optimizers.LAMB(\n","          learning_rate=learning_rate,\n","          # weight_decay_rate=weight_decay,\n","      )\n","\n","    model.compile(optimizer, loss='mse')\n","    # model.compile(optimizer, loss=tf.keras.losses.CategoricalCrossentropy(), metrics='accuracy')\n","\n","    history = model.fit(\n","        train_dataset,\n","        steps_per_epoch=steps_per_epoch,\n","        validation_data=valid_dataset,\n","        validation_steps=validation_steps,\n","        epochs=num_epochs,\n","        verbose=2,\n","        callbacks=[early_stopping, reduce_lr]\n","    )\n","    predict_data = model.predict(test_dataset[0])\n","    print(\"Test MSE:\", mean_squared_error(test_dataset[1], predict_data))\n","\n","    print(\"==========================================================\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"enzaO4cfCeb6","executionInfo":{"status":"ok","timestamp":1650127531387,"user_tz":240,"elapsed":3208043,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"f5f4f39e-b01f-4e0d-8397-05f2a8352fd6"},"execution_count":19,"outputs":[{"output_type":"stream","name":"stdout","text":["Training using seed 8:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n","43/43 - 855s - loss: 0.0181 - val_loss: 0.0206 - lr: 0.0010 - 855s/epoch - 20s/step\n","Epoch 2/100\n","43/43 - 35s - loss: 0.0116 - val_loss: 0.0141 - lr: 0.0010 - 35s/epoch - 814ms/step\n","Epoch 3/100\n","43/43 - 35s - loss: 0.0098 - val_loss: 0.0128 - lr: 0.0010 - 35s/epoch - 817ms/step\n","Epoch 4/100\n","43/43 - 35s - loss: 0.0088 - val_loss: 0.0123 - lr: 0.0010 - 35s/epoch - 825ms/step\n","Epoch 5/100\n","43/43 - 35s - loss: 0.0078 - val_loss: 0.0110 - lr: 0.0010 - 35s/epoch - 812ms/step\n","Epoch 6/100\n","43/43 - 36s - loss: 0.0069 - val_loss: 0.0106 - lr: 0.0010 - 36s/epoch - 830ms/step\n","Epoch 7/100\n","43/43 - 35s - loss: 0.0064 - val_loss: 0.0099 - lr: 0.0010 - 35s/epoch - 814ms/step\n","Epoch 8/100\n","43/43 - 28s - loss: 0.0054 - val_loss: 0.0099 - lr: 0.0010 - 28s/epoch - 660ms/step\n","Epoch 9/100\n","43/43 - 28s - loss: 0.0050 - val_loss: 0.0101 - lr: 0.0010 - 28s/epoch - 660ms/step\n","Epoch 10/100\n","43/43 - 29s - loss: 0.0046 - val_loss: 0.0102 - lr: 0.0010 - 29s/epoch - 664ms/step\n","Epoch 11/100\n","43/43 - 29s - loss: 0.0039 - val_loss: 0.0104 - lr: 8.0000e-04 - 29s/epoch - 663ms/step\n","Epoch 12/100\n","43/43 - 29s - loss: 0.0034 - val_loss: 0.0103 - lr: 8.0000e-04 - 29s/epoch - 665ms/step\n","Epoch 13/100\n","43/43 - 29s - loss: 0.0031 - val_loss: 0.0103 - lr: 8.0000e-04 - 29s/epoch - 664ms/step\n","Epoch 14/100\n","43/43 - 29s - loss: 0.0027 - val_loss: 0.0107 - lr: 6.4000e-04 - 29s/epoch - 663ms/step\n","Epoch 15/100\n","43/43 - 28s - loss: 0.0025 - val_loss: 0.0110 - lr: 6.4000e-04 - 28s/epoch - 660ms/step\n","Epoch 16/100\n","43/43 - 29s - loss: 0.0023 - val_loss: 0.0112 - lr: 6.4000e-04 - 29s/epoch - 669ms/step\n","Epoch 17/100\n","43/43 - 58s - loss: 0.0022 - val_loss: 0.0110 - lr: 5.1200e-04 - 58s/epoch - 1s/step\n","Test MSE: 0.009809045596774178\n","==========================================================\n","Training using seed 9:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n","43/43 - 890s - loss: 0.0174 - val_loss: 0.0206 - lr: 0.0010 - 890s/epoch - 21s/step\n","Epoch 2/100\n","43/43 - 35s - loss: 0.0110 - val_loss: 0.0149 - lr: 0.0010 - 35s/epoch - 814ms/step\n","Epoch 3/100\n","43/43 - 35s - loss: 0.0098 - val_loss: 0.0131 - lr: 0.0010 - 35s/epoch - 814ms/step\n","Epoch 4/100\n","43/43 - 35s - loss: 0.0087 - val_loss: 0.0119 - lr: 0.0010 - 35s/epoch - 813ms/step\n","Epoch 5/100\n","43/43 - 35s - loss: 0.0077 - val_loss: 0.0119 - lr: 0.0010 - 35s/epoch - 813ms/step\n","Epoch 6/100\n","43/43 - 28s - loss: 0.0068 - val_loss: 0.0125 - lr: 0.0010 - 28s/epoch - 663ms/step\n","Epoch 7/100\n","43/43 - 29s - loss: 0.0062 - val_loss: 0.0122 - lr: 0.0010 - 29s/epoch - 667ms/step\n","Epoch 8/100\n","43/43 - 35s - loss: 0.0052 - val_loss: 0.0114 - lr: 8.0000e-04 - 35s/epoch - 817ms/step\n","Epoch 9/100\n","43/43 - 35s - loss: 0.0046 - val_loss: 0.0109 - lr: 8.0000e-04 - 35s/epoch - 816ms/step\n","Epoch 10/100\n","43/43 - 35s - loss: 0.0043 - val_loss: 0.0103 - lr: 8.0000e-04 - 35s/epoch - 817ms/step\n","Epoch 11/100\n","43/43 - 28s - loss: 0.0038 - val_loss: 0.0117 - lr: 8.0000e-04 - 28s/epoch - 662ms/step\n","Epoch 12/100\n","43/43 - 28s - loss: 0.0035 - val_loss: 0.0116 - lr: 8.0000e-04 - 28s/epoch - 660ms/step\n","Epoch 13/100\n","43/43 - 28s - loss: 0.0033 - val_loss: 0.0113 - lr: 8.0000e-04 - 28s/epoch - 660ms/step\n","Epoch 14/100\n","43/43 - 28s - loss: 0.0028 - val_loss: 0.0123 - lr: 6.4000e-04 - 28s/epoch - 663ms/step\n","Epoch 15/100\n","43/43 - 28s - loss: 0.0025 - val_loss: 0.0123 - lr: 6.4000e-04 - 28s/epoch - 662ms/step\n","Epoch 16/100\n","43/43 - 28s - loss: 0.0024 - val_loss: 0.0121 - lr: 6.4000e-04 - 28s/epoch - 662ms/step\n","Epoch 17/100\n","43/43 - 28s - loss: 0.0023 - val_loss: 0.0128 - lr: 5.1200e-04 - 28s/epoch - 662ms/step\n","Epoch 18/100\n","43/43 - 28s - loss: 0.0020 - val_loss: 0.0127 - lr: 5.1200e-04 - 28s/epoch - 661ms/step\n","Epoch 19/100\n","43/43 - 29s - loss: 0.0020 - val_loss: 0.0128 - lr: 5.1200e-04 - 29s/epoch - 663ms/step\n","Epoch 20/100\n","43/43 - 56s - loss: 0.0018 - val_loss: 0.0121 - lr: 4.0960e-04 - 56s/epoch - 1s/step\n","Test MSE: 0.010515812599048383\n","==========================================================\n"]}]},{"cell_type":"code","source":["def cal_prob(predict_missing_onehot):\n","    # calcaulate the probility of genotype 0, 1, 2\n","    predict_prob = predict_missing_onehot[:,:,:2] / predict_missing_onehot[:,:,:2].sum(axis=2, keepdims=True)\n","    return predict_prob[0]\n","# cal_prob(x_test[:10])\n","\n","# tf.keras.utils.plot_model(create_model(), show_shapes=True)\n","missing_perc = 0.05\n","\n","accuracies = []\n","for random_state in range(6, 10):\n","  print(f\"Training using seed {random_state}:\")\n","  train_dataset, valid_dataset, test_dataset, steps_per_epoch, validation_steps = get_three_sets(x, batch_size, random_state, missing_perc)\n","\n","  # Create a learning rate scheduler callback.\n","  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(\n","      monitor=\"val_loss\", factor=0.8, patience=3\n","  )\n","\n","  # Create an early stopping callback.\n","  early_stopping = tf.keras.callbacks.EarlyStopping(\n","      monitor=\"val_loss\", patience=10, restore_best_weights=True\n","  )\n","\n","  num_epochs = 100\n","\n","  K.clear_session()\n","  with strategy.scope():\n","    \n","    model = create_model()\n","    # tf.keras.utils.plot_model(model, show_shapes=True)\n","    optimizer = tfa.optimizers.LAMB(\n","          learning_rate=learning_rate,\n","          # weight_decay_rate=weight_decay,\n","      )\n","\n","    model.compile(optimizer, loss='mse')\n","    # model.compile(optimizer, loss=tf.keras.losses.CategoricalCrossentropy(), metrics='accuracy')\n","\n","    history = model.fit(\n","        train_dataset,\n","        steps_per_epoch=steps_per_epoch,\n","        validation_data=valid_dataset,\n","        validation_steps=validation_steps,\n","        epochs=num_epochs,\n","        verbose=2,\n","        callbacks=[early_stopping, reduce_lr]\n","    )\n","    predict_data = model.predict(test_dataset[0])\n","    print(\"Test MSE:\", mean_squared_error(test_dataset[1], predict_data))\n","\n","    print(\"==========================================================\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"TfPdHNfpifA1","executionInfo":{"status":"error","timestamp":1650110168564,"user_tz":240,"elapsed":4137333,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"4ce2ca90-26d8-4b50-c6e8-1c29ad6d641a"},"execution_count":19,"outputs":[{"output_type":"stream","name":"stdout","text":["Training using seed 6:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n","43/43 - 857s - loss: 0.0177 - val_loss: 0.0220 - lr: 0.0010 - 857s/epoch - 20s/step\n","Epoch 2/100\n","43/43 - 34s - loss: 0.0117 - val_loss: 0.0159 - lr: 0.0010 - 34s/epoch - 801ms/step\n","Epoch 3/100\n","43/43 - 35s - loss: 0.0095 - val_loss: 0.0142 - lr: 0.0010 - 35s/epoch - 804ms/step\n","Epoch 4/100\n","43/43 - 35s - loss: 0.0088 - val_loss: 0.0131 - lr: 0.0010 - 35s/epoch - 807ms/step\n","Epoch 5/100\n","43/43 - 35s - loss: 0.0082 - val_loss: 0.0120 - lr: 0.0010 - 35s/epoch - 808ms/step\n","Epoch 6/100\n","43/43 - 35s - loss: 0.0069 - val_loss: 0.0113 - lr: 0.0010 - 35s/epoch - 805ms/step\n","Epoch 7/100\n","43/43 - 35s - loss: 0.0068 - val_loss: 0.0113 - lr: 0.0010 - 35s/epoch - 807ms/step\n","Epoch 8/100\n","43/43 - 35s - loss: 0.0057 - val_loss: 0.0108 - lr: 0.0010 - 35s/epoch - 810ms/step\n","Epoch 9/100\n","43/43 - 35s - loss: 0.0054 - val_loss: 0.0105 - lr: 0.0010 - 35s/epoch - 806ms/step\n","Epoch 10/100\n","43/43 - 35s - loss: 0.0046 - val_loss: 0.0103 - lr: 0.0010 - 35s/epoch - 808ms/step\n","Epoch 11/100\n","43/43 - 35s - loss: 0.0041 - val_loss: 0.0099 - lr: 0.0010 - 35s/epoch - 812ms/step\n","Epoch 12/100\n","43/43 - 28s - loss: 0.0036 - val_loss: 0.0104 - lr: 0.0010 - 28s/epoch - 663ms/step\n","Epoch 13/100\n","43/43 - 28s - loss: 0.0034 - val_loss: 0.0102 - lr: 0.0010 - 28s/epoch - 662ms/step\n","Epoch 14/100\n","43/43 - 29s - loss: 0.0030 - val_loss: 0.0101 - lr: 0.0010 - 29s/epoch - 664ms/step\n","Epoch 15/100\n","43/43 - 29s - loss: 0.0026 - val_loss: 0.0103 - lr: 8.0000e-04 - 29s/epoch - 664ms/step\n","Epoch 16/100\n","43/43 - 29s - loss: 0.0024 - val_loss: 0.0101 - lr: 8.0000e-04 - 29s/epoch - 663ms/step\n","Epoch 17/100\n","43/43 - 29s - loss: 0.0021 - val_loss: 0.0108 - lr: 8.0000e-04 - 29s/epoch - 664ms/step\n","Epoch 18/100\n","43/43 - 29s - loss: 0.0021 - val_loss: 0.0105 - lr: 6.4000e-04 - 29s/epoch - 664ms/step\n","Epoch 19/100\n","43/43 - 29s - loss: 0.0018 - val_loss: 0.0102 - lr: 6.4000e-04 - 29s/epoch - 663ms/step\n","Epoch 20/100\n","43/43 - 29s - loss: 0.0018 - val_loss: 0.0102 - lr: 6.4000e-04 - 29s/epoch - 663ms/step\n","Epoch 21/100\n","43/43 - 56s - loss: 0.0016 - val_loss: 0.0103 - lr: 5.1200e-04 - 56s/epoch - 1s/step\n","Test MSE: 0.010125920863930908\n","==========================================================\n","Training using seed 7:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n","43/43 - 875s - loss: 0.0174 - val_loss: 0.0219 - lr: 0.0010 - 875s/epoch - 20s/step\n","Epoch 2/100\n","43/43 - 35s - loss: 0.0111 - val_loss: 0.0150 - lr: 0.0010 - 35s/epoch - 805ms/step\n","Epoch 3/100\n","43/43 - 35s - loss: 0.0099 - val_loss: 0.0139 - lr: 0.0010 - 35s/epoch - 808ms/step\n","Epoch 4/100\n","43/43 - 35s - loss: 0.0084 - val_loss: 0.0133 - lr: 0.0010 - 35s/epoch - 807ms/step\n","Epoch 5/100\n","43/43 - 28s - loss: 0.0078 - val_loss: 0.0147 - lr: 0.0010 - 28s/epoch - 662ms/step\n","Epoch 6/100\n","43/43 - 28s - loss: 0.0071 - val_loss: 0.0140 - lr: 0.0010 - 28s/epoch - 662ms/step\n","Epoch 7/100\n","43/43 - 35s - loss: 0.0063 - val_loss: 0.0131 - lr: 0.0010 - 35s/epoch - 813ms/step\n","Epoch 8/100\n","43/43 - 28s - loss: 0.0058 - val_loss: 0.0139 - lr: 0.0010 - 28s/epoch - 662ms/step\n","Epoch 9/100\n","43/43 - 35s - loss: 0.0051 - val_loss: 0.0110 - lr: 0.0010 - 35s/epoch - 808ms/step\n","Epoch 10/100\n","43/43 - 35s - loss: 0.0049 - val_loss: 0.0105 - lr: 0.0010 - 35s/epoch - 807ms/step\n","Epoch 11/100\n","43/43 - 28s - loss: 0.0041 - val_loss: 0.0119 - lr: 0.0010 - 28s/epoch - 662ms/step\n","Epoch 12/100\n","43/43 - 29s - loss: 0.0036 - val_loss: 0.0123 - lr: 0.0010 - 29s/epoch - 663ms/step\n","Epoch 13/100\n","43/43 - 28s - loss: 0.0033 - val_loss: 0.0134 - lr: 0.0010 - 28s/epoch - 661ms/step\n","Epoch 14/100\n","43/43 - 29s - loss: 0.0031 - val_loss: 0.0116 - lr: 8.0000e-04 - 29s/epoch - 667ms/step\n","Epoch 15/100\n","43/43 - 28s - loss: 0.0027 - val_loss: 0.0122 - lr: 8.0000e-04 - 28s/epoch - 663ms/step\n","Epoch 16/100\n","43/43 - 29s - loss: 0.0024 - val_loss: 0.0124 - lr: 8.0000e-04 - 29s/epoch - 663ms/step\n","Epoch 17/100\n","43/43 - 29s - loss: 0.0021 - val_loss: 0.0132 - lr: 6.4000e-04 - 29s/epoch - 663ms/step\n","Epoch 18/100\n","43/43 - 29s - loss: 0.0020 - val_loss: 0.0127 - lr: 6.4000e-04 - 29s/epoch - 663ms/step\n","Epoch 19/100\n","43/43 - 29s - loss: 0.0020 - val_loss: 0.0124 - lr: 6.4000e-04 - 29s/epoch - 665ms/step\n","Epoch 20/100\n","43/43 - 56s - loss: 0.0018 - val_loss: 0.0118 - lr: 5.1200e-04 - 56s/epoch - 1s/step\n","Test MSE: 0.010781640534057555\n","==========================================================\n","Training using seed 8:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n"]},{"output_type":"error","ename":"ResourceExhaustedError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mResourceExhaustedError\u001b[0m                    Traceback (most recent call last)","\u001b[0;32m<ipython-input-19-ac0d48ead2cf>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     45\u001b[0m         \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m         \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m         \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mearly_stopping\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_lr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     48\u001b[0m     )\n\u001b[1;32m     49\u001b[0m     \u001b[0mpredict_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_dataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     65\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     66\u001b[0m       \u001b[0mfiltered_tb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 67\u001b[0;31m       \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     68\u001b[0m     \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     69\u001b[0m       \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py\u001b[0m in \u001b[0;36m_numpy\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1189\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_numpy_internal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1190\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1191\u001b[0;31m       \u001b[0;32mraise\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_status_to_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1193\u001b[0m   \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mResourceExhaustedError\u001b[0m: 9 root error(s) found.\n  (0) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1816221}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execute_6_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/control_after/_1/_83]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (1) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1816221}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execute_6_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_3_0/_63]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (2) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1816221}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execute_6_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (3) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1816221}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execute_5_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (4) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1816221}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execute_2_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (5) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1816221}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{ ... [truncated]"]}]},{"cell_type":"code","source":["def cal_prob(predict_missing_onehot):\n","    # calcaulate the probility of genotype 0, 1, 2\n","    predict_prob = predict_missing_onehot[:,:,:2] / predict_missing_onehot[:,:,:2].sum(axis=2, keepdims=True)\n","    return predict_prob[0]\n","# cal_prob(x_test[:10])\n","\n","# tf.keras.utils.plot_model(create_model(), show_shapes=True)\n","missing_perc = 0.05\n","\n","accuracies = []\n","for random_state in range(4, 10):\n","  print(f\"Training using seed {random_state}:\")\n","  train_dataset, valid_dataset, test_dataset, steps_per_epoch, validation_steps = get_three_sets(x, batch_size, random_state, missing_perc)\n","\n","  # Create a learning rate scheduler callback.\n","  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(\n","      monitor=\"val_loss\", factor=0.8, patience=3\n","  )\n","\n","  # Create an early stopping callback.\n","  early_stopping = tf.keras.callbacks.EarlyStopping(\n","      monitor=\"val_loss\", patience=10, restore_best_weights=True\n","  )\n","\n","  num_epochs = 100\n","\n","  K.clear_session()\n","  with strategy.scope():\n","    \n","    model = create_model()\n","    # tf.keras.utils.plot_model(model, show_shapes=True)\n","    optimizer = tfa.optimizers.LAMB(\n","          learning_rate=learning_rate,\n","          # weight_decay_rate=weight_decay,\n","      )\n","\n","    model.compile(optimizer, loss='mse')\n","    # model.compile(optimizer, loss=tf.keras.losses.CategoricalCrossentropy(), metrics='accuracy')\n","\n","    history = model.fit(\n","        train_dataset,\n","        steps_per_epoch=steps_per_epoch,\n","        validation_data=valid_dataset,\n","        validation_steps=validation_steps,\n","        epochs=num_epochs,\n","        verbose=2,\n","        callbacks=[early_stopping, reduce_lr]\n","    )\n","    predict_data = model.predict(test_dataset[0])\n","    print(\"Test MSE:\", mean_squared_error(test_dataset[1], predict_data))\n","\n","    print(\"==========================================================\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"tItUqgD8N5-J","executionInfo":{"status":"error","timestamp":1650097665732,"user_tz":240,"elapsed":4347020,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"1bc10b43-2ea7-4e78-f4f0-019c17ab7e0c"},"execution_count":19,"outputs":[{"output_type":"stream","name":"stdout","text":["Training using seed 4:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n","43/43 - 923s - loss: 0.0176 - val_loss: 0.0258 - lr: 0.0010 - 923s/epoch - 21s/step\n","Epoch 2/100\n","43/43 - 36s - loss: 0.0110 - val_loss: 0.0209 - lr: 0.0010 - 36s/epoch - 848ms/step\n","Epoch 3/100\n","43/43 - 36s - loss: 0.0096 - val_loss: 0.0154 - lr: 0.0010 - 36s/epoch - 842ms/step\n","Epoch 4/100\n","43/43 - 29s - loss: 0.0089 - val_loss: 0.0173 - lr: 0.0010 - 29s/epoch - 673ms/step\n","Epoch 5/100\n","43/43 - 36s - loss: 0.0081 - val_loss: 0.0139 - lr: 0.0010 - 36s/epoch - 847ms/step\n","Epoch 6/100\n","43/43 - 36s - loss: 0.0073 - val_loss: 0.0135 - lr: 0.0010 - 36s/epoch - 843ms/step\n","Epoch 7/100\n","43/43 - 36s - loss: 0.0064 - val_loss: 0.0127 - lr: 0.0010 - 36s/epoch - 842ms/step\n","Epoch 8/100\n","43/43 - 29s - loss: 0.0059 - val_loss: 0.0131 - lr: 0.0010 - 29s/epoch - 672ms/step\n","Epoch 9/100\n","43/43 - 29s - loss: 0.0052 - val_loss: 0.0128 - lr: 0.0010 - 29s/epoch - 684ms/step\n","Epoch 10/100\n","43/43 - 36s - loss: 0.0046 - val_loss: 0.0118 - lr: 0.0010 - 36s/epoch - 838ms/step\n","Epoch 11/100\n","43/43 - 36s - loss: 0.0043 - val_loss: 0.0117 - lr: 0.0010 - 36s/epoch - 841ms/step\n","Epoch 12/100\n","43/43 - 29s - loss: 0.0038 - val_loss: 0.0120 - lr: 0.0010 - 29s/epoch - 668ms/step\n","Epoch 13/100\n","43/43 - 29s - loss: 0.0034 - val_loss: 0.0126 - lr: 0.0010 - 29s/epoch - 673ms/step\n","Epoch 14/100\n","43/43 - 29s - loss: 0.0029 - val_loss: 0.0131 - lr: 8.0000e-04 - 29s/epoch - 673ms/step\n","Epoch 15/100\n","43/43 - 29s - loss: 0.0026 - val_loss: 0.0147 - lr: 8.0000e-04 - 29s/epoch - 669ms/step\n","Epoch 16/100\n","43/43 - 29s - loss: 0.0026 - val_loss: 0.0133 - lr: 8.0000e-04 - 29s/epoch - 676ms/step\n","Epoch 17/100\n","43/43 - 29s - loss: 0.0022 - val_loss: 0.0136 - lr: 6.4000e-04 - 29s/epoch - 672ms/step\n","Epoch 18/100\n","43/43 - 29s - loss: 0.0020 - val_loss: 0.0145 - lr: 6.4000e-04 - 29s/epoch - 671ms/step\n","Epoch 19/100\n","43/43 - 29s - loss: 0.0020 - val_loss: 0.0133 - lr: 6.4000e-04 - 29s/epoch - 670ms/step\n","Epoch 20/100\n","43/43 - 29s - loss: 0.0017 - val_loss: 0.0129 - lr: 5.1200e-04 - 29s/epoch - 669ms/step\n","Epoch 21/100\n","43/43 - 57s - loss: 0.0016 - val_loss: 0.0130 - lr: 5.1200e-04 - 57s/epoch - 1s/step\n","Test MSE: 0.010075070526222661\n","==========================================================\n","Training using seed 5:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n","43/43 - 929s - loss: 0.0178 - val_loss: 0.0237 - lr: 0.0010 - 929s/epoch - 22s/step\n","Epoch 2/100\n","43/43 - 36s - loss: 0.0111 - val_loss: 0.0164 - lr: 0.0010 - 36s/epoch - 846ms/step\n","Epoch 3/100\n","43/43 - 36s - loss: 0.0096 - val_loss: 0.0155 - lr: 0.0010 - 36s/epoch - 845ms/step\n","Epoch 4/100\n","43/43 - 36s - loss: 0.0084 - val_loss: 0.0144 - lr: 0.0010 - 36s/epoch - 846ms/step\n","Epoch 5/100\n","43/43 - 29s - loss: 0.0073 - val_loss: 0.0147 - lr: 0.0010 - 29s/epoch - 671ms/step\n","Epoch 6/100\n","43/43 - 29s - loss: 0.0067 - val_loss: 0.0153 - lr: 0.0010 - 29s/epoch - 673ms/step\n","Epoch 7/100\n","43/43 - 36s - loss: 0.0063 - val_loss: 0.0122 - lr: 0.0010 - 36s/epoch - 843ms/step\n","Epoch 8/100\n","43/43 - 36s - loss: 0.0053 - val_loss: 0.0115 - lr: 0.0010 - 36s/epoch - 841ms/step\n","Epoch 9/100\n","43/43 - 29s - loss: 0.0049 - val_loss: 0.0120 - lr: 0.0010 - 29s/epoch - 670ms/step\n","Epoch 10/100\n","43/43 - 36s - loss: 0.0043 - val_loss: 0.0115 - lr: 0.0010 - 36s/epoch - 845ms/step\n","Epoch 11/100\n","43/43 - 29s - loss: 0.0038 - val_loss: 0.0122 - lr: 0.0010 - 29s/epoch - 673ms/step\n","Epoch 12/100\n","43/43 - 29s - loss: 0.0034 - val_loss: 0.0123 - lr: 8.0000e-04 - 29s/epoch - 672ms/step\n","Epoch 13/100\n","43/43 - 29s - loss: 0.0029 - val_loss: 0.0124 - lr: 8.0000e-04 - 29s/epoch - 673ms/step\n","Epoch 14/100\n","43/43 - 29s - loss: 0.0027 - val_loss: 0.0130 - lr: 8.0000e-04 - 29s/epoch - 670ms/step\n","Epoch 15/100\n","43/43 - 29s - loss: 0.0024 - val_loss: 0.0133 - lr: 6.4000e-04 - 29s/epoch - 675ms/step\n","Epoch 16/100\n","43/43 - 29s - loss: 0.0023 - val_loss: 0.0130 - lr: 6.4000e-04 - 29s/epoch - 670ms/step\n","Epoch 17/100\n","43/43 - 29s - loss: 0.0020 - val_loss: 0.0125 - lr: 6.4000e-04 - 29s/epoch - 672ms/step\n","Epoch 18/100\n","43/43 - 29s - loss: 0.0018 - val_loss: 0.0122 - lr: 5.1200e-04 - 29s/epoch - 674ms/step\n","Epoch 19/100\n","43/43 - 29s - loss: 0.0018 - val_loss: 0.0126 - lr: 5.1200e-04 - 29s/epoch - 670ms/step\n","Epoch 20/100\n","43/43 - 56s - loss: 0.0017 - val_loss: 0.0130 - lr: 5.1200e-04 - 56s/epoch - 1s/step\n","Test MSE: 0.010832743714819\n","==========================================================\n","Training using seed 6:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n"]},{"output_type":"error","ename":"ResourceExhaustedError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mResourceExhaustedError\u001b[0m                    Traceback (most recent call last)","\u001b[0;32m<ipython-input-19-cf271c68d9c6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     45\u001b[0m         \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m         \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m         \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mearly_stopping\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_lr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     48\u001b[0m     )\n\u001b[1;32m     49\u001b[0m     \u001b[0mpredict_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_dataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     65\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     66\u001b[0m       \u001b[0mfiltered_tb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 67\u001b[0;31m       \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     68\u001b[0m     \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     69\u001b[0m       \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py\u001b[0m in \u001b[0;36m_numpy\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1189\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_numpy_internal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1190\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1191\u001b[0;31m       \u001b[0;32mraise\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_status_to_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1193\u001b[0m   \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mResourceExhaustedError\u001b[0m: 9 root error(s) found.\n  (0) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1788474}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.73G reservable.\n\t [[{{node cluster_train_function/_execute_0_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_6_0/_75]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_7_0/_78]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (1) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1788474}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.73G reservable.\n\t [[{{node cluster_train_function/_execute_0_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_6_0/_75]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_6_0/_74]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (2) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1788474}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.73G reservable.\n\t [[{{node cluster_train_function/_execute_0_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_6_0/_75]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_5_0/_70]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (3) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1788474}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. Th ... [truncated]"]}]},{"cell_type":"code","source":["def cal_prob(predict_missing_onehot):\n","    # calcaulate the probility of genotype 0, 1, 2\n","    predict_prob = predict_missing_onehot[:,:,:2] / predict_missing_onehot[:,:,:2].sum(axis=2, keepdims=True)\n","    return predict_prob[0]\n","# cal_prob(x_test[:10])\n","\n","# tf.keras.utils.plot_model(create_model(), show_shapes=True)\n","missing_perc = 0.05\n","\n","accuracies = []\n","for random_state in range(2, 10):\n","  print(f\"Training using seed {random_state}:\")\n","  train_dataset, valid_dataset, test_dataset, steps_per_epoch, validation_steps = get_three_sets(x, batch_size, random_state, missing_perc)\n","\n","  # Create a learning rate scheduler callback.\n","  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(\n","      monitor=\"val_loss\", factor=0.8, patience=3\n","  )\n","\n","  # Create an early stopping callback.\n","  early_stopping = tf.keras.callbacks.EarlyStopping(\n","      monitor=\"val_loss\", patience=10, restore_best_weights=True\n","  )\n","\n","  num_epochs = 100\n","\n","  K.clear_session()\n","  with strategy.scope():\n","    \n","    model = create_model()\n","    # tf.keras.utils.plot_model(model, show_shapes=True)\n","    optimizer = tfa.optimizers.LAMB(\n","          learning_rate=learning_rate,\n","          # weight_decay_rate=weight_decay,\n","      )\n","\n","    model.compile(optimizer, loss='mse')\n","    # model.compile(optimizer, loss=tf.keras.losses.CategoricalCrossentropy(), metrics='accuracy')\n","\n","    history = model.fit(\n","        train_dataset,\n","        steps_per_epoch=steps_per_epoch,\n","        validation_data=valid_dataset,\n","        validation_steps=validation_steps,\n","        epochs=num_epochs,\n","        verbose=2,\n","        callbacks=[early_stopping, reduce_lr]\n","    )\n","    predict_data = model.predict(test_dataset[0])\n","    print(\"Test MSE:\", mean_squared_error(test_dataset[1], predict_data))\n","\n","    print(\"==========================================================\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"id":"iexrafIy4ul3","executionInfo":{"status":"error","timestamp":1650093147861,"user_tz":240,"elapsed":4102166,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"0120be02-30d1-4f4c-ffe7-68c393c27612"},"execution_count":21,"outputs":[{"output_type":"stream","name":"stdout","text":["Training using seed 2:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n","43/43 - 838s - loss: 0.0179 - val_loss: 0.0235 - lr: 0.0010 - 838s/epoch - 19s/step\n","Epoch 2/100\n","43/43 - 34s - loss: 0.0119 - val_loss: 0.0158 - lr: 0.0010 - 34s/epoch - 797ms/step\n","Epoch 3/100\n","43/43 - 34s - loss: 0.0101 - val_loss: 0.0142 - lr: 0.0010 - 34s/epoch - 796ms/step\n","Epoch 4/100\n","43/43 - 35s - loss: 0.0090 - val_loss: 0.0137 - lr: 0.0010 - 35s/epoch - 804ms/step\n","Epoch 5/100\n","43/43 - 35s - loss: 0.0082 - val_loss: 0.0114 - lr: 0.0010 - 35s/epoch - 808ms/step\n","Epoch 6/100\n","43/43 - 34s - loss: 0.0072 - val_loss: 0.0107 - lr: 0.0010 - 34s/epoch - 798ms/step\n","Epoch 7/100\n","43/43 - 34s - loss: 0.0067 - val_loss: 0.0104 - lr: 0.0010 - 34s/epoch - 796ms/step\n","Epoch 8/100\n","43/43 - 34s - loss: 0.0059 - val_loss: 0.0097 - lr: 0.0010 - 34s/epoch - 791ms/step\n","Epoch 9/100\n","43/43 - 28s - loss: 0.0055 - val_loss: 0.0101 - lr: 0.0010 - 28s/epoch - 656ms/step\n","Epoch 10/100\n","43/43 - 28s - loss: 0.0046 - val_loss: 0.0107 - lr: 0.0010 - 28s/epoch - 659ms/step\n","Epoch 11/100\n","43/43 - 28s - loss: 0.0042 - val_loss: 0.0099 - lr: 0.0010 - 28s/epoch - 661ms/step\n","Epoch 12/100\n","43/43 - 28s - loss: 0.0038 - val_loss: 0.0102 - lr: 8.0000e-04 - 28s/epoch - 663ms/step\n","Epoch 13/100\n","43/43 - 28s - loss: 0.0031 - val_loss: 0.0105 - lr: 8.0000e-04 - 28s/epoch - 660ms/step\n","Epoch 14/100\n","43/43 - 28s - loss: 0.0029 - val_loss: 0.0113 - lr: 8.0000e-04 - 28s/epoch - 662ms/step\n","Epoch 15/100\n","43/43 - 28s - loss: 0.0026 - val_loss: 0.0119 - lr: 6.4000e-04 - 28s/epoch - 657ms/step\n","Epoch 16/100\n","43/43 - 28s - loss: 0.0024 - val_loss: 0.0122 - lr: 6.4000e-04 - 28s/epoch - 662ms/step\n","Epoch 17/100\n","43/43 - 28s - loss: 0.0022 - val_loss: 0.0129 - lr: 6.4000e-04 - 28s/epoch - 659ms/step\n","Epoch 18/100\n","43/43 - 57s - loss: 0.0021 - val_loss: 0.0117 - lr: 5.1200e-04 - 57s/epoch - 1s/step\n","Test MSE: 0.009707802511928017\n","==========================================================\n","Training using seed 3:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n","43/43 - 878s - loss: 0.0177 - val_loss: 0.0274 - lr: 0.0010 - 878s/epoch - 20s/step\n","Epoch 2/100\n","43/43 - 36s - loss: 0.0109 - val_loss: 0.0212 - lr: 0.0010 - 36s/epoch - 839ms/step\n","Epoch 3/100\n","43/43 - 36s - loss: 0.0094 - val_loss: 0.0203 - lr: 0.0010 - 36s/epoch - 840ms/step\n","Epoch 4/100\n","43/43 - 36s - loss: 0.0084 - val_loss: 0.0150 - lr: 0.0010 - 36s/epoch - 841ms/step\n","Epoch 5/100\n","43/43 - 36s - loss: 0.0078 - val_loss: 0.0131 - lr: 0.0010 - 36s/epoch - 840ms/step\n","Epoch 6/100\n","43/43 - 29s - loss: 0.0073 - val_loss: 0.0136 - lr: 0.0010 - 29s/epoch - 670ms/step\n","Epoch 7/100\n","43/43 - 36s - loss: 0.0065 - val_loss: 0.0124 - lr: 0.0010 - 36s/epoch - 843ms/step\n","Epoch 8/100\n","43/43 - 29s - loss: 0.0055 - val_loss: 0.0132 - lr: 0.0010 - 29s/epoch - 668ms/step\n","Epoch 9/100\n","43/43 - 36s - loss: 0.0050 - val_loss: 0.0117 - lr: 0.0010 - 36s/epoch - 839ms/step\n","Epoch 10/100\n","43/43 - 36s - loss: 0.0045 - val_loss: 0.0107 - lr: 0.0010 - 36s/epoch - 840ms/step\n","Epoch 11/100\n","43/43 - 36s - loss: 0.0039 - val_loss: 0.0101 - lr: 0.0010 - 36s/epoch - 841ms/step\n","Epoch 12/100\n","43/43 - 29s - loss: 0.0036 - val_loss: 0.0103 - lr: 0.0010 - 29s/epoch - 668ms/step\n","Epoch 13/100\n","43/43 - 29s - loss: 0.0033 - val_loss: 0.0106 - lr: 0.0010 - 29s/epoch - 669ms/step\n","Epoch 14/100\n","43/43 - 29s - loss: 0.0028 - val_loss: 0.0110 - lr: 0.0010 - 29s/epoch - 669ms/step\n","Epoch 15/100\n","43/43 - 29s - loss: 0.0026 - val_loss: 0.0109 - lr: 8.0000e-04 - 29s/epoch - 670ms/step\n","Epoch 16/100\n","43/43 - 29s - loss: 0.0024 - val_loss: 0.0107 - lr: 8.0000e-04 - 29s/epoch - 672ms/step\n","Epoch 17/100\n","43/43 - 29s - loss: 0.0021 - val_loss: 0.0109 - lr: 8.0000e-04 - 29s/epoch - 673ms/step\n","Epoch 18/100\n","43/43 - 29s - loss: 0.0019 - val_loss: 0.0110 - lr: 6.4000e-04 - 29s/epoch - 671ms/step\n","Epoch 19/100\n","43/43 - 29s - loss: 0.0018 - val_loss: 0.0113 - lr: 6.4000e-04 - 29s/epoch - 668ms/step\n","Epoch 20/100\n","43/43 - 29s - loss: 0.0017 - val_loss: 0.0113 - lr: 6.4000e-04 - 29s/epoch - 673ms/step\n","Epoch 21/100\n","43/43 - 57s - loss: 0.0016 - val_loss: 0.0113 - lr: 5.1200e-04 - 57s/epoch - 1s/step\n","Test MSE: 0.010120811596367043\n","==========================================================\n","Training using seed 4:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n"]},{"output_type":"error","ename":"ResourceExhaustedError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mResourceExhaustedError\u001b[0m                    Traceback (most recent call last)","\u001b[0;32m<ipython-input-21-228f593b2203>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     45\u001b[0m         \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m         \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m         \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mearly_stopping\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_lr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     48\u001b[0m     )\n\u001b[1;32m     49\u001b[0m     \u001b[0mpredict_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_dataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     65\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     66\u001b[0m       \u001b[0mfiltered_tb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 67\u001b[0;31m       \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     68\u001b[0m     \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     69\u001b[0m       \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py\u001b[0m in \u001b[0;36m_numpy\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1189\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_numpy_internal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1190\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1191\u001b[0;31m       \u001b[0;32mraise\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_status_to_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1193\u001b[0m   \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mResourceExhaustedError\u001b[0m: 9 root error(s) found.\n  (0) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1804713}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execute_7_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (1) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1804713}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execute_6_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (2) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1804713}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execute_5_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (3) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1804713}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execute_4_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (4) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1804713}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execute_3_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (5) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1804713}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execute_2_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (6) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1804713}} Attempting to reserve 5.85G at the bottom of memory. That was not possible. There are 6.15G free, 0B reserved, and 5.72G reservable.\n\t [[{{node cluster_train_function/_execut ... [truncated]"]}]},{"cell_type":"code","metadata":{"id":"p9rOiOQu6ZlP","colab":{"base_uri":"https://localhost:8080/","height":1000},"executionInfo":{"status":"error","timestamp":1650086965761,"user_tz":240,"elapsed":4034436,"user":{"displayName":"Mohammad Erfan Mowlaei","userId":"01586088812525175948"}},"outputId":"d7c70f87-356c-41e8-bed4-2958d58043d2"},"source":["def cal_prob(predict_missing_onehot):\n","    # calcaulate the probility of genotype 0, 1, 2\n","    predict_prob = predict_missing_onehot[:,:,:2] / predict_missing_onehot[:,:,:2].sum(axis=2, keepdims=True)\n","    return predict_prob[0]\n","# cal_prob(x_test[:10])\n","\n","# tf.keras.utils.plot_model(create_model(), show_shapes=True)\n","missing_perc = 0.05\n","\n","accuracies = []\n","for random_state in range(10):\n","  print(f\"Training using seed {random_state}:\")\n","  train_dataset, valid_dataset, test_dataset, steps_per_epoch, validation_steps = get_three_sets(x, batch_size, random_state, missing_perc)\n","\n","  # Create a learning rate scheduler callback.\n","  reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(\n","      monitor=\"val_loss\", factor=0.8, patience=3\n","  )\n","\n","  # Create an early stopping callback.\n","  early_stopping = tf.keras.callbacks.EarlyStopping(\n","      monitor=\"val_loss\", patience=10, restore_best_weights=True\n","  )\n","\n","  num_epochs = 100\n","\n","  K.clear_session()\n","  with strategy.scope():\n","    \n","    model = create_model()\n","    # tf.keras.utils.plot_model(model, show_shapes=True)\n","    optimizer = tfa.optimizers.LAMB(\n","          learning_rate=learning_rate,\n","          # weight_decay_rate=weight_decay,\n","      )\n","\n","    model.compile(optimizer, loss='mse')\n","    # model.compile(optimizer, loss=tf.keras.losses.CategoricalCrossentropy(), metrics='accuracy')\n","\n","    history = model.fit(\n","        train_dataset,\n","        steps_per_epoch=steps_per_epoch,\n","        validation_data=valid_dataset,\n","        validation_steps=validation_steps,\n","        epochs=num_epochs,\n","        verbose=2,\n","        callbacks=[early_stopping, reduce_lr]\n","    )\n","    predict_data = model.predict(test_dataset[0])\n","    print(\"Test MSE:\", mean_squared_error(test_dataset[1], predict_data))\n","\n","    print(\"==========================================================\")\n"],"execution_count":21,"outputs":[{"output_type":"stream","name":"stdout","text":["Training using seed 0:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n","43/43 - 834s - loss: 0.0183 - val_loss: 0.0254 - lr: 0.0010 - 834s/epoch - 19s/step\n","Epoch 2/100\n","43/43 - 34s - loss: 0.0116 - val_loss: 0.0155 - lr: 0.0010 - 34s/epoch - 792ms/step\n","Epoch 3/100\n","43/43 - 34s - loss: 0.0096 - val_loss: 0.0142 - lr: 0.0010 - 34s/epoch - 792ms/step\n","Epoch 4/100\n","43/43 - 34s - loss: 0.0088 - val_loss: 0.0128 - lr: 0.0010 - 34s/epoch - 796ms/step\n","Epoch 5/100\n","43/43 - 34s - loss: 0.0075 - val_loss: 0.0117 - lr: 0.0010 - 34s/epoch - 797ms/step\n","Epoch 6/100\n","43/43 - 34s - loss: 0.0068 - val_loss: 0.0111 - lr: 0.0010 - 34s/epoch - 802ms/step\n","Epoch 7/100\n","43/43 - 35s - loss: 0.0062 - val_loss: 0.0108 - lr: 0.0010 - 35s/epoch - 807ms/step\n","Epoch 8/100\n","43/43 - 34s - loss: 0.0055 - val_loss: 0.0103 - lr: 0.0010 - 34s/epoch - 794ms/step\n","Epoch 9/100\n","43/43 - 34s - loss: 0.0050 - val_loss: 0.0100 - lr: 0.0010 - 34s/epoch - 798ms/step\n","Epoch 10/100\n","43/43 - 34s - loss: 0.0044 - val_loss: 0.0097 - lr: 0.0010 - 34s/epoch - 789ms/step\n","Epoch 11/100\n","43/43 - 28s - loss: 0.0038 - val_loss: 0.0099 - lr: 0.0010 - 28s/epoch - 657ms/step\n","Epoch 12/100\n","43/43 - 28s - loss: 0.0035 - val_loss: 0.0101 - lr: 0.0010 - 28s/epoch - 658ms/step\n","Epoch 13/100\n","43/43 - 28s - loss: 0.0030 - val_loss: 0.0104 - lr: 0.0010 - 28s/epoch - 658ms/step\n","Epoch 14/100\n","43/43 - 28s - loss: 0.0028 - val_loss: 0.0103 - lr: 8.0000e-04 - 28s/epoch - 657ms/step\n","Epoch 15/100\n","43/43 - 28s - loss: 0.0025 - val_loss: 0.0106 - lr: 8.0000e-04 - 28s/epoch - 658ms/step\n","Epoch 16/100\n","43/43 - 28s - loss: 0.0022 - val_loss: 0.0110 - lr: 8.0000e-04 - 28s/epoch - 660ms/step\n","Epoch 17/100\n","43/43 - 28s - loss: 0.0021 - val_loss: 0.0107 - lr: 6.4000e-04 - 28s/epoch - 658ms/step\n","Epoch 18/100\n","43/43 - 28s - loss: 0.0018 - val_loss: 0.0108 - lr: 6.4000e-04 - 28s/epoch - 663ms/step\n","Epoch 19/100\n","43/43 - 28s - loss: 0.0018 - val_loss: 0.0113 - lr: 6.4000e-04 - 28s/epoch - 660ms/step\n","Epoch 20/100\n","43/43 - 56s - loss: 0.0017 - val_loss: 0.0119 - lr: 5.1200e-04 - 56s/epoch - 1s/step\n","Test MSE: 0.009677360983830787\n","==========================================================\n","Training using seed 1:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n","43/43 - 840s - loss: 0.0184 - val_loss: 0.0250 - lr: 0.0010 - 840s/epoch - 20s/step\n","Epoch 2/100\n","43/43 - 35s - loss: 0.0110 - val_loss: 0.0162 - lr: 0.0010 - 35s/epoch - 802ms/step\n","Epoch 3/100\n","43/43 - 34s - loss: 0.0096 - val_loss: 0.0149 - lr: 0.0010 - 34s/epoch - 795ms/step\n","Epoch 4/100\n","43/43 - 34s - loss: 0.0087 - val_loss: 0.0135 - lr: 0.0010 - 34s/epoch - 795ms/step\n","Epoch 5/100\n","43/43 - 35s - loss: 0.0077 - val_loss: 0.0125 - lr: 0.0010 - 35s/epoch - 803ms/step\n","Epoch 6/100\n","43/43 - 34s - loss: 0.0069 - val_loss: 0.0119 - lr: 0.0010 - 34s/epoch - 797ms/step\n","Epoch 7/100\n","43/43 - 35s - loss: 0.0062 - val_loss: 0.0115 - lr: 0.0010 - 35s/epoch - 804ms/step\n","Epoch 8/100\n","43/43 - 35s - loss: 0.0054 - val_loss: 0.0111 - lr: 0.0010 - 35s/epoch - 802ms/step\n","Epoch 9/100\n","43/43 - 34s - loss: 0.0048 - val_loss: 0.0109 - lr: 0.0010 - 34s/epoch - 802ms/step\n","Epoch 10/100\n","43/43 - 35s - loss: 0.0043 - val_loss: 0.0106 - lr: 0.0010 - 35s/epoch - 805ms/step\n","Epoch 11/100\n","43/43 - 35s - loss: 0.0037 - val_loss: 0.0104 - lr: 0.0010 - 35s/epoch - 803ms/step\n","Epoch 12/100\n","43/43 - 28s - loss: 0.0033 - val_loss: 0.0114 - lr: 0.0010 - 28s/epoch - 660ms/step\n","Epoch 13/100\n","43/43 - 28s - loss: 0.0028 - val_loss: 0.0113 - lr: 0.0010 - 28s/epoch - 662ms/step\n","Epoch 14/100\n","43/43 - 28s - loss: 0.0027 - val_loss: 0.0115 - lr: 0.0010 - 28s/epoch - 661ms/step\n","Epoch 15/100\n","43/43 - 28s - loss: 0.0024 - val_loss: 0.0121 - lr: 8.0000e-04 - 28s/epoch - 662ms/step\n","Epoch 16/100\n","43/43 - 28s - loss: 0.0022 - val_loss: 0.0116 - lr: 8.0000e-04 - 28s/epoch - 661ms/step\n","Epoch 17/100\n","43/43 - 28s - loss: 0.0020 - val_loss: 0.0118 - lr: 8.0000e-04 - 28s/epoch - 662ms/step\n","Epoch 18/100\n","43/43 - 28s - loss: 0.0018 - val_loss: 0.0120 - lr: 6.4000e-04 - 28s/epoch - 661ms/step\n","Epoch 19/100\n","43/43 - 28s - loss: 0.0018 - val_loss: 0.0119 - lr: 6.4000e-04 - 28s/epoch - 662ms/step\n","Epoch 20/100\n","43/43 - 28s - loss: 0.0016 - val_loss: 0.0119 - lr: 6.4000e-04 - 28s/epoch - 662ms/step\n","Epoch 21/100\n","43/43 - 56s - loss: 0.0016 - val_loss: 0.0118 - lr: 5.1200e-04 - 56s/epoch - 1s/step\n","Test MSE: 0.01063299384640534\n","==========================================================\n","Training using seed 2:\n","x_train percentage: 0.6479461592016709\n","x_valid percentage: 0.15200742631701092\n","x_test percentage: 0.20004641448131816\n","Epoch 1/100\n"]},{"output_type":"error","ename":"ResourceExhaustedError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mResourceExhaustedError\u001b[0m                    Traceback (most recent call last)","\u001b[0;32m<ipython-input-21-93be959c9471>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     45\u001b[0m         \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     46\u001b[0m         \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m         \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mearly_stopping\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_lr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     48\u001b[0m     )\n\u001b[1;32m     49\u001b[0m     \u001b[0mpredict_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_dataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     65\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint: disable=broad-except\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     66\u001b[0m       \u001b[0mfiltered_tb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_process_traceback_frames\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__traceback__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 67\u001b[0;31m       \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     68\u001b[0m     \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     69\u001b[0m       \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py\u001b[0m in \u001b[0;36m_numpy\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1189\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_numpy_internal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1190\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_NotOkStatusException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1191\u001b[0;31m       \u001b[0;32mraise\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_status_to_exception\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m  \u001b[0;31m# pylint: disable=protected-access\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1192\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1193\u001b[0m   \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mResourceExhaustedError\u001b[0m: 9 root error(s) found.\n  (0) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1880788}} Attempting to allocate 202.87M. That was not possible. There are 769.22M free. Due to fragmentation, the largest contiguous region of free memory is 149.66M.; (0x0x0_HBM0)\n\t [[{{node cluster_train_function/_execute_0_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_3_0/_63]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_7_0/_78]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (1) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1880788}} Attempting to allocate 202.87M. That was not possible. There are 769.22M free. Due to fragmentation, the largest contiguous region of free memory is 149.66M.; (0x0x0_HBM0)\n\t [[{{node cluster_train_function/_execute_0_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_3_0/_63]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_6_0/_74]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (2) RESOURCE_EXHAUSTED: {{function_node __inference_train_function_1880788}} Attempting to allocate 202.87M. That was not possible. There are 769.22M free. Due to fragmentation, the largest contiguous region of free memory is 149.66M.; (0x0x0_HBM0)\n\t [[{{node cluster_train_function/_execute_0_0}}]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_3_0/_63]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n\t [[cluster_train_function/_execute_5_0/_70]]\nHint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.\n\n  (3) RESOURCE_EXHAUSTED: {{function_nod ... [truncated]"]}]},{"cell_type":"code","metadata":{"id":"A_0sblsCQ62m"},"source":[""],"execution_count":null,"outputs":[]}]}