Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model with stack does not work with int8 target type #359

Open
spacycoder opened this issue Sep 24, 2024 · 15 comments · May be fixed by #360
Open

Model with stack does not work with int8 target type #359

spacycoder opened this issue Sep 24, 2024 · 15 comments · May be fixed by #360
Labels
bug Something isn't working

Comments

@spacycoder
Copy link

spacycoder commented Sep 24, 2024

Converting this dummy model with quantize_target_type="int8" and per_tensor=True throws an error in tflite

import torch.nn as nn
import torch
from tinynn.graph.quantization.quantizer import PostQuantizer
from tinynn.converter import TFLiteConverter

class StackModel(nn.Module):

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: [N, H, W, C]
        """
        return torch.stack([-x, x], dim=-1)


def _main():
    dummy_input = torch.rand(1, 60,  60, 256).float()

    model = StackModel()

    qat_config = {
        "backend": "qnnpack",
        "per_tensor": True,
        "disable_requantization_for_cat": True
    }
    quantizer = PostQuantizer(
        model, (dummy_input,), work_dir="stack_model", config=qat_config
    )

    ptq_coarse_matcher = quantizer.quantize()
    ptq_coarse_matcher(dummy_input)

    with torch.no_grad():
        ptq_coarse_matcher.eval()
        ptq_coarse_matcher.cpu()

        ptq_coarse_matcher = quantizer.convert(ptq_coarse_matcher)
        torch.backends.quantized.engine = quantizer.backend
        converter = TFLiteConverter(
            ptq_coarse_matcher,
            (dummy_input),
            "stack_model.tflite",
            fuse_quant_dequant=True,
            quantize_target_type="int8"
        )
        converter.convert()

if __name__ == '__main__':
    _main()

Tflite error:

    return self._interpreter.AllocateTensors()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: /tensorflow/tensorflow/lite/kernels/concatenation.cc:184 t->params.zero_point != output->params.zero_point (-1 != 0)Node number 3 (CONCATENATION) failed to prepare.

Note that the model works fine if I remove the "negative x" and instead send the same tensor twice, and it works with uint8

@spacycoder spacycoder changed the title Model with stack does not work with int8 quantization Model with stack does not work with int8 target type Sep 24, 2024
@peterjc123 peterjc123 added the bug Something isn't working label Sep 24, 2024
@peterjc123
Copy link
Collaborator

Well, we need to apply the same logic to stack.

@spacycoder
Copy link
Author

Seems to also happen if I turn it into a cat op:

class CatModel(nn.Module):

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: [N, H, W, C]
        """
        return torch.cat([-x.unsqueeze(-1), x.unsqueeze(-1)], dim=-1)

@peterjc123
Copy link
Collaborator

@spacycoder
What about this?

class CatModel(nn.Module):

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: [N, H, W, C]
        """
        z = x.unsqueeze(-1)
        return torch.cat([-z, z], dim=-1)

@spacycoder
Copy link
Author

That also fails

@peterjc123
Copy link
Collaborator

peterjc123 commented Sep 24, 2024

Or this?

class CatModel(nn.Module):

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: [N, H, W, C]
        """
        return torch.cat([-x, x], dim=-1).view(x.shape[:-1] + [-1, 2])

@spacycoder
Copy link
Author

Nope, doesn't work either

@peterjc123
Copy link
Collaborator

Okay, will look into it tomorrow.

@peterjc123
Copy link
Collaborator

@spacycoder It seems that the problem is on mul_scalar. The q-params for this op is calculated on the fly.

@peterjc123
Copy link
Collaborator

@spacycoder Things should work with #360

@peterjc123 peterjc123 linked a pull request Sep 25, 2024 that will close this issue
@spacycoder
Copy link
Author

spacycoder commented Sep 25, 2024

This also fails with the same concatenation error:

import torch.nn as nn
import torch
from tinynn.graph.quantization.quantizer import PostQuantizer
from tinynn.converter import TFLiteConverter

class EncoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int = 256
    ):
        super().__init__()
        self.mlp0 = nn.Linear(d_model, d_model, bias=False)
        self.mlp1 = nn.Linear(d_model * 2, d_model, bias=False)

    def forward(
        self,
        x: torch.Tensor,
    ):
        x = x.permute(0, 2, 3, 1)
        m = self.mlp0(x)
        m = torch.cat([x, m], dim=-1)
        m = self.mlp1(m)
        return x + m

