Skip to content

Commit

Permalink
Output deduped zero-shot CSV.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685729466
  • Loading branch information
jzxu authored and copybara-github committed Oct 14, 2024
1 parent ba05ed5 commit fa37891
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/skai/model/vlm_zero_shot_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,23 @@ def image_preprocessing(example):
return dataset


def _dedup_predictions(predictions: pd.DataFrame):
if 'is_cloudy' in predictions.columns:
non_cloudy = predictions[predictions['is_cloudy'] == 0]
else:
non_cloudy = predictions
return non_cloudy.groupby('building_id').agg({
'damage_score': 'mean',
'cloud_score': 'mean',
'longitude': 'mean',
'latitude': 'mean',
'example_id': 'first',
'int64_id': 'first',
'plus_code': 'first',
'label': 'first',
})


def generate_zero_shot_assessment(
model_config: ml_collections.ConfigDict,
damage_label_file_path: str,
Expand Down Expand Up @@ -379,3 +396,9 @@ def generate_zero_shot_assessment(
f'{output_dir}/{dataset_name}_output.csv', 'w'
) as output_csv_file:
output_df.to_csv(output_csv_file, index=False)

deduped = _dedup_predictions(output_df)
with tf.io.gfile.GFile(
f'{output_dir}/{dataset_name}_deduped.csv', 'w'
) as deduped_file:
deduped.to_csv(deduped_file, index=False)

0 comments on commit fa37891

Please sign in to comment.