diff --git a/src/colab/skai_assessment_notebook_custom_vm.ipynb b/src/colab/skai_assessment_notebook_custom_vm.ipynb deleted file mode 100644 index 27ebd217..00000000 --- a/src/colab/skai_assessment_notebook_custom_vm.ipynb +++ /dev/null @@ -1,1364 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "34c59af2", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Configure Assessment Parameters\n", - "\n", - "# pylint:disable=missing-module-docstring\n", - "# pylint:disable=g-bad-import-order\n", - "# pylint:disable=g-wrong-blank-lines\n", - "# pylint:disable=g-import-not-at-top\n", - "\n", - "# @markdown You must re-run this cell every time you make a change.\n", - "import os\n", - "import textwrap\n", - "import ee\n", - "\n", - "GCP_PROJECT = '' # @param {type:\"string\"}\n", - "GCP_LOCATION = '' # @param {type:\"string\"}\n", - "GCP_BUCKET = '' # @param {type:\"string\"}\n", - "GCP_SERVICE_ACCOUNT = '' # @param {type:\"string\"}\n", - "SERVICE_ACCOUNT_KEY = '' # @param {type:\"string\"}\n", - "# @markdown This is only needed if BUILDINGS_METHOD is set to \"run_model\":\n", - "BUILDING_SEGMENTATION_MODEL_PATH = '' # @param {type:\"string\"}\n", - "\n", - "# @markdown ---\n", - "ASSESSMENT_NAME = '' # @param {type:\"string\"}\n", - "EVENT_DATE = '' # @param {type:\"date\"}\n", - "OUTPUT_DIR = '' # @param {type:\"string\"}\n", - "\n", - "# @markdown ---\n", - "BEFORE_IMAGE_0 = '' # @param {type:\"string\"}\n", - "BEFORE_IMAGE_1 = '' # @param {type:\"string\"}\n", - "BEFORE_IMAGE_2 = '' # @param {type:\"string\"}\n", - "BEFORE_IMAGE_3 = '' # @param {type:\"string\"}\n", - "BEFORE_IMAGE_4 = '' # @param {type:\"string\"}\n", - "BEFORE_IMAGE_5 = '' # @param {type:\"string\"}\n", - "BEFORE_IMAGE_6 = '' # @param {type:\"string\"}\n", - "BEFORE_IMAGE_7 = '' # @param {type:\"string\"}\n", - "BEFORE_IMAGE_8 = '' # @param {type:\"string\"}\n", - "BEFORE_IMAGE_9 = '' # @param {type:\"string\"}\n", - "# @markdown ---\n", - "AFTER_IMAGE_0 = '' # @param {type:\"string\"}\n", - "AFTER_IMAGE_1 = '' # @param {type:\"string\"}\n", - "AFTER_IMAGE_2 = '' # @param {type:\"string\"}\n", - "AFTER_IMAGE_3 = '' # @param {type:\"string\"}\n", - "AFTER_IMAGE_4 = '' # @param {type:\"string\"}\n", - "AFTER_IMAGE_5 = '' # @param {type:\"string\"}\n", - "AFTER_IMAGE_6 = '' # @param {type:\"string\"}\n", - "AFTER_IMAGE_7 = '' # @param {type:\"string\"}\n", - "AFTER_IMAGE_8 = '' # @param {type:\"string\"}\n", - "AFTER_IMAGE_9 = '' # @param {type:\"string\"}\n", - "\n", - "# Constants\n", - "SKAI_REPO = 'https://github.com/google-research/skai.git'\n", - "OPEN_BUILDINGS_FEATURE_COLLECTION = 'GOOGLE/Research/open-buildings/v3/polygons'\n", - "OSM_OVERPASS_URL = 'https://lz4.overpass-api.de/api/interpreter'\n", - "\n", - "# Derived variables\n", - "SKAI_CODE_DIR = '/content/skai_src'\n", - "AOI_PATH = os.path.join(OUTPUT_DIR, 'aoi.geojson')\n", - "BUILDINGS_FILE_LOG = os.path.join(OUTPUT_DIR, 'buildings_file_log.txt')\n", - "EXAMPLE_GENERATION_CONFIG_PATH = os.path.join(\n", - " OUTPUT_DIR, 'example_generation_config.json'\n", - ")\n", - "UNLABELED_TFRECORD_PATTERN = os.path.join(\n", - " OUTPUT_DIR, 'examples', 'unlabeled-large', 'unlabeled-*-of-*.tfrecord'\n", - ")\n", - "ZERO_SHOT_DIR = os.path.join(OUTPUT_DIR, 'zero_shot_model')\n", - "ZERO_SHOT_SCORES = os.path.join(ZERO_SHOT_DIR, 'dataset_0_output.csv')\n", - "LABELING_IMAGES_DIR = os.path.join(OUTPUT_DIR, 'labeling_images')\n", - "LABELING_EXAMPLES_TFRECORD_PATTERN = os.path.join(\n", - " LABELING_IMAGES_DIR, '*', 'labeling_examples.tfrecord'\n", - ")\n", - "LABELS_CSV = os.path.join(OUTPUT_DIR, 'labels.csv')\n", - "TRAIN_TFRECORD = os.path.join(OUTPUT_DIR, 'labeled_examples_train.tfrecord')\n", - "TEST_TFRECORD = os.path.join(OUTPUT_DIR, 'labeled_examples_test.tfrecord')\n", - "MODEL_DIR = os.path.join(OUTPUT_DIR, 'models')\n", - "INFERENCE_CSV = os.path.join(OUTPUT_DIR, 'inference_scores.csv')\n", - "\n", - "\n", - "def process_image_entries(entries: list[str]) -> list[str]:\n", - " image_ids = []\n", - " for entry in entries:\n", - " entry = entry.strip()\n", - " if entry:\n", - " image_ids.append(entry)\n", - " return image_ids\n", - "\n", - "\n", - "BEFORE_IMAGES = process_image_entries([\n", - " BEFORE_IMAGE_0,\n", - " BEFORE_IMAGE_1,\n", - " BEFORE_IMAGE_2,\n", - " BEFORE_IMAGE_3,\n", - " BEFORE_IMAGE_4,\n", - " BEFORE_IMAGE_5,\n", - " BEFORE_IMAGE_6,\n", - " BEFORE_IMAGE_7,\n", - " BEFORE_IMAGE_8,\n", - " BEFORE_IMAGE_9,\n", - "])\n", - "\n", - "AFTER_IMAGES = process_image_entries([\n", - " AFTER_IMAGE_0,\n", - " AFTER_IMAGE_1,\n", - " AFTER_IMAGE_2,\n", - " AFTER_IMAGE_3,\n", - " AFTER_IMAGE_4,\n", - " AFTER_IMAGE_5,\n", - " AFTER_IMAGE_6,\n", - " AFTER_IMAGE_7,\n", - " AFTER_IMAGE_8,\n", - " AFTER_IMAGE_9,\n", - "])\n", - "\n", - "if os.path.exists(SERVICE_ACCOUNT_KEY):\n", - " os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_KEY\n", - "else:\n", - " print(f'Service account key not found: \"{SERVICE_ACCOUNT_KEY}\"')" - ] - }, - { - "cell_type": "markdown", - "id": "5c960eea", - "metadata": {}, - "source": [ - "#Initialization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2ac3dbe1", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Imports and Function Defs\n", - "%load_ext tensorboard\n", - "\n", - "import collections\n", - "import io\n", - "import json\n", - "import math\n", - "import shutil\n", - "import subprocess\n", - "import time\n", - "import warnings\n", - "\n", - "import folium\n", - "import folium.plugins\n", - "import geopandas as gpd\n", - "from google.colab import data_table\n", - "from google.colab import files\n", - "from IPython.display import display\n", - "import ipywidgets as widgets\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import shapely.wkt\n", - "from skai import earth_engine as skai_ee\n", - "from skai import labeling\n", - "from skai import open_street_map\n", - "import tensorflow as tf\n", - "import tqdm.notebook\n", - "\n", - "data_table.enable_dataframe_formatter()\n", - "\n", - "\n", - "def convert_wgs_to_utm(lon: float, lat: float):\n", - " \"\"\"Based on lat and lng, return best utm epsg-code.\"\"\"\n", - " utm_band = str((math.floor((lon + 180) / 6) % 60) + 1)\n", - " if len(utm_band) == 1:\n", - " utm_band = '0' + utm_band\n", - " if lat >= 0:\n", - " epsg_code = '326' + utm_band\n", - " else:\n", - " epsg_code = '327' + utm_band\n", - " return f'EPSG:{epsg_code}'\n", - "\n", - "\n", - "def get_aoi_area_km2(aoi_path: str):\n", - " with tf.io.gfile.GFile(aoi_path) as f:\n", - " aoi = gpd.read_file(f)\n", - "\n", - " centroid = aoi.geometry.unary_union.centroid\n", - " utm_crs = convert_wgs_to_utm(centroid.x, centroid.y)\n", - " utm_aoi = aoi.to_crs(utm_crs)\n", - " area_meters_squared = utm_aoi.geometry.unary_union.area\n", - " area_km_squared = area_meters_squared / 1000000\n", - " return area_km_squared\n", - "\n", - "\n", - "def show_inference_stats(\n", - " aoi_path: str,\n", - " inference_csv_path: str,\n", - " threshold: float):\n", - " \"\"\"Prints out statistics on inference result.\"\"\"\n", - " with tf.io.gfile.GFile(inference_csv_path) as f:\n", - " df = pd.read_csv(f)\n", - " building_count = len(df)\n", - " if 'damage_score' in df.columns:\n", - " scores = df['damage_score']\n", - " elif 'score' in df.columns:\n", - " scores = df['score']\n", - " else:\n", - " raise ValueError(f'{inference_csv_path} does not contain a score column.')\n", - "\n", - " damaged = df.loc[scores > threshold]\n", - " damaged_count = len(damaged)\n", - " damaged_pct = 100 * damaged_count / building_count\n", - " print('Area KM^2:', get_aoi_area_km2(aoi_path))\n", - " print('Buildings assessed:', building_count)\n", - " print('Damaged buildings:', damaged_count)\n", - " print(f'Percentage damaged: {damaged_pct:0.3g}%')\n", - "\n", - "\n", - "def _open_file(path: str, mode: str):\n", - " f = tf.io.gfile.GFile(path, mode)\n", - " f.closed = False\n", - " return f\n", - "\n", - "\n", - "def _file_exists(path: str) -> bool:\n", - " return bool(tf.io.gfile.glob(path))\n", - "\n", - "\n", - "def _read_text_file(path: str) -> str:\n", - " with tf.io.gfile.GFile(path, 'r') as f:\n", - " return f.read()\n", - "\n", - "\n", - "def _make_map(longitude: float, latitude: float, zoom: float):\n", - " \"\"\"Creates a Folium map with common base layers.\n", - "\n", - " Args:\n", - " longitude: Longitude of initial view.\n", - " latitude: Latitude of initial view.\n", - " zoom: Zoom level of initial view.\n", - "\n", - " Returns:\n", - " Folium map.\n", - " \"\"\"\n", - " base_maps = [\n", - " folium.TileLayer(\n", - " tiles='https://mt1.google.com/vt/lyrs=m&x={x}&y={y}&z={z}',\n", - " attr='Google',\n", - " name='Google Maps',\n", - " overlay=False,\n", - " control=True,\n", - " ),\n", - " ]\n", - "\n", - " m = folium.Map(\n", - " location=(latitude, longitude),\n", - " max_zoom=24,\n", - " zoom_start=zoom,\n", - " tiles=None)\n", - " for base_map in base_maps:\n", - " base_map.add_to(m)\n", - " return m\n", - "\n", - "\n", - "def show_assessment_heatmap(\n", - " aoi_path: str,\n", - " scores_path: str,\n", - " threshold: float,\n", - " is_zero_shot: bool):\n", - " \"\"\"Creates a Folium heatmap from inference scores.\"\"\"\n", - " with _open_file(scores_path, 'rb') as f:\n", - " df = pd.read_csv(f)\n", - " if is_zero_shot:\n", - " damaged = df.loc[~df['is_cloudy'] & (df['damage_score'] >= threshold)]\n", - " else:\n", - " damaged = df.loc[df['score'] >= threshold]\n", - " points = zip(damaged['latitude'].values, damaged['longitude'].values)\n", - " centroid_x = np.mean(damaged['longitude'].values)\n", - " centroid_y = np.mean(damaged['latitude'].values)\n", - " folium_map = _make_map(centroid_x, centroid_y, 12)\n", - " with _open_file(aoi_path, 'rb') as f:\n", - " aoi_gdf = gpd.read_file(f)\n", - " folium.GeoJson(\n", - " aoi_gdf.to_json(),\n", - " name='AOI',\n", - " style_function=lambda _: {'fillOpacity': 0},\n", - " ).add_to(folium_map)\n", - " heatmap = folium.plugins.HeatMap(points)\n", - " heatmap.add_to(folium_map)\n", - " display(folium_map)\n", - "\n", - "\n", - "def make_download_button(path: str, file_name: str, caption: str):\n", - " \"\"\"Displays a button for downloading a file in the colab kernel.\"\"\"\n", - " def download(_):\n", - " temp_path = f'/tmp/{file_name}'\n", - " with _open_file(path, 'rb') as src:\n", - " with open(temp_path, 'wb') as dst:\n", - " shutil.copyfileobj(src, dst)\n", - " files.download(temp_path)\n", - "\n", - " button = widgets.Button(\n", - " description=caption,\n", - " )\n", - " button.on_click(download)\n", - " display(button)\n", - "\n", - "\n", - "def find_model_dirs(model_root: str):\n", - " # Find all checkpoints dirs first. We only want model dirs that have at least\n", - " # one checkpoint.\n", - " checkpoint_dirs = tf.io.gfile.glob(\n", - " os.path.join(model_root, '*/*/model/epoch-*-aucpr-*'))\n", - " model_dirs = set(os.path.dirname(os.path.dirname(p)) for p in checkpoint_dirs)\n", - " return model_dirs\n", - "\n", - "\n", - "def find_labeling_image_metadata_files(labeling_images_dir: str):\n", - " return tf.io.gfile.glob(os.path.join(\n", - " labeling_images_dir, '*', 'image_metadata.csv'))\n", - "\n", - "\n", - "def yes_no_text(value: bool) -> str:\n", - " return '\\x1b[32mYES\\x1b[0m' if value else '\\x1b[31mNO\\x1b[0m'\n", - "\n", - "\n", - "def visualize_images(images: list[tuple[np.ndarray, np.ndarray]]):\n", - " \"\"\"Displays before and after images side-by-side.\"\"\"\n", - " num_rows = len(images)\n", - " size_factor = 3\n", - " fig_size = (2 * size_factor, num_rows * size_factor)\n", - " fig, axes = plt.subplots(num_rows, 2, figsize=fig_size)\n", - " for row, (pre_image, post_image) in enumerate(images):\n", - " ax1 = axes[row, 0]\n", - " ax2 = axes[row, 1]\n", - " ax1.axis('off')\n", - " ax2.axis('off')\n", - " ax1.imshow(pre_image)\n", - " ax2.imshow(post_image)\n", - " plt.show(fig)\n", - "\n", - "\n", - "def get_eeda_bearer_token(service_account: str) -> str:\n", - " return subprocess.check_output(\n", - " 'gcloud auth print-access-token'\n", - " f' --impersonate-service-account=\"{service_account}\"',\n", - " shell=True,\n", - " ).decode()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0aee715f", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Authenticate with Earth Engine\n", - "\n", - "def auth():\n", - " credentials = ee.ServiceAccountCredentials(\n", - " GCP_SERVICE_ACCOUNT, SERVICE_ACCOUNT_KEY)\n", - " ee.Initialize(credentials)\n", - "\n", - "auth()" - ] - }, - { - "cell_type": "markdown", - "id": "935eb444", - "metadata": {}, - "source": [ - "# Check Assessment Status\n", - "\n", - "Run the following cell to check which steps of the assessment have already\n", - "been completed." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9add96fb", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Check assessment status\n", - "def check_assessment_status():\n", - " \"\"\"Shows which steps of the assessment have been completed.\"\"\"\n", - " print('AOI uploaded:', yes_no_text(_file_exists(AOI_PATH)))\n", - "\n", - " if _file_exists(BUILDINGS_FILE_LOG):\n", - " buildings_file = _read_text_file(BUILDINGS_FILE_LOG).strip()\n", - " print('Building footprints generated:', yes_no_text(True))\n", - " print(f' Building footprints file: {buildings_file}')\n", - " else:\n", - " print('Building footprints generated:', yes_no_text(False))\n", - "\n", - " print(\n", - " 'Example generation config file exists:',\n", - " yes_no_text(_file_exists(EXAMPLE_GENERATION_CONFIG_PATH)),\n", - " )\n", - " print(\n", - " 'Unlabeled examples generated:',\n", - " yes_no_text(_file_exists(UNLABELED_TFRECORD_PATTERN)),\n", - " )\n", - " print(\n", - " 'Zero-shot assessment generated:',\n", - " yes_no_text(_file_exists(ZERO_SHOT_SCORES)),\n", - " )\n", - " labeling_metadata_files = find_labeling_image_metadata_files(\n", - " LABELING_IMAGES_DIR\n", - " )\n", - " print(\n", - " 'Labeling images generated:', yes_no_text(bool(labeling_metadata_files))\n", - " )\n", - " for p in labeling_metadata_files:\n", - " print(f' {p}')\n", - " print('Label CSV uploaded:', yes_no_text(_file_exists(LABELS_CSV)))\n", - " print(\n", - " 'Labeled examples generated:',\n", - " yes_no_text(_file_exists(TRAIN_TFRECORD) and _file_exists(TEST_TFRECORD)),\n", - " )\n", - " trained_models = find_model_dirs(MODEL_DIR)\n", - " print('Fine-tuned model trained:', yes_no_text(bool(trained_models)))\n", - " for model_dir in trained_models:\n", - " print(f' {model_dir}')\n", - "\n", - " print(\n", - " 'Fine-tuned inference generated:',\n", - " yes_no_text(_file_exists(INFERENCE_CSV)),\n", - " )\n", - "\n", - "check_assessment_status()" - ] - }, - { - "cell_type": "markdown", - "id": "d1668f88", - "metadata": {}, - "source": [ - "# Example Generation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "96cc153b", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Upload AOI file\n", - "def upload_aoi():\n", - " \"\"\"Shows button for user to upload AOI to the assessment directory.\"\"\"\n", - " if _file_exists(AOI_PATH):\n", - " print(f'AOI file {AOI_PATH} already exists.')\n", - " answer = input('Do you want to overwrite (y/n)? ')\n", - " if answer.lower() not in ['y', 'yes']:\n", - " print('AOI file not uploaded.')\n", - " return\n", - "\n", - " uploaded = files.upload()\n", - "\n", - " file_names = list(uploaded.keys())\n", - " if len(file_names) != 1:\n", - " print('You must choose exactly one GeoJSON file to upload.')\n", - " print('Upload NOT successful.')\n", - " return\n", - "\n", - " if not file_names[0].endswith('.geojson'):\n", - " print('AOI file must be in GeoJSON format and have extension \".geojson\".')\n", - " print('Upload NOT successful.')\n", - " return\n", - "\n", - " with _open_file(AOI_PATH, 'wb') as f:\n", - " f.write(uploaded[file_names[0]])\n", - "\n", - "upload_aoi()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1def2011", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Get building footprints\n", - "\n", - "# pylint:disable=line-too-long\n", - "BUILDINGS_METHOD = 'open_buildings' # @param [\"open_buildings\",\"open_street_map\",\"run_model\",\"file\"]\n", - "# pylint:enable=line-too-long\n", - "# @markdown This is only needed if BUILDINGS_METHOD is set to \"file\":\n", - "USER_BUILDINGS_FILE = '' # @param {type:\"string\"}\n", - "\n", - "\n", - "def download_open_buildings(aoi_path: str, output_dir: str) -> str:\n", - " path = os.path.join(output_dir, 'open_buildings.parquet')\n", - " with _open_file(aoi_path, 'r') as f:\n", - " gdf = gpd.read_file(f)\n", - " aoi = gdf.unary_union\n", - " skai_ee.get_open_buildings(\n", - " [aoi], OPEN_BUILDINGS_FEATURE_COLLECTION, 0.5, False, path)\n", - " return path\n", - "\n", - "\n", - "def download_open_street_map(aoi_path: str, output_dir: str) -> str:\n", - " path = os.path.join(output_dir, 'open_street_map_buildings.parquet')\n", - " with _open_file(aoi_path, 'r') as f:\n", - " gdf = gpd.read_file(f)\n", - " aoi = gdf.unary_union\n", - " open_street_map.get_building_centroids_in_regions(\n", - " [aoi], OSM_OVERPASS_URL, path\n", - " )\n", - " return path\n", - "\n", - "\n", - "def run_building_detection_model(\n", - " aoi_path: str,\n", - " output_dir: str):\n", - " \"\"\"Runs building detection model.\"\"\"\n", - " image_paths = ','.join(BEFORE_IMAGES)\n", - " child_dir = os.path.join(output_dir, 'buildings')\n", - " if any('EEDAI:' in image for image in BEFORE_IMAGES):\n", - " token = get_eeda_bearer_token(GCP_SERVICE_ACCOUNT)\n", - " eeda_bearer_env = f'export EEDA_BEARER=\"{token}\"'\n", - " else:\n", - " eeda_bearer_env = ''\n", - "\n", - " script = textwrap.dedent(f'''\n", - " export PYTHONPATH={SKAI_CODE_DIR}/src:$PYTHONPATH\n", - " export GOOGLE_CLOUD_PROJECT={GCP_PROJECT}\n", - " {eeda_bearer_env}\n", - " cd {SKAI_CODE_DIR}/src\n", - " python detect_buildings_main.py \\\n", - " --cloud_project='{GCP_PROJECT}' \\\n", - " --cloud_region='{GCP_LOCATION}' \\\n", - " --worker_service_account='{GCP_SERVICE_ACCOUNT}' \\\n", - " --use_dataflow \\\n", - " --output_dir='{output_dir}' \\\n", - " --image_paths='{image_paths}' \\\n", - " --aoi_path='{aoi_path}' \\\n", - " --model_path='{BUILDING_SEGMENTATION_MODEL_PATH}'\n", - " ''')\n", - " script_path = '/content/run_building_detection.sh'\n", - " with open(script_path, 'w') as f:\n", - " f.write(script)\n", - " !bash {script_path}\n", - "\n", - " buildings_file = os.path.join(child_dir, 'dedup_buildings.parquet')\n", - " return buildings_file\n", - "\n", - "\n", - "def _display_building_footprints(buildings_gdf: gpd.GeoDataFrame):\n", - " \"\"\"Visualizes building footprints in a folium map.\"\"\"\n", - " centroid = buildings_gdf.centroid.unary_union.centroid\n", - " folium_map = _make_map(centroid.x, centroid.y, 13)\n", - " folium.GeoJson(\n", - " buildings_gdf.to_json(),\n", - " name='buildings',\n", - " marker=folium.CircleMarker(\n", - " radius=3, weight=0, fill_color='#FF0000', fill_opacity=1\n", - " ),\n", - " ).add_to(folium_map)\n", - " display(folium_map)\n", - "\n", - "\n", - "def download_buildings(aoi_path: str, output_dir: str) -> None:\n", - " \"\"\"Downloads buildings to assessment directory.\"\"\"\n", - " if BUILDINGS_METHOD == 'open_buildings':\n", - " path = download_open_buildings(aoi_path, output_dir)\n", - " elif BUILDINGS_METHOD == 'open_street_map':\n", - " path = download_open_street_map(aoi_path, output_dir)\n", - " elif BUILDINGS_METHOD == 'run_model':\n", - " path = run_building_detection_model(aoi_path, output_dir)\n", - " elif BUILDINGS_METHOD == 'file':\n", - " path = USER_BUILDINGS_FILE\n", - " else:\n", - " raise ValueError(f'Unknown BUILDINGS_METHOD {BUILDINGS_METHOD}')\n", - "\n", - " with _open_file(BUILDINGS_FILE_LOG, 'w') as f:\n", - " f.write(f'{path}\\n')\n", - "\n", - " with _open_file(path, 'rb') as f:\n", - " if path.endswith('.csv'):\n", - " df = pd.read_csv(f)\n", - " df['geometry'] = df['wkt'].apply(shapely.wkt.loads)\n", - " gdf = gpd.GeoDataFrame(df.drop(columns=['wkt']), crs='EPSG:4326')\n", - " elif path.endswith('.parquet'):\n", - " gdf = gpd.read_parquet(f)\n", - " else:\n", - " gdf = gpd.read_file(f)\n", - " print(f'Found {len(gdf)} buildings.')\n", - " print(f'Saved buildings to {path}')\n", - " if len(gdf) < 500000:\n", - " _display_building_footprints(gdf)\n", - " else:\n", - " print('Too many buildings to visualize. Use QGIS instead.')\n", - "\n", - "with warnings.catch_warnings():\n", - " warnings.simplefilter('ignore')\n", - " download_buildings(AOI_PATH, OUTPUT_DIR)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "91539f74", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Write Example Generation Config File\n", - "def write_example_generation_config(path: str) -> None:\n", - " \"\"\"Writes example generation config file to assessment directory.\"\"\"\n", - " dataset_name = ASSESSMENT_NAME.lower().replace('_', '-')\n", - " with _open_file(BUILDINGS_FILE_LOG, 'r') as f:\n", - " buildings_file = f.read().strip()\n", - "\n", - " config_dict = {\n", - " 'dataset_name': dataset_name,\n", - " 'aoi_path': AOI_PATH,\n", - " 'output_dir': OUTPUT_DIR,\n", - " 'buildings_method': 'file',\n", - " 'buildings_file': buildings_file,\n", - " 'resolution': 0.5,\n", - " 'use_dataflow': True,\n", - " 'cloud_project': GCP_PROJECT,\n", - " 'cloud_region': GCP_LOCATION,\n", - " 'worker_service_account': GCP_SERVICE_ACCOUNT,\n", - " 'max_dataflow_workers': 100,\n", - " 'output_shards': 100,\n", - " 'output_metadata_file': True,\n", - " 'before_image_patterns': BEFORE_IMAGES,\n", - " 'after_image_patterns': AFTER_IMAGES,\n", - " }\n", - "\n", - " valid_config = True\n", - " for key, value in config_dict.items():\n", - " if not value:\n", - " if key == 'buildings_file' and config_dict['buildings_method'] != 'file':\n", - " continue\n", - " print(f'Field {key} cannot be empty')\n", - " valid_config = False\n", - " if not valid_config:\n", - " return\n", - "\n", - " config_string = json.dumps(config_dict, indent=2)\n", - " print(f'Example Generation configuration written to {path}:')\n", - " print()\n", - " print(config_string)\n", - " with tf.io.gfile.GFile(path, 'w') as f:\n", - " f.write(config_string)\n", - "\n", - "write_example_generation_config(EXAMPLE_GENERATION_CONFIG_PATH)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "70523af4", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Run Example Generation Job\n", - "def run_example_generation(config_file_path: str):\n", - " \"\"\"Runs example generation pipeline.\"\"\"\n", - "\n", - " script = textwrap.dedent(f'''\n", - " cd {SKAI_CODE_DIR}/src\n", - " export GOOGLE_APPLICATION_CREDENTIALS={SERVICE_ACCOUNT_KEY}\n", - " python generate_examples_main.py \\\n", - " --configuration_path={config_file_path} \\\n", - " --output_metadata_file\n", - " ''')\n", - "\n", - " script_path = '/content/example_generation.sh'\n", - " with open(script_path, 'w') as f:\n", - " f.write(script)\n", - " !bash {script_path}\n", - "\n", - "run_example_generation(EXAMPLE_GENERATION_CONFIG_PATH)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "76c536fd", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Visualize Generated Examples\n", - "def visualize_generated_examples(pattern: str, num: int):\n", - " images = []\n", - " paths = tf.io.gfile.glob(pattern)\n", - " for record in tf.data.TFRecordDataset([paths[0]]).take(num):\n", - " example = tf.train.Example()\n", - " example.ParseFromString(record.numpy())\n", - " pre_image = plt.imread(io.BytesIO(\n", - " example.features.feature['pre_image_png_large'].bytes_list.value[0]))\n", - " post_image = plt.imread(io.BytesIO(\n", - " example.features.feature['post_image_png_large'].bytes_list.value[0]))\n", - " images.append((pre_image, post_image))\n", - " visualize_images(images)\n", - "\n", - "visualize_generated_examples(UNLABELED_TFRECORD_PATTERN, 3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ca5ec6e1", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Run Zero Shot Model\n", - "def run_zero_shot_model():\n", - " \"\"\"Runs zero-shot model inference.\"\"\"\n", - " script = textwrap.dedent(f'''\n", - " export PYTHONPATH={SKAI_CODE_DIR}/src:$PYTHONPATH\n", - " export GOOGLE_CLOUD_PROJECT={GCP_PROJECT}\n", - " export GOOGLE_CLOUD_BUCKET_NAME={GCP_BUCKET}\n", - " export GOOGLE_APPLICATION_CREDENTIALS={SERVICE_ACCOUNT_KEY}\n", - " cd {SKAI_CODE_DIR}/src\n", - "\n", - " xmanager launch skai/model/xm_vlm_zero_shot_vertex.py -- \\\n", - " --example_patterns={UNLABELED_TFRECORD_PATTERN} \\\n", - " --output_dir={ZERO_SHOT_DIR}\n", - " ''')\n", - "\n", - " print(\n", - " 'Starting zero shot model inference. Scores will be written to'\n", - " f' {ZERO_SHOT_SCORES}'\n", - " )\n", - " script_path = '/content/zero_shot_model.sh'\n", - " with open(script_path, 'w') as f:\n", - " f.write(script)\n", - " !bash {script_path}\n", - "\n", - "run_zero_shot_model()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e1835156", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title View Zero Shot Assessment\n", - "DAMAGE_SCORE_THRESHOLD = 0.5 # @param {type:\"number\"}\n", - "\n", - "make_download_button(\n", - " ZERO_SHOT_SCORES,\n", - " f'{ASSESSMENT_NAME}_zero_shot_assessment.csv',\n", - " 'Download CSV')\n", - "show_inference_stats(AOI_PATH, ZERO_SHOT_SCORES, DAMAGE_SCORE_THRESHOLD)\n", - "show_assessment_heatmap(\n", - " AOI_PATH, ZERO_SHOT_SCORES, DAMAGE_SCORE_THRESHOLD, True\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "c349992c", - "metadata": {}, - "source": [ - "# Labeling" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "121e2ccb", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Create Labeling Images\n", - "MAX_LABELING_IMAGES = 1000 # @param {\"type\":\"integer\"}\n", - "\n", - "\n", - "def visualize_labeling_images(images_dir: str, num: int):\n", - " \"\"\"Displays a small sample of labeling images.\"\"\"\n", - " pre_image_paths = sorted(\n", - " tf.io.gfile.glob(os.path.join(images_dir, '*_pre.png'))\n", - " )\n", - " post_image_paths = sorted(\n", - " tf.io.gfile.glob(os.path.join(images_dir, '*_post.png'))\n", - " )\n", - " assert len(pre_image_paths) == len(post_image_paths), (\n", - " f'Number of pre images ({len(pre_image_paths)}) does not match number of'\n", - " f' post images ({len(post_image_paths)}).'\n", - " )\n", - " images = []\n", - " for pre_image_path, post_image_path in list(\n", - " zip(pre_image_paths, post_image_paths)\n", - " )[:num]:\n", - " with _open_file(pre_image_path, 'rb') as f:\n", - " pre_image = plt.imread(f)\n", - " with _open_file(post_image_path, 'rb') as f:\n", - " post_image = plt.imread(f)\n", - " images.append((pre_image, post_image))\n", - " visualize_images(images)\n", - "\n", - "\n", - "def create_labeling_images(\n", - " examples_pattern: str,\n", - " scores_file: str,\n", - " output_dir: str,\n", - " max_images: int,\n", - "):\n", - " \"\"\"Creates labeling images.\"\"\"\n", - " if not tf.io.gfile.glob(examples_pattern):\n", - " print(\n", - " f'No files match \"{examples_pattern}\". Please run example generation'\n", - " ' first.'\n", - " )\n", - " return\n", - "\n", - " existing_metadata_files = find_labeling_image_metadata_files(output_dir)\n", - " if existing_metadata_files:\n", - " print(\n", - " 'The following labeling image metadata files have already been'\n", - " ' generated:'\n", - " )\n", - " print('\\n'.join(f' {p}' for p in existing_metadata_files))\n", - " response = input(\n", - " 'Do you want to generate a new set of labeling images (y/n)? '\n", - " )\n", - " if response.lower() not in ['y', 'yes']:\n", - " return\n", - "\n", - " timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n", - " images_dir = os.path.join(output_dir, timestamp)\n", - " metadata_csv = os.path.join(images_dir, 'image_metadata.csv')\n", - "\n", - " num_images = labeling.create_labeling_images(\n", - " examples_pattern,\n", - " max_images,\n", - " set(),\n", - " set(),\n", - " images_dir,\n", - " True,\n", - " None,\n", - " 4,\n", - " 70.0,\n", - " {\n", - " (0, 0.25): 0.25,\n", - " (0.25, 0.5): 0.25,\n", - " (0.5, 0.75): 0.25,\n", - " (0.75, 1.0): 0.25,\n", - " },\n", - " scores_path=scores_file,\n", - " filter_by_column='is_cloudy',\n", - " )\n", - " print('Number of labeling images:', num_images)\n", - " print(\n", - " 'Please create a new project in the SKAI labeling tool with the following'\n", - " ' metadata CSV:'\n", - " )\n", - " print(metadata_csv)\n", - " visualize_labeling_images(images_dir, 3)\n", - "\n", - "create_labeling_images(\n", - " UNLABELED_TFRECORD_PATTERN,\n", - " ZERO_SHOT_SCORES,\n", - " LABELING_IMAGES_DIR,\n", - " MAX_LABELING_IMAGES,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "0614bff4", - "metadata": {}, - "source": [ - "When the labeling project is complete, download the CSV from the labeling tool\n", - "and upload it to your assessment directory using the following cell.\n", - "\n", - "You may upload multiple CSV files at once, in case you wish to combine labels\n", - "from multiple rounds of labeling." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "88cbf415", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Upload Label CSV\n", - "def upload_label_csvs(output_path: str):\n", - " \"\"\"Lets the user upload the labeling CSV file from their computer.\"\"\"\n", - " uploaded = files.upload()\n", - " dfs = []\n", - " for filename in uploaded.keys():\n", - " f = io.BytesIO(uploaded[filename])\n", - " df = pd.read_csv(f)\n", - " if 'example_id' not in df.columns:\n", - " print('\"example_id\" column not found in {filename}')\n", - " return\n", - " if 'string_label' not in df.columns:\n", - " print('\"string_label\" column not found in {filename}')\n", - " return\n", - " dfs.append(df)\n", - " print(f'Read {len(df)} rows from {filename}')\n", - "\n", - " combined = pd.concat(dfs, ignore_index=True)\n", - "\n", - " with tf.io.gfile.GFile(output_path, 'wb') as f:\n", - " f.closed = False\n", - " combined.to_csv(f, index=False)\n", - "\n", - "upload_label_csvs(LABELS_CSV)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "60f9c004", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Create Labeled Examples\n", - "TEST_FRACTION = 0.2 # @param {\"type\":\"number\"}\n", - "MINOR_IS_0 = False # @param {\"type\":\"boolean\"}\n", - "\n", - "\n", - "def create_labeled_examples(\n", - " examples_pattern: str,\n", - " labels_csv: str,\n", - " test_fraction: float,\n", - " train_path: str,\n", - " test_path: str,\n", - " minor_is_0: bool):\n", - " \"\"\"Creates labeled train and test TFRecords files.\"\"\"\n", - "\n", - " minor_damage_float_label = (0 if minor_is_0 else 1)\n", - " label_mapping = [\n", - " 'bad_example=0',\n", - " 'no_damage=0',\n", - " f'minor_damage={minor_damage_float_label}',\n", - " 'major_damage=1',\n", - " 'destroyed=1',\n", - " ]\n", - "\n", - " labeling.create_labeled_examples(\n", - " label_file_paths=[labels_csv],\n", - " string_to_numeric_labels=label_mapping,\n", - " example_patterns=[examples_pattern],\n", - " test_fraction=test_fraction,\n", - " train_output_path=train_path,\n", - " test_output_path=test_path,\n", - " connecting_distance_meters=70.0,\n", - " use_multiprocessing=False,\n", - " multiprocessing_context=None,\n", - " max_processes=1,\n", - " )\n", - " print(f'Train TFRecord: {train_path}')\n", - " print(f'Test TFRecord: {test_path}')\n", - "\n", - "create_labeled_examples(\n", - " LABELING_EXAMPLES_TFRECORD_PATTERN,\n", - " LABELS_CSV,\n", - " TEST_FRACTION,\n", - " TRAIN_TFRECORD,\n", - " TEST_TFRECORD,\n", - " MINOR_IS_0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7fb9619d", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Show Label Stats\n", - "def _load_examples_into_df(\n", - " train_tfrecords: str,\n", - " test_tfrecords: str,\n", - ") -> pd.DataFrame:\n", - " \"\"\"Loads examples from TFRecords into a DataFrame.\n", - " \"\"\"\n", - " feature_config = {\n", - " 'example_id': tf.io.FixedLenFeature([], tf.string),\n", - " 'coordinates': tf.io.FixedLenFeature([2], tf.float32),\n", - " 'string_label': tf.io.FixedLenFeature([], tf.string, 'unlabeled'),\n", - " 'label': tf.io.FixedLenFeature([], tf.float32),\n", - " }\n", - "\n", - " def _parse_examples(record_bytes):\n", - " return tf.io.parse_single_example(record_bytes, feature_config)\n", - "\n", - " columns = collections.defaultdict(list)\n", - " longitudes = []\n", - " latitudes = []\n", - " for path in [train_tfrecords, test_tfrecords]:\n", - " for features in tqdm.notebook.tqdm(\n", - " tf.data.TFRecordDataset([path])\n", - " .map(_parse_examples, num_parallel_calls=tf.data.AUTOTUNE)\n", - " .prefetch(tf.data.AUTOTUNE)\n", - " .as_numpy_iterator(),\n", - " desc=path,\n", - " ):\n", - " longitudes.append(features['coordinates'][0])\n", - " latitudes.append(features['coordinates'][1])\n", - " columns['example_id'].append(features['example_id'].decode())\n", - " columns['string_label'].append(features['string_label'].decode())\n", - " columns['label'].append(features['label'])\n", - " columns['source_path'].append(path)\n", - "\n", - " return pd.DataFrame(columns)\n", - "\n", - "def _format_counts_table(df: pd.DataFrame):\n", - " for column in df.columns:\n", - " if column != 'All':\n", - " df[column] = [\n", - " f'{x} ({x/t * 100:0.2f}%)' for x, t in zip(df[column], df['All'])\n", - " ]\n", - "\n", - "def show_label_stats(train_tfrecord: str, test_tfrecord: str):\n", - " \"\"\"Displays tables showing label count stats.\"\"\"\n", - " df = _load_examples_into_df(train_tfrecord, test_tfrecord)\n", - " counts = df.pivot_table(\n", - " index='source_path',\n", - " columns='string_label',\n", - " aggfunc='count',\n", - " values='example_id',\n", - " margins=True,\n", - " fill_value=0)\n", - " _format_counts_table(counts)\n", - "\n", - " print('String Label Counts')\n", - " display(data_table.DataTable(counts))\n", - "\n", - " float_counts = df.pivot_table(\n", - " index='source_path',\n", - " columns='label',\n", - " aggfunc='count',\n", - " values='example_id',\n", - " margins=True,\n", - " fill_value=0.0)\n", - " _format_counts_table(float_counts)\n", - " print('Float Label Counts')\n", - " display(data_table.DataTable(float_counts))\n", - "\n", - "\n", - "show_label_stats(TRAIN_TFRECORD, TEST_TFRECORD)" - ] - }, - { - "cell_type": "markdown", - "id": "a0b9ba79", - "metadata": {}, - "source": [ - "# Fine Tuning" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43b61092", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Train model\n", - "\n", - "NUM_EPOCHS = 20 # @param {type:\"integer\"}\n", - "\n", - "\n", - "def run_training(\n", - " experiment_name: str,\n", - " train_path: str,\n", - " test_path: str,\n", - " output_dir: str,\n", - " num_epochs: int):\n", - " \"\"\"Runs training job.\"\"\"\n", - " if not tf.io.gfile.exists(train_path):\n", - " raise ValueError(\n", - " f'Train TFRecord {train_path} does not exist. Did you run the \"Create'\n", - " ' Labeled Examples\" cell?'\n", - " )\n", - " if not tf.io.gfile.exists(test_path):\n", - " raise ValueError(\n", - " f'Test TFRecord {test_path} does not exist. Did you run the \"Create'\n", - " ' Labeled Examples\" cell?'\n", - " )\n", - "\n", - " print(f'Train data: {train_path}')\n", - " print(f'Test data: {test_path}')\n", - " print(f'Model dir: {output_dir}')\n", - " job_args = {\n", - " 'config': 'src/skai/model/configs/skai_two_tower_config.py',\n", - " 'config.data.tfds_dataset_name': 'skai_dataset',\n", - " 'config.data.adhoc_config_name': 'adhoc_dataset',\n", - " 'config.data.labeled_train_pattern': train_path,\n", - " 'config.data.validation_pattern': test_path,\n", - " 'config.output_dir': output_dir,\n", - " 'config.training.num_epochs': num_epochs,\n", - " 'accelerator': 'V100',\n", - " 'experiment_name': experiment_name,\n", - " }\n", - " job_arg_str = ' '.join(f'--{f}={v}' for f, v in job_args.items())\n", - " sh = textwrap.dedent(f'''\n", - " export GOOGLE_CLOUD_PROJECT={GCP_PROJECT}\n", - " export GOOGLE_CLOUD_BUCKET_NAME={GCP_BUCKET}\n", - " export GOOGLE_APPLICATION_CREDENTIALS={SERVICE_ACCOUNT_KEY}\n", - " export LOCATION={GCP_LOCATION}\n", - "\n", - " cd {SKAI_CODE_DIR}\n", - "\n", - " xmanager launch src/skai/model/xm_launch_single_model_vertex.py -- \\\n", - " --xm_wrap_late_bindings \\\n", - " --xm_upgrade_db=True \\\n", - " --cloud_location=$LOCATION \\\n", - " --accelerator_count=1 {job_arg_str}''')\n", - "\n", - " with open('script.sh', 'w') as file:\n", - " file.write(sh)\n", - "\n", - " !bash script.sh\n", - "\n", - "run_training(\n", - " ASSESSMENT_NAME,\n", - " TRAIN_TFRECORD,\n", - " TEST_TFRECORD,\n", - " MODEL_DIR,\n", - " NUM_EPOCHS)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12febfce", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title View Tensorboard\n", - "def start_tensorboard(model_root: str):\n", - " \"\"\"Shows Tensorboard visualization.\"\"\"\n", - " tensorboard_dirs = tf.io.gfile.glob(\n", - " os.path.join(model_root, '*/*/tensorboard')\n", - " )\n", - " if not tensorboard_dirs:\n", - " print(\n", - " 'No Tensorboard directories found. Either you have not trained a model'\n", - " ' yet or a running job has not written any tensorboard log events yet.'\n", - " )\n", - " return\n", - "\n", - " dir_selection_widget = widgets.Dropdown(\n", - " options=tensorboard_dirs,\n", - " description='Choose a tensorboard dir:',\n", - " layout={'width': 'initial'},\n", - " )\n", - " dir_selection_widget.style.description_width = 'initial'\n", - "\n", - " def run_tensorboard(_):\n", - " # pylint:disable=unused-variable\n", - " tensorboard_dir = dir_selection_widget.value\n", - " %tensorboard --load_fast=false --logdir $tensorboard_dir\n", - " # pylint:enable=unused-variable\n", - "\n", - " start_button = widgets.Button(\n", - " description='Start',\n", - " )\n", - " start_button.on_click(run_tensorboard)\n", - "\n", - " display(dir_selection_widget)\n", - " display(start_button)\n", - "\n", - "start_tensorboard(MODEL_DIR)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "477bed14", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Run inference\n", - "def get_best_checkpoint(model_dir: str):\n", - " checkpoint_dirs = tf.io.gfile.glob(os.path.join(model_dir, 'epoch-*-aucpr-*'))\n", - " best_checkpoint = None\n", - " best_aucpr = 0\n", - " for checkpoint in checkpoint_dirs:\n", - " aucpr = float(checkpoint.split('-')[-1])\n", - " if aucpr > best_aucpr:\n", - " best_checkpoint = checkpoint\n", - " best_aucpr = aucpr\n", - " return best_checkpoint\n", - "\n", - "\n", - "def run_inference(\n", - " examples_pattern: str,\n", - " model_dir: str,\n", - " output_dir: str,\n", - " output_path: str,\n", - " cloud_project: str,\n", - " cloud_region: str,\n", - " service_account: str) -> None:\n", - " \"\"\"Starts model inference job.\"\"\"\n", - " temp_dir = os.path.join(output_dir, 'inference_temp')\n", - " print(\n", - " f'Running inference with model checkpoint \"{model_dir}\" on examples'\n", - " f' matching \"{examples_pattern}\"'\n", - " )\n", - " print(f'Output will be written to {output_path}')\n", - "\n", - " # accelerator_flags = ' '.join([\n", - " # '--worker_machine_type=n1-highmem-8',\n", - " # '--accelerator=nvidia-tesla-t4',\n", - " # '--accelerator_count=1'])\n", - "\n", - " # Currently, Colab only supports Python 3.10. However, the docker images we\n", - " # need for GPU acceleration are based on Tensorflow 2.14.0 images, which are\n", - " # based on Python 3.11. If we try to launch an inference job with GPU\n", - " # acceleration, Dataflow will complain about a Python version mismatch.\n", - " # Therefore, we can only use CPU inference until Colab upgrades to Python 3.11\n", - " # (which should be sometime within 2024).\n", - " accelerator_flags = ''\n", - "\n", - " script = textwrap.dedent(f'''\n", - " cd {SKAI_CODE_DIR}/src\n", - " export GOOGLE_CLOUD_PROJECT={cloud_project}\n", - " export GOOGLE_APPLICATION_CREDENTIALS={SERVICE_ACCOUNT_KEY}\n", - " python skai/model/inference.py \\\n", - " --examples_pattern='{examples_pattern}' \\\n", - " --image_model_dir='{model_dir}' \\\n", - " --output_path='{output_path}' \\\n", - " --use_dataflow \\\n", - " --cloud_project='{cloud_project}' \\\n", - " --cloud_region='{cloud_region}' \\\n", - " --dataflow_temp_dir='{temp_dir}' \\\n", - " --worker_service_account='{service_account}' \\\n", - " --threshold=0.5 \\\n", - " --high_precision_threshold=0.75 \\\n", - " --high_recall_threshold=0.4 \\\n", - " --max_dataflow_workers=4 {accelerator_flags}\n", - " ''')\n", - "\n", - " script_path = '/content/inference_script.sh'\n", - " with open(script_path, 'w') as f:\n", - " f.write(script)\n", - " !bash {script_path}\n", - "\n", - "\n", - "def do_inference(model_root: str):\n", - " \"\"\"Runs model inference.\"\"\"\n", - " model_dirs = find_model_dirs(model_root)\n", - " if not model_dirs:\n", - " print(\n", - " f'No models found in directory {model_root}. Please train a model'\n", - " ' first.'\n", - " )\n", - " return\n", - "\n", - " model_selection_widget = widgets.Dropdown(\n", - " options=model_dirs,\n", - " description='Choose a model:',\n", - " layout={'width': 'initial'},\n", - " )\n", - " model_selection_widget.style.description_width = 'initial'\n", - "\n", - " def start_clicked(_):\n", - " model_dir = os.path.join(model_selection_widget.value, 'model')\n", - " checkpoint = get_best_checkpoint(model_dir)\n", - " if not checkpoint:\n", - " print('Model directory does not contain a valid checkpoint directory.')\n", - " return\n", - " run_inference(\n", - " UNLABELED_TFRECORD_PATTERN,\n", - " checkpoint,\n", - " OUTPUT_DIR,\n", - " INFERENCE_CSV,\n", - " GCP_PROJECT,\n", - " GCP_LOCATION,\n", - " GCP_SERVICE_ACCOUNT,\n", - " )\n", - "\n", - " start_button = widgets.Button(\n", - " description='Start',\n", - " )\n", - " start_button.on_click(start_clicked)\n", - "\n", - " display(model_selection_widget)\n", - " display(start_button)\n", - "\n", - "do_inference(MODEL_DIR)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e38b956e", - "metadata": { - "cellView": "form" - }, - "outputs": [], - "source": [ - "# @title Get assessment stats\n", - "DAMAGE_SCORE_THRESHOLD = 0.5 # @param {type:\"number\"}\n", - "\n", - "make_download_button(\n", - " INFERENCE_CSV,\n", - " f'{ASSESSMENT_NAME}_assessment.csv',\n", - " 'Download CSV')\n", - "show_inference_stats(AOI_PATH, INFERENCE_CSV, DAMAGE_SCORE_THRESHOLD)\n", - "show_assessment_heatmap(AOI_PATH, INFERENCE_CSV, DAMAGE_SCORE_THRESHOLD, False)" - ] - } - ], - "metadata": { - "jupytext": { - "cell_metadata_filter": "cellView,-all", - "main_language": "python" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/colab/skai_assessment_notebook_custom_vm.py b/src/colab/skai_assessment_notebook_custom_vm.py deleted file mode 100644 index e02763ab..00000000 --- a/src/colab/skai_assessment_notebook_custom_vm.py +++ /dev/null @@ -1,1171 +0,0 @@ -# --- -# jupyter: -# jupytext: -# cell_metadata_filter: -all,cellView -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.16.4 -# kernelspec: -# display_name: Python 3 -# name: python3 -# --- - -# %% cellView="form" -# @title Configure Assessment Parameters - -# pylint:disable=missing-module-docstring -# pylint:disable=g-bad-import-order -# pylint:disable=g-wrong-blank-lines -# pylint:disable=g-import-not-at-top - -# @markdown You must re-run this cell every time you make a change. -import os -import textwrap -import ee - -GCP_PROJECT = '' # @param {type:"string"} -GCP_LOCATION = '' # @param {type:"string"} -GCP_BUCKET = '' # @param {type:"string"} -GCP_SERVICE_ACCOUNT = '' # @param {type:"string"} -SERVICE_ACCOUNT_KEY = '' # @param {type:"string"} -# @markdown This is only needed if BUILDINGS_METHOD is set to "run_model": -BUILDING_SEGMENTATION_MODEL_PATH = '' # @param {type:"string"} - -# @markdown --- -ASSESSMENT_NAME = '' # @param {type:"string"} -EVENT_DATE = '' # @param {type:"date"} -OUTPUT_DIR = '' # @param {type:"string"} - -# @markdown --- -BEFORE_IMAGE_0 = '' # @param {type:"string"} -BEFORE_IMAGE_1 = '' # @param {type:"string"} -BEFORE_IMAGE_2 = '' # @param {type:"string"} -BEFORE_IMAGE_3 = '' # @param {type:"string"} -BEFORE_IMAGE_4 = '' # @param {type:"string"} -BEFORE_IMAGE_5 = '' # @param {type:"string"} -BEFORE_IMAGE_6 = '' # @param {type:"string"} -BEFORE_IMAGE_7 = '' # @param {type:"string"} -BEFORE_IMAGE_8 = '' # @param {type:"string"} -BEFORE_IMAGE_9 = '' # @param {type:"string"} -# @markdown --- -AFTER_IMAGE_0 = '' # @param {type:"string"} -AFTER_IMAGE_1 = '' # @param {type:"string"} -AFTER_IMAGE_2 = '' # @param {type:"string"} -AFTER_IMAGE_3 = '' # @param {type:"string"} -AFTER_IMAGE_4 = '' # @param {type:"string"} -AFTER_IMAGE_5 = '' # @param {type:"string"} -AFTER_IMAGE_6 = '' # @param {type:"string"} -AFTER_IMAGE_7 = '' # @param {type:"string"} -AFTER_IMAGE_8 = '' # @param {type:"string"} -AFTER_IMAGE_9 = '' # @param {type:"string"} - -# Constants -SKAI_REPO = 'https://github.com/google-research/skai.git' -OPEN_BUILDINGS_FEATURE_COLLECTION = 'GOOGLE/Research/open-buildings/v3/polygons' -OSM_OVERPASS_URL = 'https://lz4.overpass-api.de/api/interpreter' - -# Derived variables -SKAI_CODE_DIR = '/content/skai_src' -AOI_PATH = os.path.join(OUTPUT_DIR, 'aoi.geojson') -BUILDINGS_FILE_LOG = os.path.join(OUTPUT_DIR, 'buildings_file_log.txt') -EXAMPLE_GENERATION_CONFIG_PATH = os.path.join( - OUTPUT_DIR, 'example_generation_config.json' -) -UNLABELED_TFRECORD_PATTERN = os.path.join( - OUTPUT_DIR, 'examples', 'unlabeled-large', 'unlabeled-*-of-*.tfrecord' -) -ZERO_SHOT_DIR = os.path.join(OUTPUT_DIR, 'zero_shot_model') -ZERO_SHOT_SCORES = os.path.join(ZERO_SHOT_DIR, 'dataset_0_output.csv') -LABELING_IMAGES_DIR = os.path.join(OUTPUT_DIR, 'labeling_images') -LABELING_EXAMPLES_TFRECORD_PATTERN = os.path.join( - LABELING_IMAGES_DIR, '*', 'labeling_examples.tfrecord' -) -LABELS_CSV = os.path.join(OUTPUT_DIR, 'labels.csv') -TRAIN_TFRECORD = os.path.join(OUTPUT_DIR, 'labeled_examples_train.tfrecord') -TEST_TFRECORD = os.path.join(OUTPUT_DIR, 'labeled_examples_test.tfrecord') -MODEL_DIR = os.path.join(OUTPUT_DIR, 'models') -INFERENCE_CSV = os.path.join(OUTPUT_DIR, 'inference_scores.csv') - - -def process_image_entries(entries: list[str]) -> list[str]: - image_ids = [] - for entry in entries: - entry = entry.strip() - if entry: - image_ids.append(entry) - return image_ids - - -BEFORE_IMAGES = process_image_entries([ - BEFORE_IMAGE_0, - BEFORE_IMAGE_1, - BEFORE_IMAGE_2, - BEFORE_IMAGE_3, - BEFORE_IMAGE_4, - BEFORE_IMAGE_5, - BEFORE_IMAGE_6, - BEFORE_IMAGE_7, - BEFORE_IMAGE_8, - BEFORE_IMAGE_9, -]) - -AFTER_IMAGES = process_image_entries([ - AFTER_IMAGE_0, - AFTER_IMAGE_1, - AFTER_IMAGE_2, - AFTER_IMAGE_3, - AFTER_IMAGE_4, - AFTER_IMAGE_5, - AFTER_IMAGE_6, - AFTER_IMAGE_7, - AFTER_IMAGE_8, - AFTER_IMAGE_9, -]) - -if os.path.exists(SERVICE_ACCOUNT_KEY): - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = SERVICE_ACCOUNT_KEY -else: - print(f'Service account key not found: "{SERVICE_ACCOUNT_KEY}"') - -# %% [markdown] -# #Initialization - -# %% cellView="form" -# @title Imports and Function Defs -# %load_ext tensorboard - -import collections -import io -import json -import math -import shutil -import subprocess -import time -import warnings - -import folium -import folium.plugins -import geopandas as gpd -from google.colab import data_table -from google.colab import files -from IPython.display import display -import ipywidgets as widgets -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import shapely.wkt -from skai import earth_engine as skai_ee -from skai import labeling -from skai import open_street_map -import tensorflow as tf -import tqdm.notebook - -data_table.enable_dataframe_formatter() - - -def convert_wgs_to_utm(lon: float, lat: float): - """Based on lat and lng, return best utm epsg-code.""" - utm_band = str((math.floor((lon + 180) / 6) % 60) + 1) - if len(utm_band) == 1: - utm_band = '0' + utm_band - if lat >= 0: - epsg_code = '326' + utm_band - else: - epsg_code = '327' + utm_band - return f'EPSG:{epsg_code}' - - -def get_aoi_area_km2(aoi_path: str): - with tf.io.gfile.GFile(aoi_path) as f: - aoi = gpd.read_file(f) - - centroid = aoi.geometry.unary_union.centroid - utm_crs = convert_wgs_to_utm(centroid.x, centroid.y) - utm_aoi = aoi.to_crs(utm_crs) - area_meters_squared = utm_aoi.geometry.unary_union.area - area_km_squared = area_meters_squared / 1000000 - return area_km_squared - - -def show_inference_stats( - aoi_path: str, - inference_csv_path: str, - threshold: float): - """Prints out statistics on inference result.""" - with tf.io.gfile.GFile(inference_csv_path) as f: - df = pd.read_csv(f) - building_count = len(df) - if 'damage_score' in df.columns: - scores = df['damage_score'] - elif 'score' in df.columns: - scores = df['score'] - else: - raise ValueError(f'{inference_csv_path} does not contain a score column.') - - damaged = df.loc[scores > threshold] - damaged_count = len(damaged) - damaged_pct = 100 * damaged_count / building_count - print('Area KM^2:', get_aoi_area_km2(aoi_path)) - print('Buildings assessed:', building_count) - print('Damaged buildings:', damaged_count) - print(f'Percentage damaged: {damaged_pct:0.3g}%') - - -def _open_file(path: str, mode: str): - f = tf.io.gfile.GFile(path, mode) - f.closed = False - return f - - -def _file_exists(path: str) -> bool: - return bool(tf.io.gfile.glob(path)) - - -def _read_text_file(path: str) -> str: - with tf.io.gfile.GFile(path, 'r') as f: - return f.read() - - -def _make_map(longitude: float, latitude: float, zoom: float): - """Creates a Folium map with common base layers. - - Args: - longitude: Longitude of initial view. - latitude: Latitude of initial view. - zoom: Zoom level of initial view. - - Returns: - Folium map. - """ - base_maps = [ - folium.TileLayer( - tiles='https://mt1.google.com/vt/lyrs=m&x={x}&y={y}&z={z}', - attr='Google', - name='Google Maps', - overlay=False, - control=True, - ), - ] - - m = folium.Map( - location=(latitude, longitude), - max_zoom=24, - zoom_start=zoom, - tiles=None) - for base_map in base_maps: - base_map.add_to(m) - return m - - -def show_assessment_heatmap( - aoi_path: str, - scores_path: str, - threshold: float, - is_zero_shot: bool): - """Creates a Folium heatmap from inference scores.""" - with _open_file(scores_path, 'rb') as f: - df = pd.read_csv(f) - if is_zero_shot: - damaged = df.loc[~df['is_cloudy'] & (df['damage_score'] >= threshold)] - else: - damaged = df.loc[df['score'] >= threshold] - points = zip(damaged['latitude'].values, damaged['longitude'].values) - centroid_x = np.mean(damaged['longitude'].values) - centroid_y = np.mean(damaged['latitude'].values) - folium_map = _make_map(centroid_x, centroid_y, 12) - with _open_file(aoi_path, 'rb') as f: - aoi_gdf = gpd.read_file(f) - folium.GeoJson( - aoi_gdf.to_json(), - name='AOI', - style_function=lambda _: {'fillOpacity': 0}, - ).add_to(folium_map) - heatmap = folium.plugins.HeatMap(points) - heatmap.add_to(folium_map) - display(folium_map) - - -def make_download_button(path: str, file_name: str, caption: str): - """Displays a button for downloading a file in the colab kernel.""" - def download(_): - temp_path = f'/tmp/{file_name}' - with _open_file(path, 'rb') as src: - with open(temp_path, 'wb') as dst: - shutil.copyfileobj(src, dst) - files.download(temp_path) - - button = widgets.Button( - description=caption, - ) - button.on_click(download) - display(button) - - -def find_model_dirs(model_root: str): - # Find all checkpoints dirs first. We only want model dirs that have at least - # one checkpoint. - checkpoint_dirs = tf.io.gfile.glob( - os.path.join(model_root, '*/*/model/epoch-*-aucpr-*')) - model_dirs = set(os.path.dirname(os.path.dirname(p)) for p in checkpoint_dirs) - return model_dirs - - -def find_labeling_image_metadata_files(labeling_images_dir: str): - return tf.io.gfile.glob(os.path.join( - labeling_images_dir, '*', 'image_metadata.csv')) - - -def yes_no_text(value: bool) -> str: - return '\x1b[32mYES\x1b[0m' if value else '\x1b[31mNO\x1b[0m' - - -def visualize_images(images: list[tuple[np.ndarray, np.ndarray]]): - """Displays before and after images side-by-side.""" - num_rows = len(images) - size_factor = 3 - fig_size = (2 * size_factor, num_rows * size_factor) - fig, axes = plt.subplots(num_rows, 2, figsize=fig_size) - for row, (pre_image, post_image) in enumerate(images): - ax1 = axes[row, 0] - ax2 = axes[row, 1] - ax1.axis('off') - ax2.axis('off') - ax1.imshow(pre_image) - ax2.imshow(post_image) - plt.show(fig) - - -def get_eeda_bearer_token(service_account: str) -> str: - return subprocess.check_output( - 'gcloud auth print-access-token' - f' --impersonate-service-account="{service_account}"', - shell=True, - ).decode() - - -# %% cellView="form" -# @title Authenticate with Earth Engine - -def auth(): - credentials = ee.ServiceAccountCredentials( - GCP_SERVICE_ACCOUNT, SERVICE_ACCOUNT_KEY) - ee.Initialize(credentials) - -auth() - - -# %% [markdown] -# # Check Assessment Status -# -# Run the following cell to check which steps of the assessment have already -# been completed. - -# %% cellView="form" -# @title Check assessment status -def check_assessment_status(): - """Shows which steps of the assessment have been completed.""" - print('AOI uploaded:', yes_no_text(_file_exists(AOI_PATH))) - - if _file_exists(BUILDINGS_FILE_LOG): - buildings_file = _read_text_file(BUILDINGS_FILE_LOG).strip() - print('Building footprints generated:', yes_no_text(True)) - print(f' Building footprints file: {buildings_file}') - else: - print('Building footprints generated:', yes_no_text(False)) - - print( - 'Example generation config file exists:', - yes_no_text(_file_exists(EXAMPLE_GENERATION_CONFIG_PATH)), - ) - print( - 'Unlabeled examples generated:', - yes_no_text(_file_exists(UNLABELED_TFRECORD_PATTERN)), - ) - print( - 'Zero-shot assessment generated:', - yes_no_text(_file_exists(ZERO_SHOT_SCORES)), - ) - labeling_metadata_files = find_labeling_image_metadata_files( - LABELING_IMAGES_DIR - ) - print( - 'Labeling images generated:', yes_no_text(bool(labeling_metadata_files)) - ) - for p in labeling_metadata_files: - print(f' {p}') - print('Label CSV uploaded:', yes_no_text(_file_exists(LABELS_CSV))) - print( - 'Labeled examples generated:', - yes_no_text(_file_exists(TRAIN_TFRECORD) and _file_exists(TEST_TFRECORD)), - ) - trained_models = find_model_dirs(MODEL_DIR) - print('Fine-tuned model trained:', yes_no_text(bool(trained_models))) - for model_dir in trained_models: - print(f' {model_dir}') - - print( - 'Fine-tuned inference generated:', - yes_no_text(_file_exists(INFERENCE_CSV)), - ) - -check_assessment_status() - - -# %% [markdown] -# # Example Generation - -# %% cellView="form" -# @title Upload AOI file -def upload_aoi(): - """Shows button for user to upload AOI to the assessment directory.""" - if _file_exists(AOI_PATH): - print(f'AOI file {AOI_PATH} already exists.') - answer = input('Do you want to overwrite (y/n)? ') - if answer.lower() not in ['y', 'yes']: - print('AOI file not uploaded.') - return - - uploaded = files.upload() - - file_names = list(uploaded.keys()) - if len(file_names) != 1: - print('You must choose exactly one GeoJSON file to upload.') - print('Upload NOT successful.') - return - - if not file_names[0].endswith('.geojson'): - print('AOI file must be in GeoJSON format and have extension ".geojson".') - print('Upload NOT successful.') - return - - with _open_file(AOI_PATH, 'wb') as f: - f.write(uploaded[file_names[0]]) - -upload_aoi() - -# %% cellView="form" -# @title Get building footprints - -# pylint:disable=line-too-long -BUILDINGS_METHOD = 'open_buildings' # @param ["open_buildings","open_street_map","run_model","file"] -# pylint:enable=line-too-long -# @markdown This is only needed if BUILDINGS_METHOD is set to "file": -USER_BUILDINGS_FILE = '' # @param {type:"string"} - - -def download_open_buildings(aoi_path: str, output_dir: str) -> str: - path = os.path.join(output_dir, 'open_buildings.parquet') - with _open_file(aoi_path, 'r') as f: - gdf = gpd.read_file(f) - aoi = gdf.unary_union - skai_ee.get_open_buildings( - [aoi], OPEN_BUILDINGS_FEATURE_COLLECTION, 0.5, False, path) - return path - - -def download_open_street_map(aoi_path: str, output_dir: str) -> str: - path = os.path.join(output_dir, 'open_street_map_buildings.parquet') - with _open_file(aoi_path, 'r') as f: - gdf = gpd.read_file(f) - aoi = gdf.unary_union - open_street_map.get_building_centroids_in_regions( - [aoi], OSM_OVERPASS_URL, path - ) - return path - - -def run_building_detection_model( - aoi_path: str, - output_dir: str): - """Runs building detection model.""" - image_paths = ','.join(BEFORE_IMAGES) - child_dir = os.path.join(output_dir, 'buildings') - if any('EEDAI:' in image for image in BEFORE_IMAGES): - token = get_eeda_bearer_token(GCP_SERVICE_ACCOUNT) - eeda_bearer_env = f'export EEDA_BEARER="{token}"' - else: - eeda_bearer_env = '' - - script = textwrap.dedent(f''' - export PYTHONPATH={SKAI_CODE_DIR}/src:$PYTHONPATH - export GOOGLE_CLOUD_PROJECT={GCP_PROJECT} - {eeda_bearer_env} - cd {SKAI_CODE_DIR}/src - python detect_buildings_main.py \ - --cloud_project='{GCP_PROJECT}' \ - --cloud_region='{GCP_LOCATION}' \ - --worker_service_account='{GCP_SERVICE_ACCOUNT}' \ - --use_dataflow \ - --output_dir='{output_dir}' \ - --image_paths='{image_paths}' \ - --aoi_path='{aoi_path}' \ - --model_path='{BUILDING_SEGMENTATION_MODEL_PATH}' - ''') - script_path = '/content/run_building_detection.sh' - with open(script_path, 'w') as f: - f.write(script) - # !bash {script_path} - - buildings_file = os.path.join(child_dir, 'dedup_buildings.parquet') - return buildings_file - - -def _display_building_footprints(buildings_gdf: gpd.GeoDataFrame): - """Visualizes building footprints in a folium map.""" - centroid = buildings_gdf.centroid.unary_union.centroid - folium_map = _make_map(centroid.x, centroid.y, 13) - folium.GeoJson( - buildings_gdf.to_json(), - name='buildings', - marker=folium.CircleMarker( - radius=3, weight=0, fill_color='#FF0000', fill_opacity=1 - ), - ).add_to(folium_map) - display(folium_map) - - -def download_buildings(aoi_path: str, output_dir: str) -> None: - """Downloads buildings to assessment directory.""" - if BUILDINGS_METHOD == 'open_buildings': - path = download_open_buildings(aoi_path, output_dir) - elif BUILDINGS_METHOD == 'open_street_map': - path = download_open_street_map(aoi_path, output_dir) - elif BUILDINGS_METHOD == 'run_model': - path = run_building_detection_model(aoi_path, output_dir) - elif BUILDINGS_METHOD == 'file': - path = USER_BUILDINGS_FILE - else: - raise ValueError(f'Unknown BUILDINGS_METHOD {BUILDINGS_METHOD}') - - with _open_file(BUILDINGS_FILE_LOG, 'w') as f: - f.write(f'{path}\n') - - with _open_file(path, 'rb') as f: - if path.endswith('.csv'): - df = pd.read_csv(f) - df['geometry'] = df['wkt'].apply(shapely.wkt.loads) - gdf = gpd.GeoDataFrame(df.drop(columns=['wkt']), crs='EPSG:4326') - elif path.endswith('.parquet'): - gdf = gpd.read_parquet(f) - else: - gdf = gpd.read_file(f) - print(f'Found {len(gdf)} buildings.') - print(f'Saved buildings to {path}') - if len(gdf) < 500000: - _display_building_footprints(gdf) - else: - print('Too many buildings to visualize. Use QGIS instead.') - -with warnings.catch_warnings(): - warnings.simplefilter('ignore') - download_buildings(AOI_PATH, OUTPUT_DIR) - - -# %% cellView="form" -# @title Write Example Generation Config File -def write_example_generation_config(path: str) -> None: - """Writes example generation config file to assessment directory.""" - dataset_name = ASSESSMENT_NAME.lower().replace('_', '-') - with _open_file(BUILDINGS_FILE_LOG, 'r') as f: - buildings_file = f.read().strip() - - config_dict = { - 'dataset_name': dataset_name, - 'aoi_path': AOI_PATH, - 'output_dir': OUTPUT_DIR, - 'buildings_method': 'file', - 'buildings_file': buildings_file, - 'resolution': 0.5, - 'use_dataflow': True, - 'cloud_project': GCP_PROJECT, - 'cloud_region': GCP_LOCATION, - 'worker_service_account': GCP_SERVICE_ACCOUNT, - 'max_dataflow_workers': 100, - 'output_shards': 100, - 'output_metadata_file': True, - 'before_image_patterns': BEFORE_IMAGES, - 'after_image_patterns': AFTER_IMAGES, - } - - valid_config = True - for key, value in config_dict.items(): - if not value: - if key == 'buildings_file' and config_dict['buildings_method'] != 'file': - continue - print(f'Field {key} cannot be empty') - valid_config = False - if not valid_config: - return - - config_string = json.dumps(config_dict, indent=2) - print(f'Example Generation configuration written to {path}:') - print() - print(config_string) - with tf.io.gfile.GFile(path, 'w') as f: - f.write(config_string) - -write_example_generation_config(EXAMPLE_GENERATION_CONFIG_PATH) - - -# %% cellView="form" -# @title Run Example Generation Job -def run_example_generation(config_file_path: str): - """Runs example generation pipeline.""" - - script = textwrap.dedent(f''' - cd {SKAI_CODE_DIR}/src - export GOOGLE_APPLICATION_CREDENTIALS={SERVICE_ACCOUNT_KEY} - python generate_examples_main.py \ - --configuration_path={config_file_path} \ - --output_metadata_file - ''') - - script_path = '/content/example_generation.sh' - with open(script_path, 'w') as f: - f.write(script) - # !bash {script_path} - -run_example_generation(EXAMPLE_GENERATION_CONFIG_PATH) - - -# %% cellView="form" -# @title Visualize Generated Examples -def visualize_generated_examples(pattern: str, num: int): - images = [] - paths = tf.io.gfile.glob(pattern) - for record in tf.data.TFRecordDataset([paths[0]]).take(num): - example = tf.train.Example() - example.ParseFromString(record.numpy()) - pre_image = plt.imread(io.BytesIO( - example.features.feature['pre_image_png_large'].bytes_list.value[0])) - post_image = plt.imread(io.BytesIO( - example.features.feature['post_image_png_large'].bytes_list.value[0])) - images.append((pre_image, post_image)) - visualize_images(images) - -visualize_generated_examples(UNLABELED_TFRECORD_PATTERN, 3) - - -# %% cellView="form" -# @title Run Zero Shot Model -def run_zero_shot_model(): - """Runs zero-shot model inference.""" - script = textwrap.dedent(f''' - export PYTHONPATH={SKAI_CODE_DIR}/src:$PYTHONPATH - export GOOGLE_CLOUD_PROJECT={GCP_PROJECT} - export GOOGLE_CLOUD_BUCKET_NAME={GCP_BUCKET} - export GOOGLE_APPLICATION_CREDENTIALS={SERVICE_ACCOUNT_KEY} - cd {SKAI_CODE_DIR}/src - - xmanager launch skai/model/xm_vlm_zero_shot_vertex.py -- \ - --example_patterns={UNLABELED_TFRECORD_PATTERN} \ - --output_dir={ZERO_SHOT_DIR} - ''') - - print( - 'Starting zero shot model inference. Scores will be written to' - f' {ZERO_SHOT_SCORES}' - ) - script_path = '/content/zero_shot_model.sh' - with open(script_path, 'w') as f: - f.write(script) - # !bash {script_path} - -run_zero_shot_model() - -# %% cellView="form" -# @title View Zero Shot Assessment -DAMAGE_SCORE_THRESHOLD = 0.5 # @param {type:"number"} - -make_download_button( - ZERO_SHOT_SCORES, - f'{ASSESSMENT_NAME}_zero_shot_assessment.csv', - 'Download CSV') -show_inference_stats(AOI_PATH, ZERO_SHOT_SCORES, DAMAGE_SCORE_THRESHOLD) -show_assessment_heatmap( - AOI_PATH, ZERO_SHOT_SCORES, DAMAGE_SCORE_THRESHOLD, True -) - -# %% [markdown] -# # Labeling - -# %% cellView="form" -# @title Create Labeling Images -MAX_LABELING_IMAGES = 1000 # @param {"type":"integer"} - - -def visualize_labeling_images(images_dir: str, num: int): - """Displays a small sample of labeling images.""" - pre_image_paths = sorted( - tf.io.gfile.glob(os.path.join(images_dir, '*_pre.png')) - ) - post_image_paths = sorted( - tf.io.gfile.glob(os.path.join(images_dir, '*_post.png')) - ) - assert len(pre_image_paths) == len(post_image_paths), ( - f'Number of pre images ({len(pre_image_paths)}) does not match number of' - f' post images ({len(post_image_paths)}).' - ) - images = [] - for pre_image_path, post_image_path in list( - zip(pre_image_paths, post_image_paths) - )[:num]: - with _open_file(pre_image_path, 'rb') as f: - pre_image = plt.imread(f) - with _open_file(post_image_path, 'rb') as f: - post_image = plt.imread(f) - images.append((pre_image, post_image)) - visualize_images(images) - - -def create_labeling_images( - examples_pattern: str, - scores_file: str, - output_dir: str, - max_images: int, -): - """Creates labeling images.""" - if not tf.io.gfile.glob(examples_pattern): - print( - f'No files match "{examples_pattern}". Please run example generation' - ' first.' - ) - return - - existing_metadata_files = find_labeling_image_metadata_files(output_dir) - if existing_metadata_files: - print( - 'The following labeling image metadata files have already been' - ' generated:' - ) - print('\n'.join(f' {p}' for p in existing_metadata_files)) - response = input( - 'Do you want to generate a new set of labeling images (y/n)? ' - ) - if response.lower() not in ['y', 'yes']: - return - - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) - images_dir = os.path.join(output_dir, timestamp) - metadata_csv = os.path.join(images_dir, 'image_metadata.csv') - - num_images = labeling.create_labeling_images( - examples_pattern, - max_images, - set(), - set(), - images_dir, - True, - None, - 4, - 70.0, - { - (0, 0.25): 0.25, - (0.25, 0.5): 0.25, - (0.5, 0.75): 0.25, - (0.75, 1.0): 0.25, - }, - scores_path=scores_file, - filter_by_column='is_cloudy', - ) - print('Number of labeling images:', num_images) - print( - 'Please create a new project in the SKAI labeling tool with the following' - ' metadata CSV:' - ) - print(metadata_csv) - visualize_labeling_images(images_dir, 3) - -create_labeling_images( - UNLABELED_TFRECORD_PATTERN, - ZERO_SHOT_SCORES, - LABELING_IMAGES_DIR, - MAX_LABELING_IMAGES, -) - - -# %% [markdown] -# When the labeling project is complete, download the CSV from the labeling tool -# and upload it to your assessment directory using the following cell. -# -# You may upload multiple CSV files at once, in case you wish to combine labels -# from multiple rounds of labeling. - -# %% cellView="form" -# @title Upload Label CSV -def upload_label_csvs(output_path: str): - """Lets the user upload the labeling CSV file from their computer.""" - uploaded = files.upload() - dfs = [] - for filename in uploaded.keys(): - f = io.BytesIO(uploaded[filename]) - df = pd.read_csv(f) - if 'example_id' not in df.columns: - print('"example_id" column not found in {filename}') - return - if 'string_label' not in df.columns: - print('"string_label" column not found in {filename}') - return - dfs.append(df) - print(f'Read {len(df)} rows from {filename}') - - combined = pd.concat(dfs, ignore_index=True) - - with tf.io.gfile.GFile(output_path, 'wb') as f: - f.closed = False - combined.to_csv(f, index=False) - -upload_label_csvs(LABELS_CSV) - -# %% cellView="form" -# @title Create Labeled Examples -TEST_FRACTION = 0.2 # @param {"type":"number"} -MINOR_IS_0 = False # @param {"type":"boolean"} - - -def create_labeled_examples( - examples_pattern: str, - labels_csv: str, - test_fraction: float, - train_path: str, - test_path: str, - minor_is_0: bool): - """Creates labeled train and test TFRecords files.""" - - minor_damage_float_label = (0 if minor_is_0 else 1) - label_mapping = [ - 'bad_example=0', - 'no_damage=0', - f'minor_damage={minor_damage_float_label}', - 'major_damage=1', - 'destroyed=1', - ] - - labeling.create_labeled_examples( - label_file_paths=[labels_csv], - string_to_numeric_labels=label_mapping, - example_patterns=[examples_pattern], - test_fraction=test_fraction, - train_output_path=train_path, - test_output_path=test_path, - connecting_distance_meters=70.0, - use_multiprocessing=False, - multiprocessing_context=None, - max_processes=1, - ) - print(f'Train TFRecord: {train_path}') - print(f'Test TFRecord: {test_path}') - -create_labeled_examples( - LABELING_EXAMPLES_TFRECORD_PATTERN, - LABELS_CSV, - TEST_FRACTION, - TRAIN_TFRECORD, - TEST_TFRECORD, - MINOR_IS_0) - - -# %% cellView="form" -# @title Show Label Stats -def _load_examples_into_df( - train_tfrecords: str, - test_tfrecords: str, -) -> pd.DataFrame: - """Loads examples from TFRecords into a DataFrame. - """ - feature_config = { - 'example_id': tf.io.FixedLenFeature([], tf.string), - 'coordinates': tf.io.FixedLenFeature([2], tf.float32), - 'string_label': tf.io.FixedLenFeature([], tf.string, 'unlabeled'), - 'label': tf.io.FixedLenFeature([], tf.float32), - } - - def _parse_examples(record_bytes): - return tf.io.parse_single_example(record_bytes, feature_config) - - columns = collections.defaultdict(list) - longitudes = [] - latitudes = [] - for path in [train_tfrecords, test_tfrecords]: - for features in tqdm.notebook.tqdm( - tf.data.TFRecordDataset([path]) - .map(_parse_examples, num_parallel_calls=tf.data.AUTOTUNE) - .prefetch(tf.data.AUTOTUNE) - .as_numpy_iterator(), - desc=path, - ): - longitudes.append(features['coordinates'][0]) - latitudes.append(features['coordinates'][1]) - columns['example_id'].append(features['example_id'].decode()) - columns['string_label'].append(features['string_label'].decode()) - columns['label'].append(features['label']) - columns['source_path'].append(path) - - return pd.DataFrame(columns) - -def _format_counts_table(df: pd.DataFrame): - for column in df.columns: - if column != 'All': - df[column] = [ - f'{x} ({x/t * 100:0.2f}%)' for x, t in zip(df[column], df['All']) - ] - -def show_label_stats(train_tfrecord: str, test_tfrecord: str): - """Displays tables showing label count stats.""" - df = _load_examples_into_df(train_tfrecord, test_tfrecord) - counts = df.pivot_table( - index='source_path', - columns='string_label', - aggfunc='count', - values='example_id', - margins=True, - fill_value=0) - _format_counts_table(counts) - - print('String Label Counts') - display(data_table.DataTable(counts)) - - float_counts = df.pivot_table( - index='source_path', - columns='label', - aggfunc='count', - values='example_id', - margins=True, - fill_value=0.0) - _format_counts_table(float_counts) - print('Float Label Counts') - display(data_table.DataTable(float_counts)) - - -show_label_stats(TRAIN_TFRECORD, TEST_TFRECORD) - -# %% [markdown] -# # Fine Tuning - -# %% cellView="form" -# @title Train model - -NUM_EPOCHS = 20 # @param {type:"integer"} - - -def run_training( - experiment_name: str, - train_path: str, - test_path: str, - output_dir: str, - num_epochs: int): - """Runs training job.""" - if not tf.io.gfile.exists(train_path): - raise ValueError( - f'Train TFRecord {train_path} does not exist. Did you run the "Create' - ' Labeled Examples" cell?' - ) - if not tf.io.gfile.exists(test_path): - raise ValueError( - f'Test TFRecord {test_path} does not exist. Did you run the "Create' - ' Labeled Examples" cell?' - ) - - print(f'Train data: {train_path}') - print(f'Test data: {test_path}') - print(f'Model dir: {output_dir}') - job_args = { - 'config': 'src/skai/model/configs/skai_two_tower_config.py', - 'config.data.tfds_dataset_name': 'skai_dataset', - 'config.data.adhoc_config_name': 'adhoc_dataset', - 'config.data.labeled_train_pattern': train_path, - 'config.data.validation_pattern': test_path, - 'config.output_dir': output_dir, - 'config.training.num_epochs': num_epochs, - 'accelerator': 'V100', - 'experiment_name': experiment_name, - } - job_arg_str = ' '.join(f'--{f}={v}' for f, v in job_args.items()) - sh = textwrap.dedent(f''' - export GOOGLE_CLOUD_PROJECT={GCP_PROJECT} - export GOOGLE_CLOUD_BUCKET_NAME={GCP_BUCKET} - export GOOGLE_APPLICATION_CREDENTIALS={SERVICE_ACCOUNT_KEY} - export LOCATION={GCP_LOCATION} - - cd {SKAI_CODE_DIR} - - xmanager launch src/skai/model/xm_launch_single_model_vertex.py -- \ - --xm_wrap_late_bindings \ - --xm_upgrade_db=True \ - --cloud_location=$LOCATION \ - --accelerator_count=1 {job_arg_str}''') - - with open('script.sh', 'w') as file: - file.write(sh) - - # !bash script.sh - -run_training( - ASSESSMENT_NAME, - TRAIN_TFRECORD, - TEST_TFRECORD, - MODEL_DIR, - NUM_EPOCHS) - - -# %% cellView="form" -# @title View Tensorboard -def start_tensorboard(model_root: str): - """Shows Tensorboard visualization.""" - tensorboard_dirs = tf.io.gfile.glob( - os.path.join(model_root, '*/*/tensorboard') - ) - if not tensorboard_dirs: - print( - 'No Tensorboard directories found. Either you have not trained a model' - ' yet or a running job has not written any tensorboard log events yet.' - ) - return - - dir_selection_widget = widgets.Dropdown( - options=tensorboard_dirs, - description='Choose a tensorboard dir:', - layout={'width': 'initial'}, - ) - dir_selection_widget.style.description_width = 'initial' - - def run_tensorboard(_): - # pylint:disable=unused-variable - tensorboard_dir = dir_selection_widget.value - # %tensorboard --load_fast=false --logdir $tensorboard_dir - # pylint:enable=unused-variable - - start_button = widgets.Button( - description='Start', - ) - start_button.on_click(run_tensorboard) - - display(dir_selection_widget) - display(start_button) - -start_tensorboard(MODEL_DIR) - - -# %% cellView="form" -# @title Run inference -def get_best_checkpoint(model_dir: str): - checkpoint_dirs = tf.io.gfile.glob(os.path.join(model_dir, 'epoch-*-aucpr-*')) - best_checkpoint = None - best_aucpr = 0 - for checkpoint in checkpoint_dirs: - aucpr = float(checkpoint.split('-')[-1]) - if aucpr > best_aucpr: - best_checkpoint = checkpoint - best_aucpr = aucpr - return best_checkpoint - - -def run_inference( - examples_pattern: str, - model_dir: str, - output_dir: str, - output_path: str, - cloud_project: str, - cloud_region: str, - service_account: str) -> None: - """Starts model inference job.""" - temp_dir = os.path.join(output_dir, 'inference_temp') - print( - f'Running inference with model checkpoint "{model_dir}" on examples' - f' matching "{examples_pattern}"' - ) - print(f'Output will be written to {output_path}') - - # accelerator_flags = ' '.join([ - # '--worker_machine_type=n1-highmem-8', - # '--accelerator=nvidia-tesla-t4', - # '--accelerator_count=1']) - - # Currently, Colab only supports Python 3.10. However, the docker images we - # need for GPU acceleration are based on Tensorflow 2.14.0 images, which are - # based on Python 3.11. If we try to launch an inference job with GPU - # acceleration, Dataflow will complain about a Python version mismatch. - # Therefore, we can only use CPU inference until Colab upgrades to Python 3.11 - # (which should be sometime within 2024). - accelerator_flags = '' - - script = textwrap.dedent(f''' - cd {SKAI_CODE_DIR}/src - export GOOGLE_CLOUD_PROJECT={cloud_project} - export GOOGLE_APPLICATION_CREDENTIALS={SERVICE_ACCOUNT_KEY} - python skai/model/inference.py \ - --examples_pattern='{examples_pattern}' \ - --image_model_dir='{model_dir}' \ - --output_path='{output_path}' \ - --use_dataflow \ - --cloud_project='{cloud_project}' \ - --cloud_region='{cloud_region}' \ - --dataflow_temp_dir='{temp_dir}' \ - --worker_service_account='{service_account}' \ - --threshold=0.5 \ - --high_precision_threshold=0.75 \ - --high_recall_threshold=0.4 \ - --max_dataflow_workers=4 {accelerator_flags} - ''') - - script_path = '/content/inference_script.sh' - with open(script_path, 'w') as f: - f.write(script) - # !bash {script_path} - - -def do_inference(model_root: str): - """Runs model inference.""" - model_dirs = find_model_dirs(model_root) - if not model_dirs: - print( - f'No models found in directory {model_root}. Please train a model' - ' first.' - ) - return - - model_selection_widget = widgets.Dropdown( - options=model_dirs, - description='Choose a model:', - layout={'width': 'initial'}, - ) - model_selection_widget.style.description_width = 'initial' - - def start_clicked(_): - model_dir = os.path.join(model_selection_widget.value, 'model') - checkpoint = get_best_checkpoint(model_dir) - if not checkpoint: - print('Model directory does not contain a valid checkpoint directory.') - return - run_inference( - UNLABELED_TFRECORD_PATTERN, - checkpoint, - OUTPUT_DIR, - INFERENCE_CSV, - GCP_PROJECT, - GCP_LOCATION, - GCP_SERVICE_ACCOUNT, - ) - - start_button = widgets.Button( - description='Start', - ) - start_button.on_click(start_clicked) - - display(model_selection_widget) - display(start_button) - -do_inference(MODEL_DIR) - -# %% cellView="form" -# @title Get assessment stats -DAMAGE_SCORE_THRESHOLD = 0.5 # @param {type:"number"} - -make_download_button( - INFERENCE_CSV, - f'{ASSESSMENT_NAME}_assessment.csv', - 'Download CSV') -show_inference_stats(AOI_PATH, INFERENCE_CSV, DAMAGE_SCORE_THRESHOLD) -show_assessment_heatmap(AOI_PATH, INFERENCE_CSV, DAMAGE_SCORE_THRESHOLD, False) diff --git a/src/colab/sync_notebook_source.py b/src/colab/sync_notebook_source.py index 497924c8..e618d3e2 100644 --- a/src/colab/sync_notebook_source.py +++ b/src/colab/sync_notebook_source.py @@ -32,10 +32,7 @@ 'GCP_LOCATION': '', 'GCP_BUCKET': '', 'GCP_SERVICE_ACCOUNT': '', - 'SERVICE_ACCOUNT_KEY': '', 'BUILDING_SEGMENTATION_MODEL_PATH': '', - 'BUILDINGS_METHOD': 'open_buildings', - 'USER_BUILDINGS_FILE': '', 'ASSESSMENT_NAME': '', 'EVENT_DATE': '', 'OUTPUT_DIR': '', @@ -59,7 +56,6 @@ 'AFTER_IMAGE_7': '', 'AFTER_IMAGE_8': '', 'AFTER_IMAGE_9': '', - 'DAMAGE_SCORE_THRESHOLD': 0.5, } diff --git a/src/detect_buildings_main.py b/src/detect_buildings_main.py index 941e51ac..05128acb 100644 --- a/src/detect_buildings_main.py +++ b/src/detect_buildings_main.py @@ -67,9 +67,6 @@ 'worker_service_account', None, 'Service account that will launch Dataflow workers. If unset, workers will ' 'run with the project\'s default Compute Engine service account.') -flags.DEFINE_integer( - 'min_dataflow_workers', 10, 'Minimum number of dataflow workers' -) flags.DEFINE_integer( 'max_dataflow_workers', None, 'Maximum number of dataflow workers' ) @@ -113,7 +110,6 @@ def main(args): FLAGS.cloud_project, FLAGS.cloud_region, temp_dir, - FLAGS.min_dataflow_workers, FLAGS.max_dataflow_workers, FLAGS.worker_service_account, machine_type=FLAGS.worker_machine_type, diff --git a/src/generate_examples_main.py b/src/generate_examples_main.py index 935e7ced..c6a5a758 100644 --- a/src/generate_examples_main.py +++ b/src/generate_examples_main.py @@ -51,9 +51,6 @@ 'worker_service_account', None, 'Service account that will launch Dataflow workers. If unset, workers will ' 'run with the project\'s default Compute Engine service account.') -flags.DEFINE_integer( - 'min_dataflow_workers', None, 'Minimum number of dataflow workers' -) flags.DEFINE_integer( 'max_dataflow_workers', None, 'Maximum number of dataflow workers' ) diff --git a/src/skai/beam_utils.py b/src/skai/beam_utils.py index 717aedc6..2204479b 100644 --- a/src/skai/beam_utils.py +++ b/src/skai/beam_utils.py @@ -108,7 +108,6 @@ def get_pipeline_options( project: str, region: str, temp_dir: str, - min_workers: int, max_workers: int, worker_service_account: str | None, machine_type: str | None, @@ -124,7 +123,6 @@ def get_pipeline_options( project: GCP project. region: GCP region. temp_dir: Temporary data location. - min_workers: Minimum number of Dataflow workers. max_workers: Maximum number of Dataflow workers. worker_service_account: Email of the service account will launch workers. If None, uses the project's default Compute Engine service account @@ -166,11 +164,8 @@ def get_pipeline_options( if machine_type: options['machine_type'] = machine_type - service_options = [ - f'min_num_workers={min_workers}', - ] if accelerator: - service_options.extend([ + options['dataflow_service_options'] = ';'.join([ f'worker_accelerator=type:{accelerator}', f'count:{accelerator_count}', 'install-nvidia-driver', @@ -179,5 +174,4 @@ def get_pipeline_options( else: options['sdk_container_image'] = _get_dataflow_container_image('cpu') - options['dataflow_service_options'] = ';'.join(service_options) return PipelineOptions.from_dictionary(options) diff --git a/src/skai/generate_examples.py b/src/skai/generate_examples.py index 21d69726..015a870b 100644 --- a/src/skai/generate_examples.py +++ b/src/skai/generate_examples.py @@ -14,6 +14,7 @@ """Pipeline for generating tensorflow examples from satellite images.""" import binascii +import csv import dataclasses import hashlib import itertools @@ -27,8 +28,6 @@ from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple import apache_beam as beam -import apache_beam.dataframe.convert -import apache_beam.dataframe.io import cv2 import geopandas as gpd import numpy as np @@ -120,8 +119,6 @@ class ExamplesGenerationConfig: use_dataflow: If true, execute pipeline in Cloud Dataflow. output_metadata_file: Output a CSV metadata file for all generated examples. worker_service_account: If using Dataflow, the service account to run as. - min_dataflow_workers: If using Dataflow, the minimum number of workers to - instantiate. max_dataflow_workers: If using Dataflow, the number of workers to instantiate. example_patch_size: Size of the example image. @@ -174,7 +171,6 @@ class ExamplesGenerationConfig: use_dataflow: bool = False output_metadata_file: bool = True worker_service_account: Optional[str] = None - min_dataflow_workers: int = 10 max_dataflow_workers: int = 20 example_patch_size: int = 64 large_patch_size: int = 256 @@ -874,7 +870,6 @@ def _generate_examples_pipeline( cloud_project: Optional[str], cloud_region: Optional[str], worker_service_account: Optional[str], - min_workers: int, max_workers: int, wait_for_dataflow_job: bool, cloud_detector_model_path: Optional[str], @@ -898,7 +893,6 @@ def _generate_examples_pipeline( cloud_project: Cloud project name. cloud_region: Cloud region, e.g. us-central1. worker_service_account: Email of service account that will launch workers. - min_workers: Minimum number of workers to use. max_workers: Maximum number of workers to use. wait_for_dataflow_job: If true, wait for dataflow job to complete before returning. @@ -914,7 +908,6 @@ def _generate_examples_pipeline( cloud_project, cloud_region, temp_dir, - min_workers, max_workers, worker_service_account, machine_type=None, @@ -945,14 +938,28 @@ def _generate_examples_pipeline( num_shards=num_output_shards)) if output_metadata_file: - rows = ( + field_names = [ + 'example_id', + 'encoded_coordinates', + 'longitude', + 'latitude', + 'post_image_id', + 'pre_image_id', + 'plus_code', + ] + _ = ( examples - | 'extract_metadata_rows' >> beam.Map(_get_example_metadata) - | 'remove_duplicates' >> beam.Distinct() + | 'convert_metadata_examples_to_dict' >> beam.Map(_get_example_metadata) + | 'combine_to_list' >> beam.combiners.ToList() + | 'write_metadata_to_file' + >> beam.ParDo( + WriteMetadataToCSVFn( + metadata_output_file_path=( + f'{output_dir}/examples/metadata_examples.csv' + ), field_names=field_names + ) + ) ) - df = apache_beam.dataframe.convert.to_dataframe(rows) - output_prefix = f'{output_dir}/examples/metadata/metadata.csv' - apache_beam.dataframe.io.to_csv(df, output_prefix, index=False) result = pipeline.run() if wait_for_dataflow_job: @@ -1082,7 +1089,6 @@ def run_example_generation( config.cloud_project, config.cloud_region, config.worker_service_account, - config.min_dataflow_workers, config.max_dataflow_workers, wait_for_dataflow, config.cloud_detector_model_path, @@ -1090,6 +1096,27 @@ def run_example_generation( ) +class WriteMetadataToCSVFn(beam.DoFn): + """DoFn to write meta data of examples to csv file. + + Attributes: + metadata_output_file_path: File path to output meta data of all examples. + field_names: Field names to be included in output file. + """ + + def __init__(self, metadata_output_file_path: str, field_names: List[str]): + self.metadata_output_file_path = metadata_output_file_path + self.field_names = field_names + + def process(self, element): + with tf.io.gfile.GFile( + self.metadata_output_file_path, 'w' + ) as csv_output_file: + csv_writer = csv.DictWriter(csv_output_file, fieldnames=self.field_names) + csv_writer.writeheader() + csv_writer.writerows(element) + + class ExampleType(typing.NamedTuple): example_id: str encoded_coordinates: str @@ -1102,16 +1129,21 @@ class ExampleType(typing.NamedTuple): @beam.typehints.with_output_types(ExampleType) def _get_example_metadata(example: tf.train.Example) -> ExampleType: - return ExampleType( - example_id=utils.get_bytes_feature(example, 'example_id')[0].decode(), - encoded_coordinates=utils.get_bytes_feature( - example, 'encoded_coordinates' - )[0].decode(), - longitude=utils.get_float_feature(example, 'coordinates')[0], - latitude=utils.get_float_feature(example, 'coordinates')[1], - post_image_id=utils.get_bytes_feature(example, 'post_image_id')[ - 0 - ].decode(), - pre_image_id=utils.get_bytes_feature(example, 'pre_image_id')[0].decode(), - plus_code=utils.get_bytes_feature(example, 'plus_code')[0].decode(), - ) + example_id = utils.get_bytes_feature(example, 'example_id')[0].decode() + encoded_coordinates = utils.get_bytes_feature(example, 'encoded_coordinates')[ + 0 + ].decode() + longitude, latitude = utils.get_float_feature(example, 'coordinates') + post_image_id = utils.get_bytes_feature(example, 'post_image_id')[0].decode() + pre_image_id = utils.get_bytes_feature(example, 'pre_image_id')[0].decode() + plus_code = utils.get_bytes_feature(example, 'plus_code')[0].decode() + + return dict({ + 'example_id': example_id, + 'encoded_coordinates': encoded_coordinates, + 'longitude': longitude, + 'latitude': latitude, + 'post_image_id': post_image_id, + 'pre_image_id': pre_image_id, + 'plus_code': plus_code, + }) diff --git a/src/skai/generate_examples_test.py b/src/skai/generate_examples_test.py index 7c4c1ca9..fb6c175c 100644 --- a/src/skai/generate_examples_test.py +++ b/src/skai/generate_examples_test.py @@ -14,7 +14,6 @@ """Tests for generate_examples.py.""" -import glob import os import pathlib import tempfile @@ -478,7 +477,6 @@ def testGenerateExamplesPipeline(self): cloud_project=None, cloud_region=None, worker_service_account=None, - min_workers=0, max_workers=0, wait_for_dataflow_job=True, cloud_detector_model_path=None, @@ -517,7 +515,6 @@ def testGenerateExamplesWithOutputMetaDataFile(self): cloud_project=None, cloud_region=None, worker_service_account=None, - min_workers=0, max_workers=0, wait_for_dataflow_job=True, cloud_detector_model_path=None, @@ -527,27 +524,26 @@ def testGenerateExamplesWithOutputMetaDataFile(self): tfrecords = os.listdir( os.path.join(output_dir, 'examples', 'unlabeled-large') ) - metadata_pattern = os.path.join( - output_dir, 'examples', 'metadata', 'metadata.csv-*-of-*' + df_metadata_contents = pd.read_csv( + os.path.join(output_dir, 'examples', 'metadata_examples.csv') ) - metadata = pd.concat([pd.read_csv(p) for p in glob.glob(metadata_pattern)]) # No assert for example_id as each example_id depends on the image path # which varies with platforms where this test is run self.assertEqual( - metadata.encoded_coordinates[0], 'A17B32432A1085C1' + df_metadata_contents.encoded_coordinates[0], 'A17B32432A1085C1' ) self.assertAlmostEqual( - metadata.latitude[0], -16.632892608642578 + df_metadata_contents.latitude[0], -16.632892608642578 ) self.assertAlmostEqual( - metadata.longitude[0], 178.48292541503906 + df_metadata_contents.longitude[0], 178.48292541503906 ) - self.assertEqual(metadata.pre_image_id[0], self.test_image_path) + self.assertEqual(df_metadata_contents.pre_image_id[0], self.test_image_path) self.assertEqual( - metadata.post_image_id[0], self.test_image_path + df_metadata_contents.post_image_id[0], self.test_image_path ) - self.assertEqual(metadata.plus_code[0], '5VMW9F8M+R5V8F4') + self.assertEqual(df_metadata_contents.plus_code[0], '5VMW9F8M+R5V8F4') self.assertSameElements(tfrecords, ['unlabeled-00000-of-00001.tfrecord']) def testConfigLoadedCorrectlyFromJsonFile(self): diff --git a/src/skai/labeling.py b/src/skai/labeling.py index b7a9288f..0358869e 100644 --- a/src/skai/labeling.py +++ b/src/skai/labeling.py @@ -220,26 +220,6 @@ def sample_with_buffer( return sample -def _read_sharded_csvs(pattern: str) -> pd.DataFrame: - """Reads CSV shards matching pattern and merges them.""" - paths = tf.io.gfile.glob(pattern) - if not paths: - raise ValueError(f'File pattern {pattern} did not match any files.') - dfs = [] - expected_columns = None - for path in paths: - with tf.io.gfile.GFile(path, 'r') as f: - df = pd.read_csv(f) - if expected_columns is None: - expected_columns = set(df.columns) - else: - actual_columns = set(df.columns) - if actual_columns != expected_columns: - raise ValueError(f'Inconsistent columns in file {path}') - dfs.append(df) - return pd.concat(dfs, ignore_index=True) - - def get_buffered_example_ids( examples_pattern: str, buffered_sampling_radius: float, @@ -258,23 +238,25 @@ def get_buffered_example_ids( Returns: Set of allowed example ids. """ - root_dir = '/'.join(examples_pattern.split('/')[:-2]) - single_csv_pattern = str(os.path.join(root_dir, 'metadata_examples.csv')) - if tf.io.gfile.exists(single_csv_pattern): - metadata = _read_sharded_csvs(single_csv_pattern) - else: - sharded_csv_pattern = str( - os.path.join( - root_dir, - 'metadata', - 'metadata.csv-*-of-*', - ) - ) - metadata = _read_sharded_csvs(sharded_csv_pattern) - - metadata = metadata[ - ~metadata['example_id'].isin(excluded_example_ids) - ].reset_index(drop=True) + metadata_path = str( + os.path.join( + '/'.join(examples_pattern.split('/')[:-2]), + 'metadata_examples.csv', + ) + ) + with tf.io.gfile.GFile(metadata_path, 'r') as f: + try: + df_metadata = pd.read_csv(f) + df_metadata = df_metadata[ + ~df_metadata['example_id'].isin(excluded_example_ids) + ].reset_index(drop=True) + except tf.errors.NotFoundError as error: + raise SystemExit( + f'\ntf.errors.NotFoundError: {metadata_path} was not found\nUse' + ' examples_to_csv module to generate metadata_examples.csv and/or' + ' put metadata_examples.csv in the appropriate directory that is' + ' PATH_DIR/examples/' + ) from error logging.info( 'Randomly searching for buffered samples with buffer radius %.2f' @@ -283,11 +265,11 @@ def get_buffered_example_ids( ) points = utils.convert_to_utm( gpd.GeoSeries( - gpd.points_from_xy(metadata['longitude'], metadata['latitude']), + gpd.points_from_xy(df_metadata['longitude'], df_metadata['latitude']), crs=4326, ) ) - gpd_df = gpd.GeoDataFrame(metadata, geometry=points) + gpd_df = gpd.GeoDataFrame(df_metadata, geometry=points) max_examples = len(gpd_df) if max_examples is None else max_examples df_buffered_samples = sample_with_buffer( gpd_df, max_examples, buffered_sampling_radius diff --git a/src/skai/labeling_test.py b/src/skai/labeling_test.py index da15981e..5cda474a 100644 --- a/src/skai/labeling_test.py +++ b/src/skai/labeling_test.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for labeling.""" + import os import random import tempfile from absl.testing import absltest -from absl.testing import parameterized import geopandas as gpd import numpy as np import pandas as pd @@ -79,12 +80,9 @@ def _read_tfrecords(path: str) -> list[Example]: return examples -class LabelingTest(parameterized.TestCase): +class LabelingTest(absltest.TestCase): - @parameterized.parameters( - dict(sharded_example_metadata=True), dict(sharded_example_metadata=False) - ) - def test_create_buffered_tfrecords(self, sharded_example_metadata: bool): + def test_create_buffered_tfrecords(self): """Tests create_buffered_tfrecords.""" # Create 5 unlabeled examples in 3 tfrecords. with tempfile.TemporaryDirectory() as examples_dir: @@ -95,6 +93,9 @@ def test_create_buffered_tfrecords(self, sharded_example_metadata: bool): examples_pattern = os.path.join( examples_dir, 'examples', 'unlabeled', '*' ) + metadata_examples_path = os.path.join( + examples_dir, 'examples', 'metadata_examples.csv' + ) filtered_tfrecords_output_dir = os.path.join( examples_dir, 'filtered', ) @@ -113,28 +114,7 @@ def test_create_buffered_tfrecords(self, sharded_example_metadata: bool): columns=['example_id', 'longitude', 'latitude'], ) df_metadata = df_metadata.sample(frac=1) - if sharded_example_metadata: - metadata_dir = os.path.join(examples_dir, 'examples', 'metadata') - os.mkdir(metadata_dir) - df_metadata.iloc[:2].to_csv( - os.path.join( - metadata_dir, - 'metadata.csv-00000-of-00002', - ), - index=False, - ) - df_metadata.iloc[2:].to_csv( - os.path.join( - metadata_dir, - 'metadata.csv-00001-of-00002', - ), - index=False, - ) - else: - metadata_examples_path = os.path.join( - examples_dir, 'examples', 'metadata_examples.csv' - ) - df_metadata.to_csv(metadata_examples_path, index=False) + df_metadata.to_csv(metadata_examples_path, index=False) example_id_lon_lat_create_tfrecords = { '001': [('a', [92.850449, 20.148951]), ('b', [92.889694, 20.157515])], diff --git a/src/skai/model/configs/base_config.py b/src/skai/model/configs/base_config.py index 09aea5f3..0254aa7e 100644 --- a/src/skai/model/configs/base_config.py +++ b/src/skai/model/configs/base_config.py @@ -72,7 +72,7 @@ def get_data_config(): # model is trained on a random subsample of the dataset. Split guarantees # each point to be in the exact number of splits defined by the ood ratio. # Filtering only guarantees this in expectation. - config.use_splits = True + config.use_splits = False config.use_filtering = False # The following arguments are only used when use_filtering=True diff --git a/src/skai/model/data.py b/src/skai/model/data.py index 4da8842e..b1ac1476 100644 --- a/src/skai/model/data.py +++ b/src/skai/model/data.py @@ -23,10 +23,12 @@ import collections import dataclasses +import enum +import functools import os from typing import Any, Iterator import uuid - +import ml_collections import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -52,6 +54,11 @@ CROP_PADDING = 32 +class ReadingRecordStratgey(enum.Enum): + DIRECT_FROM_TFRECORD = 'direct_from_tfrecord' + USING_TFDS = 'using_tfds' + + def register_dataset(name: str): """Provides decorator to register functions that return dataset.""" @@ -74,8 +81,8 @@ def get_dataset(name: str): class Dataloader: num_subgroups: int # Number of subgroups in data. subgroup_sizes: dict[str, int] # Number of examples by subgroup. - train_splits: tf.data.Dataset # Result of tfds.load with 'split' arg. - val_splits: tf.data.Dataset # Result of tfds.load with 'split' arg. + train_splits: tf.data.Dataset | None # Result of tfds.load with 'split' arg. + val_splits: tf.data.Dataset | None # Result of tfds.load with 'split' arg. train_ds: tf.data.Dataset # Dataset with all the train splits combined. eval_ds: dict[str, tf.data.Dataset] # Validation and/or test datasets. num_train_examples: int | None = 0 # Number of training examples. @@ -161,12 +168,14 @@ def apply_batch(dataloader, batch_size): """Apply batching to dataloader.""" # TODO(jlee24): Support making splits divisible by batch_size # so that a remainder is not dropped for every split. - dataloader.train_splits = [ - data.batch(batch_size) for data in dataloader.train_splits - ] - dataloader.val_splits = [ - data.batch(batch_size) for data in dataloader.val_splits - ] + if dataloader.train_splits: + dataloader.train_splits = [ + data.batch(batch_size) for data in dataloader.train_splits + ] + if dataloader.val_splits: + dataloader.val_splits = [ + data.batch(batch_size) for data in dataloader.val_splits + ] dataloader.train_ds = dataloader.train_ds.batch( batch_size, drop_remainder=True ) @@ -595,6 +604,93 @@ def decode_and_resize_image( ) +def decode_skai_record_bytes( + record_bytes, image_size, use_post_disaster_only, load_small_images +): + """Decode bytes into a dictionary of features and their tensor values.""" + example = tf.io.parse_single_example( + record_bytes, + { + 'coordinates': tf.io.FixedLenFeature([2], dtype=tf.float32), + 'encoded_coordinates': tf.io.FixedLenFeature([], dtype=tf.string), + 'int64_id': tf.io.FixedLenFeature([], dtype=tf.int64), + 'pre_image_png_large': tf.io.FixedLenFeature([], dtype=tf.string), + 'pre_image_png': tf.io.FixedLenFeature([], dtype=tf.string), + 'post_image_png_large': tf.io.FixedLenFeature([], dtype=tf.string), + 'post_image_png': tf.io.FixedLenFeature([], dtype=tf.string), + 'label': tf.io.FixedLenFeature([], dtype=tf.float32), + }, + ) + + features = {'input_feature': {}} + large_image_concat = decode_and_resize_image( + example['post_image_png_large'], image_size + ) + small_image_concat = decode_and_resize_image( + example['post_image_png'], image_size + ) + + if not use_post_disaster_only: + before_image = decode_and_resize_image( + example['pre_image_png_large'], image_size + ) + before_image_small = decode_and_resize_image( + example['pre_image_png'], image_size + ) + large_image_concat = tf.concat([before_image, large_image_concat], axis=-1) + small_image_concat = tf.concat( + [before_image_small, small_image_concat], axis=-1 + ) + features['input_feature']['large_image'] = large_image_concat + if load_small_images: + features['input_feature']['small_image'] = small_image_concat + features['label'] = tf.cast(example['label'], tf.int64) + features['example_id'] = example['int64_id'] + features['subgroup_label'] = features['label'] + features['coordinates'] = example['coordinates'] + return features + + +def read_skai_dataset_from_tfrecord( + pattern: str, + image_size: int, + use_post_disaster_only: bool, + load_small_images: bool, +) -> tf.data.Dataset: + """Create SKAI dataset from tfrecord. + + Args: + pattern: The file path pattern of the tfrecord. + image_size: The size for which the image will be resized. + use_post_disaster_only: If True the pre disaster images will not be loaded. + load_small_images: If True a smaller cropped version of the images will be + loaded. + + Returns: + tf.data.Dataset that read and decode from the tfrecord. + + Raises: + FileNotFoundError: if the pattern does exist. + """ + decode_records = functools.partial( + decode_skai_record_bytes, + image_size=image_size, + use_post_disaster_only=use_post_disaster_only, + load_small_images=load_small_images, + ) + paths = tf.io.gfile.glob(pattern) + if not paths: + raise FileNotFoundError( + f'File pattern "{pattern}" does not match any files.' + ) + dataset = ( + tf.data.TFRecordDataset(paths) + .map(decode_records, num_parallel_calls=tf.data.AUTOTUNE) + .prefetch(tf.data.AUTOTUNE) + ) + return dataset + + class SkaiDataset(tfds.core.GeneratorBasedBuilder): """TFDS dataset for SKAI. @@ -686,66 +782,15 @@ def _split_generators(self, dl_manager: tfds.download.DownloadManager): ) return splits - def _decode_record(self, record_bytes): - - example = tf.io.parse_single_example( - record_bytes, - { - 'coordinates': tf.io.FixedLenFeature([2], dtype=tf.float32), - 'encoded_coordinates': tf.io.FixedLenFeature([], dtype=tf.string), - 'int64_id': tf.io.FixedLenFeature([], dtype=tf.int64), - 'pre_image_png_large': tf.io.FixedLenFeature([], dtype=tf.string), - 'pre_image_png': tf.io.FixedLenFeature([], dtype=tf.string), - 'post_image_png_large': tf.io.FixedLenFeature( - [], dtype=tf.string - ), - 'post_image_png': tf.io.FixedLenFeature([], dtype=tf.string), - 'label': tf.io.FixedLenFeature([], dtype=tf.float32), - }, - ) - - features = { - 'input_feature': {} - } - large_image_concat = decode_and_resize_image( - example['post_image_png_large'], self.builder_config.image_size - ) - small_image_concat = decode_and_resize_image( - example['post_image_png'], self.builder_config.image_size - ) - - if not self.builder_config.use_post_disaster_only: - before_image = decode_and_resize_image( - example['pre_image_png_large'], self.builder_config.image_size - ) - before_image_small = decode_and_resize_image( - example['pre_image_png'], self.builder_config.image_size - ) - large_image_concat = tf.concat( - [before_image, large_image_concat], axis=-1 - ) - small_image_concat = tf.concat( - [before_image_small, small_image_concat], axis=-1 - ) - features['input_feature']['large_image'] = large_image_concat - if self.builder_config.load_small_images: - features['input_feature']['small_image'] = small_image_concat - features['label'] = tf.cast(example['label'], tf.int64) - features['example_id'] = example['int64_id'] - features['subgroup_label'] = features['label'] - features['coordinates'] = example['coordinates'] - return features - def _generate_examples(self, pattern: str): if not pattern: return - paths = tf.io.gfile.glob(pattern) - if not paths: - raise FileNotFoundError( - f'File pattern "{pattern}" does not match any files.' - ) - ds = tf.data.TFRecordDataset(paths).map( - self._decode_record, num_parallel_calls=tf.data.AUTOTUNE) + ds = read_skai_dataset_from_tfrecord( + pattern, + self.builder_config.image_size, + self.builder_config.use_post_disaster_only, + self.builder_config.load_small_images, + ) if self.builder_config.max_examples: ds = ds.take(self.builder_config.max_examples) for features in ds.as_numpy_iterator(): @@ -1091,3 +1136,97 @@ def get_skai_dataset(num_splits: int, train_ds, train_sample_ds=None, eval_ds=eval_datasets) + + +def get_skai_dataloader_from_tfrecord( + labeled_train_pattern: str, + labeled_val_pattern: str, + image_size: int, + use_post_disaster_only: bool, + load_small_images: bool, +): + """Create Dataloader directly from tfrecord.""" + train_ds = read_skai_dataset_from_tfrecord( + labeled_train_pattern, + image_size, + use_post_disaster_only, + load_small_images, + ) + val_ds = read_skai_dataset_from_tfrecord( + labeled_val_pattern, + image_size, + use_post_disaster_only, + load_small_images, + ) + subgroup_sizes = get_subgroup_sizes(train_ds) + return Dataloader( + num_subgroups=2, + subgroup_sizes=subgroup_sizes, + train_splits=None, + val_splits=None, + train_ds=train_ds, + eval_ds={'val': val_ds, 'test': val_ds}, + train_sample_ds=None, + ) + + +class SkaiDatasetFactory: + """Factory for Skai datasets.""" + + @staticmethod + def get_dataloader( + source: str, config: ml_collections.ConfigDict + ) -> Dataloader: + """Create SKAI dataloader based on the source.""" + if source == ReadingRecordStratgey.USING_TFDS: + dataset_builder = get_dataset(config.data.name) + ds_kwargs = {} + ds_kwargs.update({ + 'tfds_dataset_name': config.data.tfds_dataset_name, + 'data_dir': config.data.tfds_data_dir, + 'adhoc_config_name': config.data.adhoc_config_name, + 'labeled_train_pattern': config.data.labeled_train_pattern, + 'unlabeled_train_pattern': config.data.unlabeled_train_pattern, + 'validation_pattern': config.data.validation_pattern, + 'use_post_disaster_only': config.data.use_post_disaster_only, + 'load_small_images': config.data.load_small_images, + }) + if config.data.use_post_disaster_only: + config.model.num_channels = 3 + if config.upsampling.do_upsampling: + ds_kwargs.update({ + 'upsampling_lambda': config.upsampling.lambda_value, + 'upsampling_signal': config.upsampling.signal, + }) + get_split_config = lambda x: x if config.data.use_splits else 1 + if config.round_idx == 0: + dataloader = dataset_builder( + num_splits=get_split_config(config.data.num_splits), + initial_sample_proportion=get_split_config( + config.data.initial_sample_proportion + ), + subgroup_ids=config.data.subgroup_ids, + subgroup_proportions=config.data.subgroup_proportions, + **ds_kwargs, + ) + else: + # If latter round, keep track of split generated in last round of active + # sampling + dataloader = dataset_builder( + config.data.num_splits, + initial_sample_proportion=1, + subgroup_ids=(), + subgroup_proportions=(), + **ds_kwargs, + ) + return dataloader + elif source == ReadingRecordStratgey.DIRECT_FROM_TFRECORD: + return get_skai_dataloader_from_tfrecord( + config.data.labeled_train_pattern, + config.data.validation_pattern, + RESNET_IMAGE_SIZE, + config.data.use_post_disaster_only, + config.data.load_small_images, + ) + else: + raise ValueError(f'Source {source} is not supported.') diff --git a/src/skai/model/data_loader_test.py b/src/skai/model/data_loader_test.py index 8c623361..50d4b0a6 100644 --- a/src/skai/model/data_loader_test.py +++ b/src/skai/model/data_loader_test.py @@ -19,6 +19,7 @@ from typing import List from absl.testing import absltest +import ml_collections as mlc import numpy as np from skai.model import data import tensorflow as tf @@ -259,6 +260,158 @@ def test_upsample_subgroup(self): 1 * lambda_value, ) + def test_dataloader_from_tfrecords_post_only(self): + dataloader = data.get_skai_dataloader_from_tfrecord( + self.labeled_train_path, + self.labeled_test_path, + RESNET_IMAGE_SIZE, + True, + False, + ) + ds = dataloader.train_ds + features = next(ds.as_numpy_iterator()) + self.assertIn('input_feature', features) + self.assertIn('large_image', features['input_feature']) + input_feature = features['input_feature']['large_image'] + self.assertEqual( + input_feature.shape, + (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 3), + ) + self.assertEqual(input_feature.dtype, np.float32) + np.testing.assert_equal(input_feature, 1.0) + + def test_dataloader_from_tfrecord_pre_post(self): + dataloader = data.get_skai_dataloader_from_tfrecord( + self.labeled_train_path, + self.labeled_test_path, + RESNET_IMAGE_SIZE, + False, + True, + ) + ds = dataloader.train_ds + features = next(ds.as_numpy_iterator()) + self.assertIn('input_feature', features) + self.assertIn('large_image', features['input_feature']) + input_feature = features['input_feature']['large_image'] + self.assertEqual( + input_feature.shape, (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 6) + ) + self.assertEqual(input_feature.dtype, np.float32) + np.testing.assert_equal(input_feature[:, :, :3], 0.0) + np.testing.assert_equal(input_feature[:, :, 3:], 1.0) + + def test_dataloader_from_tfrecord_small_images(self): + dataloader = data.get_skai_dataloader_from_tfrecord( + self.labeled_train_path, + self.labeled_test_path, + RESNET_IMAGE_SIZE, + False, + True, + ) + ds = dataloader.train_ds + features = next(ds.as_numpy_iterator()) + self.assertIn('input_feature', features) + self.assertIn('small_image', features['input_feature']) + self.assertIn('large_image', features['input_feature']) + small_image = features['input_feature']['small_image'] + large_image = features['input_feature']['large_image'] + self.assertEqual( + small_image.shape, (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 6) + ) + self.assertEqual(small_image.dtype, np.float32) + np.testing.assert_equal(small_image[:, :, :3], 0.0) + np.testing.assert_equal(small_image[:, :, 3:], 1.0) + + self.assertEqual( + large_image.shape, (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 6) + ) + self.assertEqual(small_image.dtype, np.float32) + np.testing.assert_equal(large_image[:, :, :3], 0.0) + np.testing.assert_equal(large_image[:, :, 3:], 1.0) + + def test_dataset_factory_getting_dataloader_from_tfrecord(self): + + def _get_config_tfrecord_dataloader(): + config = mlc.config_dict.ConfigDict() + config.data = mlc.config_dict.ConfigDict() + config.data.labeled_train_pattern = self.labeled_train_path + config.data.validation_pattern = self.labeled_test_path + config.data.use_post_disaster_only = False + config.data.load_small_images = True + return config + + config = _get_config_tfrecord_dataloader() + dataloader = data.SkaiDatasetFactory.get_dataloader( + data.ReadingRecordStratgey.DIRECT_FROM_TFRECORD, config + ) + ds = dataloader.train_ds + features = next(ds.as_numpy_iterator()) + self.assertIn('input_feature', features) + self.assertIn('small_image', features['input_feature']) + self.assertIn('large_image', features['input_feature']) + small_image = features['input_feature']['small_image'] + large_image = features['input_feature']['large_image'] + self.assertEqual( + small_image.shape, (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 6) + ) + self.assertEqual(small_image.dtype, np.float32) + np.testing.assert_equal(small_image[:, :, :3], 0.0) + np.testing.assert_equal(small_image[:, :, 3:], 1.0) + + self.assertEqual( + large_image.shape, (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 6) + ) + self.assertEqual(small_image.dtype, np.float32) + np.testing.assert_equal(large_image[:, :, :3], 0.0) + np.testing.assert_equal(large_image[:, :, 3:], 1.0) + + def test_dataset_factory_getting_dataloader_from_tensorflow_dataset(self): + + def _get_config_tensorflow_dataset_dataloader(): + config = mlc.config_dict.ConfigDict() + config.data = mlc.config_dict.ConfigDict() + config.data.name = 'skai' + config.data.tfds_dataset_name = 'skai_dataset' + config.data.tfds_data_dir = _make_temp_dir() + config.data.adhoc_config_name = 'skai_dataset' + config.data.labeled_train_pattern = self.labeled_train_path + config.data.validation_pattern = self.labeled_test_path + config.data.unlabeled_train_pattern = '' + config.data.use_post_disaster_only = False + config.data.load_small_images = True + config.data.num_splits = 1 + config.data.use_splits = False + config.round_idx = -1 + + config.upsampling = mlc.config_dict.ConfigDict() + config.upsampling.do_upsampling = False + return config + + config = _get_config_tensorflow_dataset_dataloader() + dataloader = data.SkaiDatasetFactory.get_dataloader( + data.ReadingRecordStratgey.USING_TFDS, config + ) + ds = dataloader.train_ds + features = next(ds.as_numpy_iterator()) + self.assertIn('input_feature', features) + self.assertIn('small_image', features['input_feature']) + self.assertIn('large_image', features['input_feature']) + small_image = features['input_feature']['small_image'] + large_image = features['input_feature']['large_image'] + self.assertEqual( + small_image.shape, (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 6) + ) + self.assertEqual(small_image.dtype, np.float32) + np.testing.assert_equal(small_image[:, :, :3], 0.0) + np.testing.assert_equal(small_image[:, :, 3:], 1.0) + + self.assertEqual( + large_image.shape, (RESNET_IMAGE_SIZE, RESNET_IMAGE_SIZE, 6) + ) + self.assertEqual(small_image.dtype, np.float32) + np.testing.assert_equal(large_image[:, :, :3], 0.0) + np.testing.assert_equal(large_image[:, :, 3:], 1.0) + if __name__ == '__main__': absltest.main() diff --git a/src/skai/model/data_test.py b/src/skai/model/data_test.py index 2b42c9c1..2755dbf4 100644 --- a/src/skai/model/data_test.py +++ b/src/skai/model/data_test.py @@ -127,9 +127,9 @@ def setUpClass(cls): name='test_config', labeled_train_pattern=labeled_train_path, labeled_test_pattern=labeled_test_path, - unlabeled_pattern=unlabeled_path) + unlabeled_pattern=unlabeled_path, + ) ] - if __name__ == '__main__': tfds.testing.test_main() diff --git a/src/skai/model/train.py b/src/skai/model/train.py index 7dd2d822..db691d0e 100644 --- a/src/skai/model/train.py +++ b/src/skai/model/train.py @@ -31,7 +31,6 @@ from skai.model import generate_bias_table_lib ###COPYBARA_PLACEHOLDER_01 from skai.model import models -from skai.model import sampling_policies from skai.model import train_lib from skai.model.configs import base_config from skai.model.train_strategy import get_strategy @@ -64,6 +63,14 @@ 'tpu', '', 'The BNS address of the first TPU worker.' ) +_DATA_READING_STRATEGY = flags.DEFINE_enum_class( + 'data_reading_strategy', + 'using_tfds', + data.ReadingRecordStratgey, + 'Specify how to read the dataset, either directly using' + ' tf.data.TRecordDataset or using tensorflow_datasets.s', +) + def get_model_dir(root_dir: str) -> str: if FLAGS.is_vertex and FLAGS.trial_name: @@ -91,53 +98,9 @@ def main(_) -> None: stream_handler = native_logging.StreamHandler(stream) logging.get_absl_logger().addHandler(stream_handler) - dataset_builder = data.get_dataset(config.data.name) - ds_kwargs = {} - if config.data.name == 'waterbirds10k': - ds_kwargs = {'corr_strength': config.data.corr_strength} - elif config.data.name == 'skai': - ds_kwargs.update({ - 'tfds_dataset_name': config.data.tfds_dataset_name, - 'data_dir': config.data.tfds_data_dir, - 'adhoc_config_name': config.data.adhoc_config_name, - 'labeled_train_pattern': config.data.labeled_train_pattern, - 'unlabeled_train_pattern': config.data.unlabeled_train_pattern, - 'validation_pattern': config.data.validation_pattern, - 'use_post_disaster_only': config.data.use_post_disaster_only, - 'load_small_images': config.data.load_small_images, - }) - if config.data.use_post_disaster_only: - config.model.num_channels = 3 - if config.upsampling.do_upsampling: - ds_kwargs.update({ - 'upsampling_lambda': config.upsampling.lambda_value, - 'upsampling_signal': config.upsampling.signal, - }) - - logging.info('Running Round %d of Training.', config.round_idx) - get_split_config = lambda x: x if config.data.use_splits else 1 - if config.round_idx == 0: - dataloader = dataset_builder( - num_splits=get_split_config(config.data.num_splits), - initial_sample_proportion=get_split_config( - config.data.initial_sample_proportion), - subgroup_ids=config.data.subgroup_ids, - subgroup_proportions=config.data.subgroup_proportions, **ds_kwargs) - else: - # If latter round, keep track of split generated in last round of active - # sampling - dataloader = dataset_builder(config.data.num_splits, - initial_sample_proportion=1, - subgroup_ids=(), - subgroup_proportions=(), - **ds_kwargs) - - # Filter each split to only have examples from example_ids_table - dataloader.train_splits = [ - dataloader.train_ds.filter( - generate_bias_table_lib.filter_ids_fn(ids_tab)) for - ids_tab in sampling_policies.convert_ids_to_table(config.ids_dir)] - + dataloader = data.SkaiDatasetFactory.get_dataloader( + _DATA_READING_STRATEGY.value, config + ) global_batch_size = config.data.batch_size * strategy.num_replicas_in_sync model_params = models.ModelTrainingParameters( @@ -210,8 +173,8 @@ def main(_) -> None: training=False ) else: - raise ValueError( - 'In `config.data`, one of `(use_splits, use_filtering)` must be True.') + new_train_ds = dataloader.train_ds + val_ds = dataloader.eval_ds['val'] dataloader.train_ds = new_train_ds dataloader.eval_ds['val'] = val_ds diff --git a/src/skai/model/vlm_zero_shot_lib.py b/src/skai/model/vlm_zero_shot_lib.py index 605a1884..0c5aa57e 100644 --- a/src/skai/model/vlm_zero_shot_lib.py +++ b/src/skai/model/vlm_zero_shot_lib.py @@ -280,23 +280,6 @@ def image_preprocessing(example): return dataset -def _dedup_predictions(predictions: pd.DataFrame): - if 'is_cloudy' in predictions.columns: - non_cloudy = predictions[predictions['is_cloudy'] == 0] - else: - non_cloudy = predictions - return non_cloudy.groupby('building_id').agg({ - 'damage_score': 'mean', - 'cloud_score': 'mean', - 'longitude': 'mean', - 'latitude': 'mean', - 'example_id': 'first', - 'int64_id': 'first', - 'plus_code': 'first', - 'label': 'first', - }) - - def generate_zero_shot_assessment( model_config: ml_collections.ConfigDict, damage_label_file_path: str, @@ -396,9 +379,3 @@ def generate_zero_shot_assessment( f'{output_dir}/{dataset_name}_output.csv', 'w' ) as output_csv_file: output_df.to_csv(output_csv_file, index=False) - - deduped = _dedup_predictions(output_df) - with tf.io.gfile.GFile( - f'{output_dir}/{dataset_name}_deduped.csv', 'w' - ) as deduped_file: - deduped.to_csv(deduped_file, index=False) diff --git a/src/skai/model/xm_launch_single_model_vertex.py b/src/skai/model/xm_launch_single_model_vertex.py index 8e085c7e..335035f4 100644 --- a/src/skai/model/xm_launch_single_model_vertex.py +++ b/src/skai/model/xm_launch_single_model_vertex.py @@ -34,6 +34,7 @@ from absl import flags from google.cloud import aiplatform_v1beta1 as aip from ml_collections import config_flags +from skai.model import data from skai.model import docker_instructions from xmanager import xm from xmanager import xm_local @@ -71,15 +72,13 @@ ), ) flags.DEFINE_string( - 'cloud_location', - None, - 'Google Cloud location (region) for vizier jobs' + 'cloud_location', None, 'Google Cloud location (region) for vizier jobs' ) flags.DEFINE_enum( 'accelerator', default=None, help='Accelerator to use for faster computations.', - enum_values=['P100', 'V100', 'P4', 'T4', 'A100', 'TPU_V2', 'TPU_V3'] + enum_values=['P100', 'V100', 'P4', 'T4', 'A100', 'TPU_V2', 'TPU_V3'], ) flags.DEFINE_integer( 'accelerator_count', @@ -100,6 +99,13 @@ ' docker image.', ) flags.DEFINE_string('docker_image', None, 'Pre-built docker image to use.') +flags.DEFINE_enum_class( + 'data_reading_strategy', + 'using_tfds', + data.ReadingRecordStratgey, + 'Specify how to read the dataset, either directly using' + ' tf.data.TRecordDataset or using tensorflow_datasets.s', +) config_flags.DEFINE_config_file('config') @@ -199,8 +205,10 @@ def main(_) -> None: 'is_vertex': True, 'accelerator_type': accelerator_type, 'tpu': FLAGS.tpu, - 'config.output_dir': os.path.join(config.output_dir, - str(experiment.experiment_id)), + 'data_reading_strategy': FLAGS.data_reading_strategy, + 'config.output_dir': os.path.join( + config.output_dir, str(experiment.experiment_id) + ), 'config.train_bias': config.train_bias, 'config.train_stage_2_as_ensemble': False, 'config.round_idx': 0, @@ -257,10 +265,8 @@ def main(_) -> None: executable=train_executable, executor=executor, args=xm_args ), study_factory=vizier_cloud.NewStudy( - study_config=get_study_config(), - location=FLAGS.cloud_location + study_config=get_study_config(), location=FLAGS.cloud_location ), - num_trials_total=100, num_parallel_trial_runs=3, ).launch()