class Dummy(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = EncoderLayer(256)

    def forward(self, x, y):
        x = self.encoder(x)
        y = self.encoder(y)
        return x, y

def _main():
    dummy_input0 = torch.rand(1, 256, 60,  60).float()
    dummy_input1 = torch.rand(1, 256, 60,  60).float()

    model = Dummy()

    ptq_config = {
        "backend": "qnnpack",
        "per_tensor": True,
        "disable_requantization_for_cat": True
    }
    quantizer = PostQuantizer(
        model, (dummy_input0, dummy_input1), work_dir="cat_model", config=ptq_config
    )

    ptq_model = quantizer.quantize()
    ptq_model(dummy_input0, dummy_input1)

    with torch.no_grad():
        ptq_model.eval()
        ptq_model.cpu()

        ptq_model = quantizer.convert(ptq_model)
        torch.backends.quantized.engine = quantizer.backend
        converter = TFLiteConverter(
            ptq_model,
            (dummy_input0, dummy_input1),
            "cat_model.tflite",
            fuse_quant_dequant=True,
            quantize_target_type="int8"
        )
        converter.convert()

if __name__ == '__main__':
    _main()

@spacycoder
Copy link
Author

spacycoder commented Sep 25, 2024

FYI having two separate encoders works (but I need them to be the same):

class Dummy(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder0 = EncoderLayer(256)
        self.encoder1 = EncoderLayer(256)

    def forward(self, x, y):
        x = self.encoder0(x)
        y = self.encoder1(y)
        return x, y

@peterjc123
Copy link
Collaborator

peterjc123 commented Sep 26, 2024

This also fails with the same concatenation error:

import torch.nn as nn
import torch
from tinynn.graph.quantization.quantizer import PostQuantizer
from tinynn.converter import TFLiteConverter

class EncoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int = 256
    ):
        super().__init__()
        self.mlp0 = nn.Linear(d_model, d_model, bias=False)
        self.mlp1 = nn.Linear(d_model * 2, d_model, bias=False)

    def forward(
        self,
        x: torch.Tensor,
    ):
        x = x.permute(0, 2, 3, 1)
        m = self.mlp0(x)
        m = torch.cat([x, m], dim=-1)
        m = self.mlp1(m)
        return x + m

class Dummy(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = EncoderLayer(256)

    def forward(self, x, y):
        x = self.encoder(x)
        y = self.encoder(y)
        return x, y

def _main():
    dummy_input0 = torch.rand(1, 256, 60,  60).float()
    dummy_input1 = torch.rand(1, 256, 60,  60).float()

    model = Dummy()

    ptq_config = {
        "backend": "qnnpack",
        "per_tensor": True,
        "disable_requantization_for_cat": True
    }
    quantizer = PostQuantizer(
        model, (dummy_input0, dummy_input1), work_dir="cat_model", config=ptq_config
    )

    ptq_model = quantizer.quantize()
    ptq_model(dummy_input0, dummy_input1)

    with torch.no_grad():
        ptq_model.eval()
        ptq_model.cpu()

        ptq_model = quantizer.convert(ptq_model)
        torch.backends.quantized.engine = quantizer.backend
        converter = TFLiteConverter(
            ptq_model,
            (dummy_input0, dummy_input1),
            "cat_model.tflite",
            fuse_quant_dequant=True,
            quantize_target_type="int8"
        )
        converter.convert()

if __name__ == '__main__':
    _main()

Okay, I guess it is because we refuse to traverse into the same nodes in the computation graph again. We need to refine the constraints a little bit.

@peterjc123
Copy link
Collaborator

@spacycoder
Copy link
Author

This seems to be a decent workaround for the moment:

class Dummy(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = EncoderLayer(256)

    def forward(self, x, y):
        x_cat = torch.cat([x, y], dim=0)
        x_cat = self.encoder(x_cat)
        x, y = torch.chunk(x_cat, 2, dim=0)
        return x, y

@peterjc123
Copy link
Collaborator

@spacycoder I'm glad it works and it looks cleaner.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants