Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple writers #140

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/koza/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def next_row():
"""
raise NextRowException

def write(self, *entities):
def write(self, *entities, split: bool = False):
# If a schema/validator is defined, validate before writing
# if self.validate:
if hasattr(self, 'schema'):
Expand All @@ -168,7 +168,7 @@ def write(self, *entities):
for edge in edges:
validate(instance=edge, target_class=self.edge_type, schema=self.schema, strict=True)

self.writer.write(entities)
self.writer.write(entities, split=split)

def _get_writer(self) -> Union[TSVWriter, JSONLWriter]:
writer_params = [
Expand Down
4 changes: 2 additions & 2 deletions src/koza/io/writer/jsonl_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def __init__(
if edge_properties:
self.edgeFH = open(f"{output_dir}/{source_name}_edges.jsonl", "w")

def write(self, entities: Iterable):
def write(self, entities: Iterable, split: bool = False) -> None:
(nodes, edges) = self.converter.convert(entities)

# TODO: implement split
if nodes:
for n in nodes:
node = json.dumps(n, ensure_ascii=False)
Expand Down
63 changes: 51 additions & 12 deletions src/koza/io/writer/tsv_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# NOTE - May want to rename to KGXWriter at some point, if we develop writers for other models non biolink/kgx specific

from pathlib import Path
import shutil
from typing import Dict, Iterable, List, Literal, Set, Union

from ordered_set import OrderedSet
Expand All @@ -13,6 +14,8 @@


class TSVWriter(KozaWriter):
_splits_cleanup_done = False

def __init__(
self,
output_dir: Union[str, Path],
Expand Down Expand Up @@ -44,39 +47,75 @@ def __init__(
self.edgeFH = open(self.edges_file_name, "w")
self.edgeFH.write(self.delimiter.join(self.edge_columns) + "\n")

def write(self, entities: Iterable) -> None:
def write(self, entities: Iterable, split: bool = False) -> None:
"""Write an entities object to separate node and edge .tsv files"""

nodes, edges = self.converter.convert(entities)

if nodes:
for node in nodes:
self.write_row(node, record_type="node")
self.write_row(node, record_type="node", split=split)

if edges:
for edge in edges:
if self.sssom_config:
edge = self.sssom_config.apply_mapping(edge)
self.write_row(edge, record_type="edge")
self.write_row(edge, record_type="edge", split=split)

def write_row(self, record: Dict, record_type: Literal["node", "edge"]) -> None:
def write_row(self, record: Dict, record_type: Literal["node", "edge"], split: bool = False) -> None:
"""Write a row to the underlying store.

