Skip to content

Commit

Permalink
Refactored projections
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed May 18, 2023
1 parent d266efe commit 72882b4
Show file tree
Hide file tree
Showing 3 changed files with 581 additions and 31 deletions.
94 changes: 71 additions & 23 deletions backend/kangas/datatypes/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,34 @@
######################################################

import json
import random
import time

from ..server.utils import pickle_dumps
from ..server.utils import Cache, pickle_dumps
from .base import Asset
from .utils import flatten, get_color, get_file_extension, is_valid_file_path
from .utils import get_color, get_file_extension, is_valid_file_path

PROJECTION_DIMENSIONS = 50

SAMPLE_CACHE = Cache(100)


def prepare_embedding(embedding, dimensions, seed):
if len(embedding) <= dimensions:
return embedding

key = (seed, dimensions)
if not SAMPLE_CACHE.contains(key):
random.seed(seed)
indices = list(range(len(embedding)))
random.shuffle(indices)
SAMPLE_CACHE.put(key, set(indices[:dimensions]))

indices = SAMPLE_CACHE.get(key)

return [v for i, v in enumerate(embedding) if i in indices]

SAMPLE_CACHE.get(key)


class Embedding(Asset):
Expand All @@ -37,6 +61,7 @@ def __init__(
metadata=None,
source=None,
unserialize=False,
dimensions=PROJECTION_DIMENSIONS,
):
"""
Create an embedding vector.
Expand All @@ -53,6 +78,7 @@ def __init__(
include: (bool) whether to include this vector when determining the
projection. Useful if you want to see one part of the datagrid in
the project of another.
dimensions: (int) maximum number of dimensions
Example:
Expand Down Expand Up @@ -88,6 +114,7 @@ def __init__(
self.metadata["color"] = color
self.metadata["projection"] = projection
self.metadata["include"] = include
self.metadata["dimensions"] = dimensions

if file_name:
if is_valid_file_path(file_name):
Expand All @@ -98,6 +125,7 @@ def __init__(
"name": name,
"color": color,
"text": text,
"dimensions": dimensions,
}
)
self.metadata["extension"] = get_file_extension(file_name)
Expand All @@ -106,7 +134,14 @@ def __init__(
raise ValueError("file not found: %r" % file_name)
else:
self.asset_data = json.dumps(
{"vector": embedding, "name": name, "color": color, "text": text}
{
"vector": embedding,
"name": name,
"color": color,
"text": text,
"dimensions": dimensions,
"include": include,
}
)
if metadata:
self.metadata.update(metadata)
Expand Down Expand Up @@ -135,39 +170,48 @@ def get_statistics(cls, datagrid, col_name, field_name):
stddev = None
other = None
name = col_name
seed = time.time() # set the same for all embeddings

projection = None
batch = []
for row in datagrid.conn.execute(
"""SELECT {field_name} as assetId, asset_data, json_extract(asset_metadata, '$.projection'), json_extract(asset_metadata, '$.include') from datagrid JOIN assets ON assetId = assets.asset_id;""".format(
"""SELECT {field_name} as assetId, asset_data, asset_metadata from datagrid JOIN assets ON assetId = assets.asset_id;""".format(
field_name=field_name
)
):
asset_id, asset_data_json, asset_metadata_json = row
if not asset_metadata_json:
continue

asset_metdata = json.loads(asset_metadata_json)
projection = asset_metdata["projection"]
include = asset_metdata["include"]
dimensions = asset_metdata["dimensions"]

# Skip if explicitly False
if row[3] is False:
if not include:
continue

embedding = json.loads(row[1])
vectors = embedding["vector"]
vector = flatten(vectors)
asset_data = json.loads(asset_data_json)
vector = prepare_embedding(asset_data["vector"], dimensions, seed)

batch.append(vector)
if row[2] is None or row[2] == "pca":
if projection is None or projection == "pca":
projection_name = "pca"
elif row[2] == "t-sne":
elif projection == "t-sne":
projection_name = "t-sne"
elif row[2] == "umap":
elif projection == "umap":
projection_name = "umap"

if projection_name == "pca":
from sklearn.decomposition import PCA

projection = PCA()
embedding = projection.fit_transform(np.array(batch))
x_max = float(embedding[:, 0].max())
x_min = float(embedding[:, 0].min())
y_max = float(embedding[:, 1].max())
y_min = float(embedding[:, 1].min())
projection = PCA(n_components=2)
transformed = projection.fit_transform(np.array(batch))
x_max = float(transformed[:, 0].max())
x_min = float(transformed[:, 0].min())
y_max = float(transformed[:, 1].max())
y_min = float(transformed[:, 1].min())
x_span = abs(x_max - x_min)
x_max += x_span * 0.1
x_min -= x_span * 0.1
Expand All @@ -181,17 +225,19 @@ def get_statistics(cls, datagrid, col_name, field_name):
"projection": projection_name,
"x_range": [x_min, x_max],
"y_range": [y_min, y_max],
"dimensions": dimensions,
"seed": seed,
}
)
elif projection_name == "t-sne":
from openTSNE import TSNE

projection = TSNE()
embedding = projection.fit(np.array(batch))
x_max = float(embedding[:, 0].max())
x_min = float(embedding[:, 0].min())
y_max = float(embedding[:, 1].max())
y_min = float(embedding[:, 1].min())
transformed = projection.fit(np.array(batch))
x_max = float(transformed[:, 0].max())
x_min = float(transformed[:, 0].min())
y_max = float(transformed[:, 1].max())
y_min = float(transformed[:, 1].min())
x_span = abs(x_max - x_min)
x_max += x_span * 0.1
x_min -= x_span * 0.1
Expand All @@ -201,9 +247,11 @@ def get_statistics(cls, datagrid, col_name, field_name):
other = json.dumps(
{
"projection": projection_name,
"embedding": pickle_dumps(embedding),
"pickled_projection": pickle_dumps(transformed),
"x_range": [x_min, x_max],
"y_range": [y_min, y_max],
"dimensions": dimensions,
"seed": seed,
}
)
elif projection_name == "umap":
Expand Down
35 changes: 27 additions & 8 deletions backend/kangas/server/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2239,10 +2239,13 @@ def process_projection_asset_ids(
size,
default_color,
color_override=None,
projection_dimensions=None,
projection_seed=None,
):
from ..datatypes.embedding import prepare_embedding

# asset_ids is a list of str
# side-effect: adds to traces

# Turn to string:
values = "(" + (",".join(["'%s'" % asset_id for asset_id in asset_ids])) + ")"
if values == "()":
Expand All @@ -2253,10 +2256,14 @@ def process_projection_asset_ids(
)

trace_data = {}
print("using seed", projection_seed)

for asset_data_row in cur.execute(sql):
asset_data_raw = asset_data_row[0]
asset_data = json.loads(asset_data_raw)
vector = asset_data["vector"]
vector_reduced = prepare_embedding(
asset_data["vector"], projection_dimensions, projection_seed
)
if color_override:
color = color_override
elif asset_data["color"]:
Expand All @@ -2283,7 +2290,7 @@ def process_projection_asset_ids(
}

trace_data[trace_name]["texts"].append(asset_data.get("text"))
trace_data[trace_name]["vectors"].append(vector)
trace_data[trace_name]["vectors"].append(vector_reduced)
trace_data[trace_name]["colors"].append(color)
trace_data[trace_name]["customdata"].append(row_id)

Expand Down Expand Up @@ -2333,6 +2340,8 @@ def select_projection_data(
where_expr,
computed_columns,
):
from ..datatypes.embedding import prepare_embedding

conn = get_database_connection(dgid)
cur = conn.cursor()
unify_computed_columns(computed_columns)
Expand All @@ -2350,14 +2359,14 @@ def select_projection_data(

pca_eigen_vectors = metadata[column_name]["other"]["pca_eigen_vectors"]
pca_mean = metadata[column_name]["other"]["pca_mean"]
projection = PCA()
projection = PCA(n_components=2)
projection.components_ = np.array(pca_eigen_vectors)
projection.mean_ = np.array(pca_mean)
elif projection_name == "t-sne":
# FIXME: Trying to prevent an error on first load; race condition?
from openTSNE import TSNE # noqa

ascii_string = metadata[column_name]["other"]["embedding"]
ascii_string = metadata[column_name]["other"]["pickled_projection"]
if not PROJECTION_EMBEDDING_CACHE.contains(ascii_string):
PROJECTION_EMBEDDING_CACHE.put(
ascii_string, pickle_loads_embedding_unsafe(ascii_string)
Expand All @@ -2370,6 +2379,8 @@ def select_projection_data(
return

default_color = get_color(column_name)
projection_dimensions = metadata[column_name]["other"]["dimensions"]
projection_seed = metadata[column_name]["other"]["seed"]

traces = []
if asset_id:
Expand Down Expand Up @@ -2407,6 +2418,8 @@ def select_projection_data(
3,
default_color,
"lightgray",
projection_dimensions,
projection_seed,
)
PROJECTION_TRACE_CACHE.put(key, traces)
# Traces contains projection data; make copy:
Expand All @@ -2415,7 +2428,10 @@ def select_projection_data(
# Next, add the selected asset:
asset_data_raw = select_asset(dgid, asset_id)
asset_data = json.loads(asset_data_raw)
vector = projection.transform(np.array([asset_data["vector"]]))
vector_reduced = prepare_embedding(
asset_data["vector"], projection_dimensions, projection_seed
)
transformed = projection.transform(np.array([vector_reduced]))
if asset_data["color"]:
color = asset_data["color"]
else:
Expand All @@ -2427,8 +2443,8 @@ def select_projection_data(
text = asset_data.get("text", column_name)

trace = {
"x": [round(vector[0][0], 3)],
"y": [round(vector[0][1], 3)],
"x": [transformed[0][0]],
"y": [transformed[0][1]],
"text": text,
"name": text,
"type": "scatter",
Expand Down Expand Up @@ -2473,6 +2489,9 @@ def select_projection_data(
traces,
3,
default_color,
None,
projection_dimensions,
projection_seed,
)
PROJECTION_TRACE_CACHE.put(key, traces)
# Traces contains projection data; make copy:
Expand Down
Loading

0 comments on commit 72882b4

Please sign in to comment.