Skip to content

Commit

Permalink
Add the ability to load directly from tfrecord
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686137245
  • Loading branch information
Mohamedelfatih Mohamedkhair authored and copybara-github committed Oct 18, 2024
1 parent 1f52a97 commit d2345a7
Show file tree
Hide file tree
Showing 17 changed files with 487 additions and 2,811 deletions.
1,364 changes: 0 additions & 1,364 deletions src/colab/skai_assessment_notebook_custom_vm.ipynb

This file was deleted.

1,171 changes: 0 additions & 1,171 deletions src/colab/skai_assessment_notebook_custom_vm.py

This file was deleted.

4 changes: 0 additions & 4 deletions src/colab/sync_notebook_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@
'GCP_LOCATION': '',
'GCP_BUCKET': '',
'GCP_SERVICE_ACCOUNT': '',
'SERVICE_ACCOUNT_KEY': '',
'BUILDING_SEGMENTATION_MODEL_PATH': '',
'BUILDINGS_METHOD': 'open_buildings',
'USER_BUILDINGS_FILE': '',
'ASSESSMENT_NAME': '',
'EVENT_DATE': '',
'OUTPUT_DIR': '',
Expand All @@ -59,7 +56,6 @@
'AFTER_IMAGE_7': '',
'AFTER_IMAGE_8': '',
'AFTER_IMAGE_9': '',
'DAMAGE_SCORE_THRESHOLD': 0.5,
}


