-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PTQ calibration shows bad results. #375
Comments
您好,我已经收到您的邮件,我会尽快查看并在第一时间回复您。
|
Hi @taestaes, you can add a pseudo-quantization switch behind this example code to analyze where the quantization error mainly comes. # Disable observer and enable fake quantization to validate model with quantization error
ptq_model.apply(torch.quantization.disable_observer)
ptq_model.apply(torch.quantization.enable_fake_quant)
if torch.cuda.device_count() > 1:
ptq_model = ptq_model.module
# Disable the activation or weight quantization to check whether the quantization error comes from activation or weight.
from tinynn.graph.quantization.fake_quantize import PTQFakeQuantize
for name, module in ptq_model.named_modules():
if isinstance(module, PTQFakeQuantize):
# disable all activation weight quantization
# if name.endswith('weight_fake_quant'):
# module.apply(torch.quantization.disable_fake_quant)
# disable all activation quantization
if name.endswith('activation_post_process') :
module.apply(torch.quantization.disable_fake_quant)
# Or, you could profile to identify which operator is causing the quantization to fail.
# First, disable all quantization, then gradually enable quantization for certain layers/Op, and pinpoint the layer with the highest quantization loss.
# The following is just an example, you should comment or modify it before profiling.
# ptq_model.apply(torch.quantization.disable_fake_quant)
# for name, child in ptq_model.named_modules():
# # if name in ['body_0_ffn_act']:
# if 'ffn_act' in name:
# child.apply(torch.quantization.enable_fake_quant)
dummy_input.to(device)
ptq_model(dummy_input)
print(ptq_model) All quantized value ranges look normal. I think it is the loss of accuracy caused by Prelu's activation value quantization. You can firstly try to only quantize Prelu and check the quantization error. |
The result files are output_ptq.zip
Isn't it severe degradation? I don't know whether the PTQ worked well or not. How can I solve this issue? |
When I only apply quantize for 'ffn_act' in name:
When I only apply quantize for 'ffn_act' not in name:
It seems two errors are similar but the result values are quantized with 0.0042 for 'ffn_act' not in name. |
The result looks bad, whether it is only quantizing prelu or not quantizing prelu, you can determine the quantization error by checking the result of BTW, if allowed, you could provide the |
@zk1998 What is your email? I can send it. The dummy input files have large filesize. |
I have sent the input files and the output folder. Can you find them? Thank you for your help, and I really need your assistance. I really want to resolve this issue. |
I have received it and will help you analyze it ASAP. |
Hi @taestaes ,could you please provide a |
What do you mean by post-processing function? my model's input are image values and outputs are also image values. |
Also I want to say that without calibration, the output values are all 0 or 1. |
The input and output value range of your model is float[0,1], but the pixel value range of the image is int[0,255].It seems that your model should be image input and image output, so your dataloader should have preprocess and postprocess to process the image. |
@taestaes Maybe your work relies on an open source project, this will make it easier for me to reproduce your problem. I currently find that it is difficult to evaluate the quality of your model by only observing numerical accuracy. It needs to be combined with actual tasks. |
It's right. My original input raw is 12 bit image (max value 4096), so I normalize it to (0,1). for example, a real input has range( -0.0483, 0.9995). The output of model has range as (-0.0289, 1.0635). I recover the output range as 8bit image (*255) |
Yes, it's right. Our code is based on https://github.com/XPixelGroup/BasicSR, but we changed it to custom network and denoising task. |
@taestaes which base model do you use in https://github.com/XPixelGroup/BasicSR/blob/master/docs/ModelZoo.md |
I'm using custom model so I think the our model is not based on the above. |
https://github.com/swz30/Restormer This is more accurate base branch. our model changed from this |
It seems I am unable to reproduce your work from the open-source repository. Could you provide the processing function for model outputs to images? This would allow me to visually assess the quality of the model's performance. Alternatively, could you provide some direct metrics that can be used to evaluate the quality of the outputs? Directly comparing the mean squared error (MSE) between floating-point outputs and quantized outputs is not meaningful, as your output values are numerous and mostly close to 0. The errors in the significant values cannot be reflected by MSE. |
Can you test this function? this makes output to uint8 img, you can use cv2.imwrite to see image. Also I used _apply_val_data_gain to make image brighter x100, since the output image are very dark. |
I did some ablation experiments using your evaluation method:
quantizer = QATQuantizer(
model, dummy_input, work_dir='out', config={
'override_qconfig_func': set_ptq_fake_quantize,
# 'force_overwrite': False,
# 'rewrite_graph': False,
'per_tensor': False,
}
) and only enable weight fake quant by add code after: ptq_model.apply(torch.quantization.disable_fake_quant)
for name, module in ptq_model.named_modules():
if isinstance(module, torch.quantization.FakeQuantize):
# disable all activation weight quantization
if name.endswith('weight_fake_quant'):
module.apply(torch.quantization.enable_fake_quant)
print(f"enable fake_quant {name}")
ptq_model.apply(torch.quantization.disable_fake_quant)
for name, module in ptq_model.named_modules():
if isinstance(module, torch.quantization.FakeQuantize):
# disable all activation weight quantization
if name.endswith('weight_fake_quant'):
module.apply(torch.quantization.enable_fake_quant)
print(f"enable fake_quant {name}")
if name.endswith('activation_post_process') and 'patch_embed_proj' in name:
module.apply(torch.quantization.enable_fake_quant) All in all, it is difficult to quantize the denoising model, only applying ptq will generate more noise. From a numerical perspective, quantization will introduce a lot of numerical errors, especially after the picture is brightened, the numerical errors are further amplified, which is why there is more noise in the picture above. |
Thanks. I found that following your ablation 1, the output image seems good, while there are some color tone difference, the noise is not that severe. The |
I compared ablation images, and the first conv activation quantizations induces a lot of noise.
|
How can I solve this issue? If I try QAT, can it make good results? Or is there option for higher accuracy for activation, such as 16 bit quantization for activation? |
|
Thanks. The problem is that the picture is in very lowlight scenario, and we need to denoise it on that scenario. I have tried QAT but the validation accuracy did not increase, so I'm not sure yet. I checked your INT16 example, but it seems TFLite format. Can I check the result of INT16 quantization as the above python code? I don't know how to see results of TFLite yet. |
I tried INT16, and it worked very well. You only need to add one line after:
And to make fake-quant pass on activation-int16 quantization, modify line to:
|
Thanks for your answer, I found that int16 quantization makes good results as below (8b/16b/float). But do you know why there are some color tone difference after quantization? the input image is originally greenish, so I'm not sure why the quantization changes the color tone. |
I have no idea, maybe you can try to compare the pixel values of RGB channels of floating point and quantized output individually to find the reason, then do some normalization to fix the color tone change. |
I have checked R G G B values of the input/output tensor, and found that withtout quantization, the channelwise mean stay similar, but with quantization, the R, B mean increases while G mean decreases. Isn't the quantization is performed per-channel? I can't understand why some channel mean increases and other channel mean decreases.
|
Per-channel quantization refers to the granularity of weight quantization in conv and has nothing to do with the RGGB channel. Quantization will introduce errors, especially the generated task like yours, which is more sensitive to errors. I have no idea how to solve this color shift. |
I followed the https://github.com/alibaba/TinyNeuralNetwork/blob/main/examples/quantization/post_error_anaylsis.py
calibrated with 10 iterations on test dataloader.
But my network PTQ outputs really bad results as below.
why the following layers have really low cosine similarity?
and found the result image have values of multiple of 0.004 (maybe this is because 8 bit?)
The text was updated successfully, but these errors were encountered: