From 1d6aa98847ed3012dd53d03edd04c6a9d30ba694 Mon Sep 17 00:00:00 2001 From: Sanjay Date: Tue, 20 Aug 2024 15:24:41 +0530 Subject: [PATCH] Added bilinear function --- .../frontends/paddle/nn/functional/common.py | 25 +++++++++ .../test_nn/test_functional/test_common.py | 56 +++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/ivy/functional/frontends/paddle/nn/functional/common.py b/ivy/functional/frontends/paddle/nn/functional/common.py index 882ce9a543f64..f3c078be45133 100644 --- a/ivy/functional/frontends/paddle/nn/functional/common.py +++ b/ivy/functional/frontends/paddle/nn/functional/common.py @@ -195,3 +195,28 @@ def zeropad2d(x, padding, data_format="NCHW", name=None): else: raise ValueError(f"Unknown data_format: {data_format}") return ivy.pad(x, padding, mode="constant", constant_values=0.0) + +@to_ivy_arrays_and_back +@with_supported_dtypes({"2.6.0 and below": ("float32", "float64")}, "paddle") +def bilinear(x1, x2, weight, bias=None, name=None): + x1_shape = ivy.shape(x1) + x2_shape = ivy.shape(x2) + + if len(x1_shape) == 2: + x1 = ivy.expand_dims(x1, axis=1) + if len(x2_shape) == 2: + x2 = ivy.expand_dims(x2, axis=1) + + output_shape = list(ivy.shape(x1)) + output_shape[-1] = ivy.shape(weight)[0] + + x1 = ivy.expand_dims(x1, axis=-1) + x2 = ivy.expand_dims(x2, axis=-2) + + output = ivy.matmul(x1, ivy.matmul(weight, x2)) + output = ivy.squeeze(output, axis=[-1, -2]) + + if bias is not None: + output = ivy.add(output, bias) + + return ivy.reshape(output, output_shape) \ No newline at end of file diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py index 871bbe9ea7092..4ca5f8c20c751 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_common.py @@ -514,3 +514,59 @@ def test_paddle_zeropad2d( padding=padding, data_format=dataformat, ) + +@handle_frontend_test( + fn_tree="paddle.nn.functional.common.bilinear", + dtype_and_inputs=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=3, + shared_dtype=True, + min_value=-1.0, + max_value=1.0, + min_num_dims=2, + max_num_dims=3, + min_dim_size=2, + max_dim_size=5, + ), + with_bias=st.booleans(), +) +def test_paddle_bilinear( + *, + dtype_and_inputs, + with_bias, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, inputs = dtype_and_inputs + x1, x2, weight = inputs + + if len(x1.shape) == 2: + output_size = weight.shape[0] + weight = ivy.reshape(weight, (output_size, x1.shape[1], x2.shape[1])) + else: + output_size = weight.shape[0] + + if with_bias: + bias = ivy.random_uniform( + shape=(output_size,), + dtype=input_dtype[0], + device=on_device, + ) + else: + bias = None + + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x1=x1, + x2=x2, + weight=weight, + bias=bias, + ) \ No newline at end of file