Expand Down
4 changes: 0 additions & 4 deletions src/detect_buildings_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@
'worker_service_account', None,
'Service account that will launch Dataflow workers. If unset, workers will '
'run with the project\'s default Compute Engine service account.')
flags.DEFINE_integer(
'min_dataflow_workers', 10, 'Minimum number of dataflow workers'
)
flags.DEFINE_integer(
'max_dataflow_workers', None, 'Maximum number of dataflow workers'
)
Expand Down Expand Up @@ -113,7 +110,6 @@ def main(args):
FLAGS.cloud_project,
FLAGS.cloud_region,
temp_dir,
FLAGS.min_dataflow_workers,
FLAGS.max_dataflow_workers,
FLAGS.worker_service_account,
machine_type=FLAGS.worker_machine_type,
Expand Down
3 changes: 0 additions & 3 deletions src/generate_examples_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@
'worker_service_account', None,
'Service account that will launch Dataflow workers. If unset, workers will '
'run with the project\'s default Compute Engine service account.')
flags.DEFINE_integer(
'min_dataflow_workers', None, 'Minimum number of dataflow workers'
)
flags.DEFINE_integer(
'max_dataflow_workers', None, 'Maximum number of dataflow workers'
)
Expand Down
8 changes: 1 addition & 7 deletions src/skai/beam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def get_pipeline_options(
project: str,
region: str,
temp_dir: str,
min_workers: int,
max_workers: int,
worker_service_account: str | None,
machine_type: str | None,
Expand All @@ -124,7 +123,6 @@ def get_pipeline_options(
project: GCP project.
region: GCP region.
temp_dir: Temporary data location.
min_workers: Minimum number of Dataflow workers.
max_workers: Maximum number of Dataflow workers.
worker_service_account: Email of the service account will launch workers.
If None, uses the project's default Compute Engine service account
Expand Down Expand Up @@ -166,11 +164,8 @@ def get_pipeline_options(
if machine_type:
options['machine_type'] = machine_type

service_options = [
f'min_num_workers={min_workers}',
]
if accelerator:
service_options.extend([
options['dataflow_service_options'] = ';'.join([
f'worker_accelerator=type:{accelerator}',
f'count:{accelerator_count}',
'install-nvidia-driver',
Expand All @@ -179,5 +174,4 @@ def get_pipeline_options(
else:
options['sdk_container_image'] = _get_dataflow_container_image('cpu')

options['dataflow_service_options'] = ';'.join(service_options)
return PipelineOptions.from_dictionary(options)
88 changes: 60 additions & 28 deletions src/skai/generate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Pipeline for generating tensorflow examples from satellite images."""

import binascii
import csv
import dataclasses
import hashlib
import itertools
Expand All @@ -27,8 +28,6 @@
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple

import apache_beam as beam
import apache_beam.dataframe.convert
import apache_beam.dataframe.io
import cv2
import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -120,8 +119,6 @@ class ExamplesGenerationConfig:
use_dataflow: If true, execute pipeline in Cloud Dataflow.
output_metadata_file: Output a CSV metadata file for all generated examples.
worker_service_account: If using Dataflow, the service account to run as.
min_dataflow_workers: If using Dataflow, the minimum number of workers to
instantiate.
max_dataflow_workers: If using Dataflow, the number of workers to
instantiate.
example_patch_size: Size of the example image.
Expand Down Expand Up @@ -174,7 +171,6 @@ class ExamplesGenerationConfig:
use_dataflow: bool = False
output_metadata_file: bool = True
worker_service_account: Optional[str] = None
min_dataflow_workers: int = 10
max_dataflow_workers: int = 20
example_patch_size: int = 64
large_patch_size: int = 256
Expand Down Expand Up @@ -874,7 +870,6 @@ def _generate_examples_pipeline(
cloud_project: Optional[str],
cloud_region: Optional[str],
worker_service_account: Optional[str],
min_workers: int,
max_workers: int,
wait_for_dataflow_job: bool,
cloud_detector_model_path: Optional[str],
Expand All @@ -898,7 +893,6 @@ def _generate_examples_pipeline(
cloud_project: Cloud project name.
cloud_region: Cloud region, e.g. us-central1.
worker_service_account: Email of service account that will launch workers.
min_workers: Minimum number of workers to use.
max_workers: Maximum number of workers to use.
wait_for_dataflow_job: If true, wait for dataflow job to complete before
returning.
Expand All @@ -914,7 +908,6 @@ def _generate_examples_pipeline(
cloud_project,
cloud_region,
temp_dir,
min_workers,
max_workers,
worker_service_account,
machine_type=None,
Expand Down Expand Up @@ -945,14 +938,28 @@ def _generate_examples_pipeline(
num_shards=num_output_shards))

if output_metadata_file:
rows = (
field_names = [
'example_id',
'encoded_coordinates',
'longitude',
'latitude',
'post_image_id',
'pre_image_id',
'plus_code',
]
_ = (
examples
| 'extract_metadata_rows' >> beam.Map(_get_example_metadata)
| 'remove_duplicates' >> beam.Distinct()
| 'convert_metadata_examples_to_dict' >> beam.Map(_get_example_metadata)
| 'combine_to_list' >> beam.combiners.ToList()
| 'write_metadata_to_file'
>> beam.ParDo(
WriteMetadataToCSVFn(
metadata_output_file_path=(
f'{output_dir}/examples/metadata_examples.csv'
), field_names=field_names
)
)
)
df = apache_beam.dataframe.convert.to_dataframe(rows)
output_prefix = f'{output_dir}/examples/metadata/metadata.csv'
apache_beam.dataframe.io.to_csv(df, output_prefix, index=False)

result = pipeline.run()
if wait_for_dataflow_job:
Expand Down Expand Up @@ -1082,14 +1089,34 @@ def run_example_generation(
config.cloud_project,
config.cloud_region,
config.worker_service_account,
config.min_dataflow_workers,
config.max_dataflow_workers,
wait_for_dataflow,
config.cloud_detector_model_path,
config.output_metadata_file
)


class WriteMetadataToCSVFn(beam.DoFn):
"""DoFn to write meta data of examples to csv file.
Attributes:
metadata_output_file_path: File path to output meta data of all examples.
field_names: Field names to be included in output file.
"""

def __init__(self, metadata_output_file_path: str, field_names: List[str]):
self.metadata_output_file_path = metadata_output_file_path
self.field_names = field_names

def process(self, element):
with tf.io.gfile.GFile(
self.metadata_output_file_path, 'w'
) as csv_output_file:
csv_writer = csv.DictWriter(csv_output_file, fieldnames=self.field_names)
csv_writer.writeheader()
csv_writer.writerows(element)


class ExampleType(typing.NamedTuple):
example_id: str
encoded_coordinates: str
Expand All @@ -1102,16 +1129,21 @@ class ExampleType(typing.NamedTuple):

@beam.typehints.with_output_types(ExampleType)
def _get_example_metadata(example: tf.train.Example) -> ExampleType:
return ExampleType(
example_id=utils.get_bytes_feature(example, 'example_id')[0].decode(),
encoded_coordinates=utils.get_bytes_feature(
example, 'encoded_coordinates'
)[0].decode(),
longitude=utils.get_float_feature(example, 'coordinates')[0],
latitude=utils.get_float_feature(example, 'coordinates')[1],
post_image_id=utils.get_bytes_feature(example, 'post_image_id')[
0
].decode(),
pre_image_id=utils.get_bytes_feature(example, 'pre_image_id')[0].decode(),
plus_code=utils.get_bytes_feature(example, 'plus_code')[0].decode(),
)
example_id = utils.get_bytes_feature(example, 'example_id')[0].decode()
encoded_coordinates = utils.get_bytes_feature(example, 'encoded_coordinates')[
0
].decode()
longitude, latitude = utils.get_float_feature(example, 'coordinates')
post_image_id = utils.get_bytes_feature(example, 'post_image_id')[0].decode()
pre_image_id = utils.get_bytes_feature(example, 'pre_image_id')[0].decode()
plus_code = utils.get_bytes_feature(example, 'plus_code')[0].decode()

return dict({
'example_id': example_id,
'encoded_coordinates': encoded_coordinates,
'longitude': longitude,
'latitude': latitude,
'post_image_id': post_image_id,
'pre_image_id': pre_image_id,
'plus_code': plus_code,
})
20 changes: 8 additions & 12 deletions src/skai/generate_examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

"""Tests for generate_examples.py."""

import glob
import os
import pathlib
import tempfile
Expand Down Expand Up @@ -478,7 +477,6 @@ def testGenerateExamplesPipeline(self):
cloud_project=None,
cloud_region=None,
worker_service_account=None,
min_workers=0,
max_workers=0,
wait_for_dataflow_job=True,
cloud_detector_model_path=None,
Expand Down Expand Up @@ -517,7 +515,6 @@ def testGenerateExamplesWithOutputMetaDataFile(self):
cloud_project=None,
cloud_region=None,
worker_service_account=None,
min_workers=0,
max_workers=0,
wait_for_dataflow_job=True,
cloud_detector_model_path=None,
Expand All @@ -527,27 +524,26 @@ def testGenerateExamplesWithOutputMetaDataFile(self):
tfrecords = os.listdir(
os.path.join(output_dir, 'examples', 'unlabeled-large')
)
metadata_pattern = os.path.join(
output_dir, 'examples', 'metadata', 'metadata.csv-*-of-*'
df_metadata_contents = pd.read_csv(
os.path.join(output_dir, 'examples', 'metadata_examples.csv')
)
metadata = pd.concat([pd.read_csv(p) for p in glob.glob(metadata_pattern)])

# No assert for example_id as each example_id depends on the image path
# which varies with platforms where this test is run
self.assertEqual(
metadata.encoded_coordinates[0], 'A17B32432A1085C1'
df_metadata_contents.encoded_coordinates[0], 'A17B32432A1085C1'
)
self.assertAlmostEqual(
metadata.latitude[0], -16.632892608642578
df_metadata_contents.latitude[0], -16.632892608642578
)
self.assertAlmostEqual(
metadata.longitude[0], 178.48292541503906
df_metadata_contents.longitude[0], 178.48292541503906
)
self.assertEqual(metadata.pre_image_id[0], self.test_image_path)
self.assertEqual(df_metadata_contents.pre_image_id[0], self.test_image_path)
self.assertEqual(
metadata.post_image_id[0], self.test_image_path
df_metadata_contents.post_image_id[0], self.test_image_path
)
self.assertEqual(metadata.plus_code[0], '5VMW9F8M+R5V8F4')
self.assertEqual(df_metadata_contents.plus_code[0], '5VMW9F8M+R5V8F4')
self.assertSameElements(tfrecords, ['unlabeled-00000-of-00001.tfrecord'])

def testConfigLoadedCorrectlyFromJsonFile(self):
Expand Down
60 changes: 21 additions & 39 deletions src/skai/labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,26 +220,6 @@ def sample_with_buffer(
return sample


def _read_sharded_csvs(pattern: str) -> pd.DataFrame:
"""Reads CSV shards matching pattern and merges them."""
paths = tf.io.gfile.glob(pattern)
if not paths:
raise ValueError(f'File pattern {pattern} did not match any files.')
dfs = []
expected_columns = None
for path in paths:
with tf.io.gfile.GFile(path, 'r') as f:
df = pd.read_csv(f)
if expected_columns is None:
expected_columns = set(df.columns)
else:
actual_columns = set(df.columns)
if actual_columns != expected_columns:
raise ValueError(f'Inconsistent columns in file {path}')
dfs.append(df)
return pd.concat(dfs, ignore_index=True)


def get_buffered_example_ids(
examples_pattern: str,
buffered_sampling_radius: float,
Expand All @@ -258,23 +238,25 @@ def get_buffered_example_ids(
Returns:
Set of allowed example ids.
"""
root_dir = '/'.join(examples_pattern.split('/')[:-2])
single_csv_pattern = str(os.path.join(root_dir, 'metadata_examples.csv'))
if tf.io.gfile.exists(single_csv_pattern):
metadata = _read_sharded_csvs(single_csv_pattern)
else:
sharded_csv_pattern = str(
os.path.join(
root_dir,
'metadata',
'metadata.csv-*-of-*',
)
)
metadata = _read_sharded_csvs(sharded_csv_pattern)

metadata = metadata[
~metadata['example_id'].isin(excluded_example_ids)
].reset_index(drop=True)
metadata_path = str(
os.path.join(
'/'.join(examples_pattern.split('/')[:-2]),
'metadata_examples.csv',
)
)
with tf.io.gfile.GFile(metadata_path, 'r') as f:
try:
df_metadata = pd.read_csv(f)
df_metadata = df_metadata[
~df_metadata['example_id'].isin(excluded_example_ids)
].reset_index(drop=True)
except tf.errors.NotFoundError as error:
raise SystemExit(
f'\ntf.errors.NotFoundError: {metadata_path} was not found\nUse'
' examples_to_csv module to generate metadata_examples.csv and/or'
' put metadata_examples.csv in the appropriate directory that is'
' PATH_DIR/examples/'
) from error

logging.info(
'Randomly searching for buffered samples with buffer radius %.2f'
Expand All @@ -283,11 +265,11 @@ def get_buffered_example_ids(
)
points = utils.convert_to_utm(
gpd.GeoSeries(
gpd.points_from_xy(metadata['longitude'], metadata['latitude']),
gpd.points_from_xy(df_metadata['longitude'], df_metadata['latitude']),
crs=4326,
)
)
gpd_df = gpd.GeoDataFrame(metadata, geometry=points)
gpd_df = gpd.GeoDataFrame(df_metadata, geometry=points)
max_examples = len(gpd_df) if max_examples is None else max_examples
df_buffered_samples = sample_with_buffer(
gpd_df, max_examples, buffered_sampling_radius
Expand Down
Loading

0 comments on commit d2345a7

Please sign in to comment.