Skip to content

Commit

Permalink
Color Transfer: add Initial Reference Image parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Splendide-Imaginarius committed Aug 30, 2024
1 parent c852759 commit 1215c36
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
5 changes: 4 additions & 1 deletion backend/src/nodes/impl/color_transfer/mean_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def scale_array(
def mean_std_transfer(
img: np.ndarray,
ref_img: np.ndarray,
init_img: np.ndarray,
colorspace: TransferColorSpace,
overflow_method: OverflowMethod,
valid_indices: np.ndarray,
Expand Down Expand Up @@ -118,12 +119,14 @@ def mean_std_transfer(
c_clip_min, c_clip_max = (-127, 127)
img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2LAB)
init_img = cv2.cvtColor(init_img, cv2.COLOR_BGR2LAB)
elif colorspace == TransferColorSpace.RGB:
a_clip_min, a_clip_max = (0, 1)
b_clip_min, b_clip_max = (0, 1)
c_clip_min, c_clip_max = (0, 1)
img = img[:, :, :3]
ref_img = ref_img[:, :, :3]
init_img = init_img[:, :, :3]
else:
raise ValueError(f"Invalid color space {colorspace}")

Expand All @@ -135,7 +138,7 @@ def mean_std_transfer(
b_std_tar,
c_mean_tar,
c_std_tar,
) = image_stats(img[valid_indices])
) = image_stats(init_img[valid_indices])
(
a_mean_src,
a_std_src,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ class TransferColorAlgorithm(Enum):
icon="MdInput",
inputs=[
ImageInput("Image", channels=[3, 4]),
ImageInput("Reference Image", channels=[3, 4]),
ImageInput("Goal Reference Image", channels=[3, 4]),
EnumInput(
TransferColorAlgorithm,
label="Algorithm",
option_labels=TRANSFER_COLOR_ALGORITHM_LABELS,
default=TransferColorAlgorithm.MEAN_STD,
).with_id(5),
if_enum_group(5, TransferColorAlgorithm.MEAN_STD)(
ImageInput("Initial Reference Image", channels=[3, 4]).make_optional().with_id(6),
EnumInput(
TransferColorSpace,
label="Colorspace",
Expand All @@ -65,10 +66,14 @@ def color_transfer_node(
img: np.ndarray,
ref_img: np.ndarray,
algorithm: TransferColorAlgorithm,
init_img: np.ndarray | None,
colorspace: TransferColorSpace,
overflow_method: OverflowMethod,
reciprocal_scale: bool,
) -> np.ndarray:
if init_img is None:
init_img = img

_, _, img_c = get_h_w_c(img)

# Preserve alpha
Expand All @@ -77,6 +82,13 @@ def color_transfer_node(
alpha = img[:, :, 3]
bgr_img = img[:, :, :3]

_, _, init_img_c = get_h_w_c(init_img)

init_alpha = None
if init_img_c == 4:
init_alpha = init_img[:, :, 3]
bgr_init_img = init_img[:, :, :3]

_, _, ref_img_c = get_h_w_c(ref_img)

ref_alpha = None
Expand All @@ -86,9 +98,9 @@ def color_transfer_node(

# Don't process RGB data if the pixel is fully transparent, since
# such RGB data is indeterminate.
valid_rgb_indices = np.ones(img.shape[:-1], dtype=bool)
if alpha is not None:
valid_rgb_indices = alpha > 0
init_valid_rgb_indices = np.ones(init_img.shape[:-1], dtype=bool)
if init_alpha is not None:
init_valid_rgb_indices = init_alpha > 0

ref_valid_rgb_indices = np.ones(ref_img.shape[:-1], dtype=bool)
if ref_alpha is not None:
Expand All @@ -99,19 +111,20 @@ def color_transfer_node(
transfer = mean_std_transfer(
bgr_img,
bgr_ref_img,
bgr_init_img,
colorspace,
overflow_method,
reciprocal_scale=reciprocal_scale,
valid_indices=valid_rgb_indices,
valid_indices=init_valid_rgb_indices,
ref_valid_indices=ref_valid_rgb_indices,
)
elif algorithm == TransferColorAlgorithm.LINEAR_HISTOGRAM:
transfer = linear_histogram_transfer(
bgr_img, bgr_ref_img, valid_rgb_indices, ref_valid_rgb_indices
bgr_img, bgr_ref_img, init_valid_rgb_indices, ref_valid_rgb_indices
)
elif algorithm == TransferColorAlgorithm.PRINCIPAL_COLOR:
transfer = principal_color_transfer(
bgr_img, bgr_ref_img, valid_rgb_indices, ref_valid_rgb_indices
bgr_img, bgr_ref_img, init_valid_rgb_indices, ref_valid_rgb_indices
)

if alpha is not None:
Expand Down

0 comments on commit 1215c36

Please sign in to comment.