diff --git a/src/koza/cli_utils.py b/src/koza/cli_utils.py index edc1393..e62b7e9 100644 --- a/src/koza/cli_utils.py +++ b/src/koza/cli_utils.py @@ -3,6 +3,7 @@ """ from pathlib import Path +import os from typing import Dict, Literal, Optional, Union import yaml @@ -126,6 +127,68 @@ def _check_row_count(type: Literal["node", "edge"]): _check_row_count("edge") +def split_file(file: str, + fields: str, + format: OutputFormat = OutputFormat.tsv, + remove_prefixes: bool = False, + output_dir: str = "./output"): + db = duckdb.connect(":memory:") + + #todo: validate that each of the fields is actually a column in the file + if format == OutputFormat.tsv: + read_file = f"read_csv('{file}')" + elif format == OutputFormat.json: + read_file = f"read_json('{file}')" + else: + raise ValueError(f"Format {format} not supported") + + values = db.execute(f'SELECT DISTINCT {fields} FROM {read_file};').fetchall() + keys = fields.split(',') + list_of_value_dicts = [dict(zip(keys, v)) for v in values] + + def clean_value_for_filename(value): + if remove_prefixes and ':' in value: + value = value.split(":")[-1] + + return value.replace("biolink:", "").replace(" ", "_").replace(":", "_") + + def generate_filename_from_row(row): + return "_".join([clean_value_for_filename(row[k]) for k in keys if row[k] is not None]) + + def get_filename_prefix(name): + # get just the filename part of the path + name = os.path.basename(name) + if name.endswith('_edges.tsv'): + return name[:-9] + elif name.endswith('_nodes.tsv'): + return name[:-9] + else: + raise ValueError(f"Unexpected file name {name}, not sure how to make am output prefix for it") + + def get_filename_suffix(name): + if name.endswith('_edges.tsv'): + return '_edges.tsv' + elif name.endswith('_nodes.tsv'): + return '_nodes.tsv' + else: + raise ValueError(f"Unexpected file name {name}, not sure how to make am output prefix for it") + + # create output/split if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + for row in list_of_value_dicts: + # export to a tsv file named with the values of the pivot fields + where_clause = ' AND '.join([f"{k} = '{row[k]}'" for k in keys]) + file_name = output_dir + "/" + get_filename_prefix(file) + generate_filename_from_row(row) + get_filename_suffix(file) + print(f"writing {file_name}") + db.execute(f""" + COPY ( + SELECT * FROM {read_file} + WHERE {where_clause} + ) TO '{file_name}' (HEADER, DELIMITER '\t'); + """) + + def validate_file( file: str, format: FormatType = FormatType.csv, diff --git a/src/koza/main.py b/src/koza/main.py index fe4008b..fa4a27b 100755 --- a/src/koza/main.py +++ b/src/koza/main.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Optional -from koza.cli_utils import transform_source, validate_file +from koza.cli_utils import transform_source, validate_file, split_file from koza.model.config.source_config import FormatType, OutputFormat import typer @@ -65,6 +65,15 @@ def validate( """Validate a source file""" validate_file(file, format, delimiter, header_delimiter, skip_blank_lines) +@typer_app.command() +def split( + file: str = typer.Argument(..., help="Path to the source kgx file to be split"), + fields: str = typer.Argument(..., help="Comma separated list of fields to split on"), + remove_prefixes: bool = typer.Option(False, help="Remove prefixes from the file names for values from the specified fields. (e.g, NCBIGene:9606 becomes 9606"), + output_dir: str = typer.Option(default="output", help="Path to output directory"), +): + """Split a file by fields""" + split_file(file, fields,remove_prefixes=remove_prefixes, output_dir=output_dir) if __name__ == "__main__": typer_app()