Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the ability to load directly from tfrecord #292

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,364 changes: 0 additions & 1,364 deletions src/colab/skai_assessment_notebook_custom_vm.ipynb

This file was deleted.

1,171 changes: 0 additions & 1,171 deletions src/colab/skai_assessment_notebook_custom_vm.py

This file was deleted.

4 changes: 0 additions & 4 deletions src/colab/sync_notebook_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@
'GCP_LOCATION': '',
'GCP_BUCKET': '',
'GCP_SERVICE_ACCOUNT': '',
'SERVICE_ACCOUNT_KEY': '',
'BUILDING_SEGMENTATION_MODEL_PATH': '',
'BUILDINGS_METHOD': 'open_buildings',
'USER_BUILDINGS_FILE': '',
'ASSESSMENT_NAME': '',
'EVENT_DATE': '',
'OUTPUT_DIR': '',
Expand All @@ -59,7 +56,6 @@
'AFTER_IMAGE_7': '',
'AFTER_IMAGE_8': '',
'AFTER_IMAGE_9': '',
'DAMAGE_SCORE_THRESHOLD': 0.5,
}


Expand Down
4 changes: 0 additions & 4 deletions src/detect_buildings_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@
'worker_service_account', None,
'Service account that will launch Dataflow workers. If unset, workers will '
'run with the project\'s default Compute Engine service account.')
flags.DEFINE_integer(
'min_dataflow_workers', 10, 'Minimum number of dataflow workers'
)
flags.DEFINE_integer(
'max_dataflow_workers', None, 'Maximum number of dataflow workers'
)
Expand Down Expand Up @@ -113,7 +110,6 @@ def main(args):
FLAGS.cloud_project,
FLAGS.cloud_region,
temp_dir,
FLAGS.min_dataflow_workers,
FLAGS.max_dataflow_workers,
FLAGS.worker_service_account,
machine_type=FLAGS.worker_machine_type,
Expand Down
3 changes: 0 additions & 3 deletions src/generate_examples_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@
'worker_service_account', None,
'Service account that will launch Dataflow workers. If unset, workers will '
'run with the project\'s default Compute Engine service account.')
flags.DEFINE_integer(
'min_dataflow_workers', None, 'Minimum number of dataflow workers'
)
flags.DEFINE_integer(
'max_dataflow_workers', None, 'Maximum number of dataflow workers'
)
Expand Down
8 changes: 1 addition & 7 deletions src/skai/beam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def get_pipeline_options(
project: str,
region: str,
temp_dir: str,
min_workers: int,
max_workers: int,
worker_service_account: str | None,
machine_type: str | None,
Expand All @@ -124,7 +123,6 @@ def get_pipeline_options(
project: GCP project.
region: GCP region.
temp_dir: Temporary data location.
min_workers: Minimum number of Dataflow workers.
max_workers: Maximum number of Dataflow workers.
worker_service_account: Email of the service account will launch workers.
If None, uses the project's default Compute Engine service account
Expand Down Expand Up @@ -166,11 +164,8 @@ def get_pipeline_options(
if machine_type:
options['machine_type'] = machine_type

service_options = [
f'min_num_workers={min_workers}',
]
if accelerator:
service_options.extend([
options['dataflow_service_options'] = ';'.join([
f'worker_accelerator=type:{accelerator}',
f'count:{accelerator_count}',
'install-nvidia-driver',
Expand All @@ -179,5 +174,4 @@ def get_pipeline_options(
else:
options['sdk_container_image'] = _get_dataflow_container_image('cpu')

options['dataflow_service_options'] = ';'.join(service_options)
return PipelineOptions.from_dictionary(options)
88 changes: 60 additions & 28 deletions src/skai/generate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Pipeline for generating tensorflow examples from satellite images."""

import binascii
import csv
import dataclasses
import hashlib
import itertools
Expand All @@ -27,8 +28,6 @@
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple

import apache_beam as beam
import apache_beam.dataframe.convert
import apache_beam.dataframe.io
import cv2
import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -120,8 +119,6 @@ class ExamplesGenerationConfig:
use_dataflow: If true, execute pipeline in Cloud Dataflow.
output_metadata_file: Output a CSV metadata file for all generated examples.
worker_service_account: If using Dataflow, the service account to run as.
min_dataflow_workers: If using Dataflow, the minimum number of workers to
instantiate.
max_dataflow_workers: If using Dataflow, the number of workers to
instantiate.
example_patch_size: Size of the example image.
Expand Down Expand Up @@ -174,7 +171,6 @@ class ExamplesGenerationConfig:
use_dataflow: bool = False
output_metadata_file: bool = True
worker_service_account: Optional[str] = None
min_dataflow_workers: int = 10
max_dataflow_workers: int = 20
example_patch_size: int = 64
large_patch_size: int = 256
Expand Down Expand Up @@ -874,7 +870,6 @@ def _generate_examples_pipeline(
cloud_project: Optional[str],
cloud_region: Optional[str],
worker_service_account: Optional[str],
min_workers: int,
max_workers: int,
wait_for_dataflow_job: bool,
cloud_detector_model_path: Optional[str],
Expand All @@ -898,7 +893,6 @@ def _generate_examples_pipeline(
cloud_project: Cloud project name.
cloud_region: Cloud region, e.g. us-central1.
worker_service_account: Email of service account that will launch workers.
min_workers: Minimum number of workers to use.
max_workers: Maximum number of workers to use.
wait_for_dataflow_job: If true, wait for dataflow job to complete before
returning.
Expand All @@ -914,7 +908,6 @@ def _generate_examples_pipeline(
cloud_project,
cloud_region,
temp_dir,
min_workers,
max_workers,
worker_service_account,
machine_type=None,
Expand Down Expand Up @@ -945,14 +938,28 @@ def _generate_examples_pipeline(
num_shards=num_output_shards))

if output_metadata_file:
rows = (
field_names = [
'example_id',
'encoded_coordinates',
'longitude',
'latitude',
'post_image_id',
'pre_image_id',
'plus_code',
]
_ = (
examples
| 'extract_metadata_rows' >> beam.Map(_get_example_metadata)
| 'remove_duplicates' >> beam.Distinct()
| 'convert_metadata_examples_to_dict' >> beam.Map(_get_example_metadata)
| 'combine_to_list' >> beam.combiners.ToList()
| 'write_metadata_to_file'
>> beam.ParDo(
WriteMetadataToCSVFn(
metadata_output_file_path=(
f'{output_dir}/examples/metadata_examples.csv'
), field_names=field_names
)
)
)
df = apache_beam.dataframe.convert.to_dataframe(rows)
output_prefix = f'{output_dir}/examples/metadata/metadata.csv'
apache_beam.dataframe.io.to_csv(df, output_prefix, index=False)

result = pipeline.run()
if wait_for_dataflow_job:
Expand Down Expand Up @@ -1082,14 +1089,34 @@ def run_example_generation(
config.cloud_project,
config.cloud_region,
config.worker_service_account,
config.min_dataflow_workers,
config.max_dataflow_workers,
wait_for_dataflow,
config.cloud_detector_model_path,
config.output_metadata_file
)


class WriteMetadataToCSVFn(beam.DoFn):
"""DoFn to write meta data of examples to csv file.

Attributes:
metadata_output_file_path: File path to output meta data of all examples.
field_names: Field names to be included in output file.
"""

def __init__(self, metadata_output_file_path: str, field_names: List[str]):
self.metadata_output_file_path = metadata_output_file_path
self.field_names = field_names

def process(self, element):
with tf.io.gfile.GFile(
self.metadata_output_file_path, 'w'
) as csv_output_file:
csv_writer = csv.DictWriter(csv_output_file, fieldnames=self.field_names)
csv_writer.writeheader()
csv_writer.writerows(element)


class ExampleType(typing.NamedTuple):
example_id: str
encoded_coordinates: str
Expand All @@ -1102,16 +1129,21 @@ class ExampleType(typing.NamedTuple):

@beam.typehints.with_output_types(ExampleType)
def _get_example_metadata(example: tf.train.Example) -> ExampleType:
return ExampleType(
example_id=utils.get_bytes_feature(example, 'example_id')[0].decode(),
encoded_coordinates=utils.get_bytes_feature(
example, 'encoded_coordinates'
)[0].decode(),
longitude=utils.get_float_feature(example, 'coordinates')[0],
latitude=utils.get_float_feature(example, 'coordinates')[1],
post_image_id=utils.get_bytes_feature(example, 'post_image_id')[
0
].decode(),
pre_image_id=utils.get_bytes_feature(example, 'pre_image_id')[0].decode(),
plus_code=utils.get_bytes_feature(example, 'plus_code')[0].decode(),
)
example_id = utils.get_bytes_feature(example, 'example_id')[0].decode()
encoded_coordinates = utils.get_bytes_feature(example, 'encoded_coordinates')[
0
].decode()
longitude, latitude = utils.get_float_feature(example, 'coordinates')
post_image_id = utils.get_bytes_feature(example, 'post_image_id')[0].decode()
pre_image_id = utils.get_bytes_feature(example, 'pre_image_id')[0].decode()
plus_code = utils.get_bytes_feature(example, 'plus_code')[0].decode()

return dict({
'example_id': example_id,
'encoded_coordinates': encoded_coordinates,
'longitude': longitude,
'latitude': latitude,
'post_image_id': post_image_id,
'pre_image_id': pre_image_id,
'plus_code': plus_code,
})
20 changes: 8 additions & 12 deletions src/skai/generate_examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""Tests for generate_examples.py."""

import glob
import os
import pathlib
import tempfile
Expand Down Expand Up @@ -478,7 +477,6 @@ def testGenerateExamplesPipeline(self):
cloud_project=None,
cloud_region=None,
worker_service_account=None,
min_workers=0,
max_workers=0,
wait_for_dataflow_job=True,
cloud_detector_model_path=None,
Expand Down Expand Up @@ -517,7 +515,6 @@ def testGenerateExamplesWithOutputMetaDataFile(self):
cloud_project=None,
cloud_region=None,
worker_service_account=None,
min_workers=0,
max_workers=0,
wait_for_dataflow_job=True,
cloud_detector_model_path=None,
Expand All @@ -527,27 +524,26 @@ def testGenerateExamplesWithOutputMetaDataFile(self):
tfrecords = os.listdir(
os.path.join(output_dir, 'examples', 'unlabeled-large')
)
metadata_pattern = os.path.join(
output_dir, 'examples', 'metadata', 'metadata.csv-*-of-*'
df_metadata_contents = pd.read_csv(
os.path.join(output_dir, 'examples', 'metadata_examples.csv')
)
metadata = pd.concat([pd.read_csv(p) for p in glob.glob(metadata_pattern)])

# No assert for example_id as each example_id depends on the image path
# which varies with platforms where this test is run
self.assertEqual(
metadata.encoded_coordinates[0], 'A17B32432A1085C1'
df_metadata_contents.encoded_coordinates[0], 'A17B32432A1085C1'
)
self.assertAlmostEqual(
metadata.latitude[0], -16.632892608642578
df_metadata_contents.latitude[0], -16.632892608642578
)
self.assertAlmostEqual(
metadata.longitude[0], 178.48292541503906
df_metadata_contents.longitude[0], 178.48292541503906
)
self.assertEqual(metadata.pre_image_id[0], self.test_image_path)
self.assertEqual(df_metadata_contents.pre_image_id[0], self.test_image_path)
self.assertEqual(
metadata.post_image_id[0], self.test_image_path
df_metadata_contents.post_image_id[0], self.test_image_path
)
self.assertEqual(metadata.plus_code[0], '5VMW9F8M+R5V8F4')
self.assertEqual(df_metadata_contents.plus_code[0], '5VMW9F8M+R5V8F4')
self.assertSameElements(tfrecords, ['unlabeled-00000-of-00001.tfrecord'])

def testConfigLoadedCorrectlyFromJsonFile(self):
Expand Down
60 changes: 21 additions & 39 deletions src/skai/labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,26 +220,6 @@ def sample_with_buffer(
return sample


def _read_sharded_csvs(pattern: str) -> pd.DataFrame:
"""Reads CSV shards matching pattern and merges them."""
paths = tf.io.gfile.glob(pattern)
if not paths:
raise ValueError(f'File pattern {pattern} did not match any files.')
dfs = []
expected_columns = None
for path in paths:
with tf.io.gfile.GFile(path, 'r') as f:
df = pd.read_csv(f)
if expected_columns is None:
expected_columns = set(df.columns)
else:
actual_columns = set(df.columns)
if actual_columns != expected_columns:
raise ValueError(f'Inconsistent columns in file {path}')
dfs.append(df)
return pd.concat(dfs, ignore_index=True)


def get_buffered_example_ids(
examples_pattern: str,
buffered_sampling_radius: float,
Expand All @@ -258,23 +238,25 @@ def get_buffered_example_ids(
Returns:
Set of allowed example ids.
"""
root_dir = '/'.join(examples_pattern.split('/')[:-2])
single_csv_pattern = str(os.path.join(root_dir, 'metadata_examples.csv'))
if tf.io.gfile.exists(single_csv_pattern):
metadata = _read_sharded_csvs(single_csv_pattern)
else:
sharded_csv_pattern = str(
os.path.join(
root_dir,
'metadata',
'metadata.csv-*-of-*',
)
)
metadata = _read_sharded_csvs(sharded_csv_pattern)

metadata = metadata[
~metadata['example_id'].isin(excluded_example_ids)
].reset_index(drop=True)
metadata_path = str(
os.path.join(
'/'.join(examples_pattern.split('/')[:-2]),
'metadata_examples.csv',
)
)
with tf.io.gfile.GFile(metadata_path, 'r') as f:
try:
df_metadata = pd.read_csv(f)
df_metadata = df_metadata[
~df_metadata['example_id'].isin(excluded_example_ids)
].reset_index(drop=True)
except tf.errors.NotFoundError as error:
raise SystemExit(
f'\ntf.errors.NotFoundError: {metadata_path} was not found\nUse'
' examples_to_csv module to generate metadata_examples.csv and/or'
' put metadata_examples.csv in the appropriate directory that is'
' PATH_DIR/examples/'
) from error

logging.info(
'Randomly searching for buffered samples with buffer radius %.2f'
Expand All @@ -283,11 +265,11 @@ def get_buffered_example_ids(
)
points = utils.convert_to_utm(
gpd.GeoSeries(
gpd.points_from_xy(metadata['longitude'], metadata['latitude']),
gpd.points_from_xy(df_metadata['longitude'], df_metadata['latitude']),
crs=4326,
)
)
gpd_df = gpd.GeoDataFrame(metadata, geometry=points)
gpd_df = gpd.GeoDataFrame(df_metadata, geometry=points)
max_examples = len(gpd_df) if max_examples is None else max_examples
df_buffered_samples = sample_with_buffer(
gpd_df, max_examples, buffered_sampling_radius
Expand Down
Loading
Loading