Args:
record: Dict - A node or edge record
record_type: Literal["node", "edge"] - The record_type of record
"""

def get_new_fh_path(base_dir, filename, category):
new_dir = base_dir / "splits"
if not self._splits_cleanup_done:
shutil.rmtree(new_dir, ignore_errors=True)
self._splits_cleanup_done = True
new_dir.mkdir(parents=True, exist_ok=True)
return new_dir / filename.replace(record_type + "s", f"{category}_{record_type}s")

fh = self.nodeFH if record_type == "node" else self.edgeFH
columns = self.node_columns if record_type == "node" else self.edge_columns
row = build_export_row(record, list_delimiter=self.list_delimiter)
values = []
if record_type == "node":
row["id"] = record["id"]
for c in columns:
if c in row:
values.append(str(row[c]))

if split:
base_dir, filename = Path(fh.name).parent, getattr(self, f"{record_type}s_file_name").name
if record_type == "node":
category = record.get("category", [""])[0].split(":")[-1]
else:
values.append("")
subject_category = (
record.get("subject_category", "").split(":")[-1]
if record.get("subject_category")
else "UnknownCategory"
)

object_category = (
record.get("object_category", "").split(":")[-1]
if record.get("object_category")
else "UnknownCategory"
)

edge_category = (
record.get("category", [""])[0].split(":")[-1] if record.get("category") else "UnknownCategory"
)

category = subject_category + edge_category + object_category

new_fh_path = get_new_fh_path(base_dir, filename, category)

with open(new_fh_path, "a+") as new_fh:
if new_fh.tell() == 0: # Check if file is empty
new_fh.write(self.delimiter.join(columns) + "\n")
self._write_record_to_file(new_fh, record, columns)

self._write_record_to_file(fh, record, columns)

def _write_record_to_file(self, fh, record, columns):
row = build_export_row(record, list_delimiter=self.list_delimiter)
values = [str(row.get(c, "")) for c in columns]
fh.write(self.delimiter.join(values) + "\n")

def finalize(self):
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/test_tsvwriter_node_and_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,74 @@ def test_tsv_writer():
assert os.path.exists("{}/{}_nodes.tsv".format(outdir, outfile)) and os.path.exists(
"{}/{}_edges.tsv".format(outdir, outfile)
)


def test_tsv_writer_split():
"""
Writes a test tsv file
"""
g1 = Gene(id="HGNC:11603", name="TBX4", category=["biolink:Gene"])
d1 = Disease(id="MONDO:0005002", name="chronic obstructive pulmonary disease", category=["biolink:Disease"])
a1 = GeneToDiseaseAssociation(
id="uuid:5b06e86f-d768-4cd9-ac27-abe31e95ab1e",
subject=g1.id,
object=d1.id,
predicate="biolink:contributes_to",
knowledge_level="not_provided",
agent_type="not_provided",
subject_category="biolink:Gene",
object_category="biolink:Disease",
)
g2 = Gene(id="HGNC:11604", name="TBX5", category=["biolink:Gene"])
d2 = Disease(id="MONDO:0005003", name="asthma")
a2 = GeneToDiseaseAssociation(
id="uuid:5b06e86f-d768-4cd9-ac27-abe31e95ab1f",
subject=g2.id,
object=d2.id,
predicate="biolink:contributes_to",
knowledge_level="not_provided",
agent_type="not_provided",
)
g3 = Gene(id="HGNC:11605", name="TBX6")
d3 = Disease(id="MONDO:0005004", name="lung cancer", category=["biolink:Disease"])
a3 = GeneToDiseaseAssociation(
id="uuid:5b06e86f-d768-4cd9-ac27-abe31e95ab1g",
subject=g3.id,
object=d3.id,
predicate="biolink:contributes_to",
knowledge_level="not_provided",
agent_type="not_provided",
)
g4 = Gene(id="HGNC:11606", name="TBX7")
d4 = Disease(id="MONDO:0005005", name="pulmonary fibrosis")
a4 = GeneToDiseaseAssociation(
id="uuid:5b06e86f-d768-4cd9-ac27-abe31e95ab1h",
subject=g4.id,
object=d4.id,
predicate="biolink:contributes_to",
knowledge_level="not_provided",
agent_type="not_provided",
)

ents = [[g1, d1, a1], [g2, d2, a2], [g3, d3, a3], [g4, d4, a4]]

node_properties = ["id", "category", "symbol", "in_taxon", "provided_by", "source"]
edge_properties = ["id", "subject", "predicate", "object", "category" "qualifiers", "publications", "provided_by"]

outdir = "output/tests/split-examples"
outfile = "tsvwriter"
split_edge_file_substring = "UnknownCategoryGeneToDiseaseAssociationUnknownCategory"

t = TSVWriter(outdir, outfile, node_properties, edge_properties)
for ent in ents:
t.write(ent, split=True)

t.finalize()

assert os.path.exists("{}/splits/{}_Disease_nodes.tsv".format(outdir, outfile))
assert os.path.exists("{}/splits/{}_{}_edges.tsv".format(outdir, outfile, split_edge_file_substring))
assert os.path.exists("{}/splits/{}_Gene_nodes.tsv".format(outdir, outfile))

assert os.path.exists("{}/{}_nodes.tsv".format(outdir, outfile)) and os.path.exists(
"{}/{}_edges.tsv".format(outdir, outfile)
)