Skip to content

Commit

Permalink
merge MNN-3.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
huangzhengxiang committed Dec 2, 2024
2 parents 74d63e1 + 98ba45b commit 9f116b8
Show file tree
Hide file tree
Showing 94 changed files with 26,362 additions and 12,695 deletions.
Binary file modified doc/dingdingmnn3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions docs/inference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ C++环境默认为Defer模式,Python环境默认为Eager模式,可通过当
### 执行器
表达式在搭建模型或进行计算时,使用与[Module API](module.md)同样一个执行器(Executor) ,可配置表达式的执行模式、计算所用资源等。

表达式相关接口并非线程安全,如果多线程调用,需要保证每个线程绑定独立的Executor 。

## 表达式接口能力
### 模型存取与修改
- 模型读取
Expand Down
62 changes: 50 additions & 12 deletions docs/transformers/llm.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,22 @@ python llmexport.py \
```

### 功能
- 将模型先转为onnx模型,使用`--export onnx`,然后使用./MNNConvert工具将onnx模型转为mnn模型: ./MNNConvert --modelFile ../transformers/llm/export/model/onnx/llm.onnx --MNNModel llm.mnn --keepInputFormat --weightQuantBits=4 -f ONNX --transformerFuse=1 --allowCustomOp
- 更快的方式:直接转为mnn模型,使用`--export mnn`,注意,你需要先安装pymnn或者通过--mnnconvert选项指定MNNConvert工具的地址,两种条件必须满足其中一个。如果没有安装pymnn并且没有通过--mnnconvert指定MNNConvert工具的地址,那么llmexport.py脚本会在目录"../../../build/"下寻找MNNConvert工具,需保证该目录下存在MNNConvert文件。
- 直接转为mnn模型,使用`--export mnn`,注意,你需要先安装pymnn或者通过`--mnnconvert`选项指定MNNConvert工具的地址,两种条件必须满足其中一个。如果没有安装pymnn并且没有通过`--mnnconvert`指定MNNConvert工具的地址,那么llmexport.py脚本会在目录"../../../build/"下寻找MNNConvert工具,需保证该目录下存在MNNConvert文件。此方案目前支持导出4bit和8bit模型
- 如果直接转为mnn模型遇到问题,或者需要其他bits数的量化(如5bit/6bit),可以先将模型先转为onnx模型,使用`--export onnx`,然后使用./MNNConvert工具将onnx模型转为mnn模型:

```
./MNNConvert --modelFile ../transformers/llm/export/model/onnx/llm.onnx --MNNModel llm.mnn --keepInputFormat --weightQuantBits=4 --weightQuantBlock=128 -f ONNX --transformerFuse=1 --allowCustomOp --saveExternalData
```

- 支持对模型进行对话测试,使用`--test $query`会返回llm的回复内容
- 默认会使用onnx-slim对onnx模型进行优化,跳过该步骤使用`--skip_slim`
- 支持合并lora权重后导出,指定lora权重的目录使用`--lora_path`
- 制定量化bit数使用`--quant_bit`;量化的block大小使用`--quant_block`
- 使用`--lm_quant_bit`来制定lm_head层权重的量化bit数,不指定则使用`--quant_bit`的量化bit数
- 支持使用自己编译的`MNNConvert`,使用`--mnnconvert`

### 参数
```
usage: llmexport.py [-h] --path PATH [--type TYPE] [--lora_path LORA_PATH] [--dst_path DST_PATH] [--test TEST] [--export EXPORT]
[--skip_slim] [--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK] [--lm_quant_bit LM_QUANT_BIT]
[--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK] [--lm_quant_bit LM_QUANT_BIT]
[--mnnconvert MNNCONVERT]
llm_exporter
Expand All @@ -78,7 +81,6 @@ options:
--dst_path DST_PATH export onnx/mnn model to path, defaut is `./model`.
--test TEST test model inference with query `TEST`.
--export EXPORT export model to an onnx/mnn model.
--skip_slim Whether or not to skip onnx-slim.
--quant_bit QUANT_BIT
mnn quant bit, 4 or 8, default is 4.
--quant_block QUANT_BLOCK
Expand All @@ -94,9 +96,17 @@ options:
### 编译

[从源码编译](../compile/other.html#id4)
在原有编译过程中增加必需编译宏即可: -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true
在原有编译过程中增加必需编译宏即可:
```
-DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true
```

- mac / linux / windows
- 需要开启视觉功能时,增加相关编译宏
```
-DLLM_SUPPORT_VISION=true -DMNN_BUILD_OPENCV=true -DMNN_IMGCODECS=true
```

#### mac / linux / windows

以 mac / linux 为例 :
```
Expand All @@ -106,26 +116,48 @@ cmake ../ -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_L
make -j16
```

x86架构额外加 MNN_AVX512 的宏:
x86架构额外加 `MNN_AVX512` 的宏:
```
make build
cd build
cmake ../ -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_AVX512=true
make -j16
```

- Android:额外增加 MNN_ARM82 的宏
#### Android:额外增加 `MNN_ARM82``MNN_OPENCL`的宏
```
cd project/android
mkdir build_64
../build_64.sh "-DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true"
../build_64.sh "-DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_USE_LOGCAT=true"
```

- iOS: 参考 transformers/llm/engine/ios/README.md
#### iOS: 参考 transformers/llm/engine/ios/README.md
```
sh package_scripts/ios/buildiOS.sh "-DMNN_ARM82=true -DMNN_LOW_MEMORY=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_BUILD_LLM=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true"
```

#### Web
环境配置参考 https://mnn-docs.readthedocs.io/en/latest/compile/engine.html#web

- 编译库,产出 `libMNN.a``libMNN_Express.a``libllm.a`

```
mkdir buildweb
emcmake cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-msimd128 -msse4.1" -DMNN_FORBID_MULTI_THREAD=ON -DMNN_USE_THREAD_POOL=OFF -DMNN_USE_SSE=ON -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true
make -j16
```

- Demo 编译

```
emcc ../transformers/llm/engine/llm_demo.cpp -DCMAKE_CXX_FLAGS="-msimd128 -msse4.1" -I ../include -I ../transformers/llm/engine/include libMNN.a libllm.a express/libMNN_Express.a -o llm_demo.js --preload-file ~/qwen2.0_1.5b/ -s ALLOW_MEMORY_GROWTH=1 -o llm_demo.js
```

使用如下命令测试:
```
node llm_demo.js ~/qwen2.0_1.5b/config.json ~/qwen2.0_1.5b/prompt.txt
```

### 使用
#### 运行时配置

Expand Down Expand Up @@ -244,11 +276,17 @@ pc端直接推理
./llm_demo model_dir/llm.mnn prompt.txt
```

<<<<<<< HEAD
手机端adb推理用法:
```bash
# 利用adb push将链接库push到手机上
adb shell mkdir /data/local/tmp/llm
adb push llm_demo ppl_demo libllm.so libMNN_CL.so libMNN_Express.so libMNN.so tools/cv/libMNNOpenCV.so /data/local/tmp/llm
=======
- 对于视觉大模型,在prompt中嵌入图片输入
```
<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>介绍一下图片里的内容
>>>>>>> alibaba/master
```
#### GPTQ权重加载
Expand Down
10 changes: 10 additions & 0 deletions express/MathOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,16 @@ VARP _Sigmoid(VARP x) {
return _Unary(x, UnaryOpOperation_SIGMOID);
}

/*Computes sigmoid of x element-wise.
Args:
x: A variable. Must be one of the following types: Halide_Type_Float
Returns:
A variable. Has the same type as x.
*/
VARP _Silu(VARP x) {
return _Unary(x, UnaryOpOperation_SILU);
}


/*Computes ((exponential of x) - 1) element-wise.
Args:
Expand Down
2 changes: 1 addition & 1 deletion include/MNN/MNNDefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
#define STR(x) STR_IMP(x)
#define MNN_VERSION_MAJOR 3
#define MNN_VERSION_MINOR 0
#define MNN_VERSION_PATCH 0
#define MNN_VERSION_PATCH 1
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
#endif /* MNNDefine_h */
1 change: 1 addition & 0 deletions include/MNN/expr/MathOp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ MNN_PUBLIC VARP _Erfc(VARP x);
MNN_PUBLIC VARP _Erfinv(VARP x);
MNN_PUBLIC VARP _Expm1(VARP x);
MNN_PUBLIC VARP _Hardswish(VARP x);
MNN_PUBLIC VARP _Silu(VARP x);

//ReduceOPs
MNN_PUBLIC VARP _ReduceSum(VARP input_variable, INTS axis = {}, bool keepDims = false);
Expand Down
17 changes: 11 additions & 6 deletions schema/current/TensorflowOp_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,12 @@ enum UnaryOpOperation {
UnaryOpOperation_HARDSWISH = 31,
UnaryOpOperation_GELU = 32,
UnaryOpOperation_GELU_STANDARD = 33,
UnaryOpOperation_SILU = 34,
UnaryOpOperation_MIN = UnaryOpOperation_ABS,
UnaryOpOperation_MAX = UnaryOpOperation_GELU_STANDARD
UnaryOpOperation_MAX = UnaryOpOperation_SILU
};

inline const UnaryOpOperation (&EnumValuesUnaryOpOperation())[34] {
inline const UnaryOpOperation (&EnumValuesUnaryOpOperation())[35] {
static const UnaryOpOperation values[] = {
UnaryOpOperation_ABS,
UnaryOpOperation_NEG,
Expand Down Expand Up @@ -439,7 +440,8 @@ inline const UnaryOpOperation (&EnumValuesUnaryOpOperation())[34] {
UnaryOpOperation_TANH,
UnaryOpOperation_HARDSWISH,
UnaryOpOperation_GELU,
UnaryOpOperation_GELU_STANDARD
UnaryOpOperation_GELU_STANDARD,
UnaryOpOperation_SILU
};
return values;
}
Expand Down Expand Up @@ -480,13 +482,14 @@ inline const char * const *EnumNamesUnaryOpOperation() {
"HARDSWISH",
"GELU",
"GELU_STANDARD",
"SILU",
nullptr
};
return names;
}

inline const char *EnumNameUnaryOpOperation(UnaryOpOperation e) {
if (e < UnaryOpOperation_ABS || e > UnaryOpOperation_GELU_STANDARD) return "";
if (e < UnaryOpOperation_ABS || e > UnaryOpOperation_SILU) return "";
const size_t index = static_cast<int>(e);
return EnumNamesUnaryOpOperation()[index];
}
Expand Down Expand Up @@ -5063,6 +5066,7 @@ inline const flatbuffers::TypeTable *UnaryOpOperationTypeTable() {
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 },
{ flatbuffers::ET_INT, 0, 0 }
};
static const flatbuffers::TypeFunction type_refs[] = {
Expand Down Expand Up @@ -5102,10 +5106,11 @@ inline const flatbuffers::TypeTable *UnaryOpOperationTypeTable() {
"TANH",
"HARDSWISH",
"GELU",
"GELU_STANDARD"
"GELU_STANDARD",
"SILU"
};
static const flatbuffers::TypeTable tt = {
flatbuffers::ST_ENUM, 34, type_codes, type_refs, nullptr, names
flatbuffers::ST_ENUM, 35, type_codes, type_refs, nullptr, names
};
return &tt;
}
Expand Down
1 change: 1 addition & 0 deletions schema/default/TensorflowOp.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ enum UnaryOpOperation : int {
HARDSWISH = 31,
GELU = 32,
GELU_STANDARD = 33,
SILU = 34,
}

table UnaryOp {
Expand Down
9 changes: 9 additions & 0 deletions source/backend/arm82/Arm82Unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ struct _Sigmoid {
MNNSigmoidLowp(out, inp, realSize);
}
};
struct _SiLu {
void operator()(void* outRaw, const void* inpRaw, int realSize) const {
auto out = (float*)outRaw;
auto inp = (const float*)inpRaw;
MNNSiLuLowp(out, inp, realSize);
}
};

void FP16GELU(void* outRaw, const void* inpRaw, int realSize) {
int sizeQuad = realSize / 8;
Expand Down Expand Up @@ -245,6 +252,8 @@ MNNUnaryExecute Arm82Unary::select(int type, int precision) {
return _Wrap<_Unary<UnarySin<float>, float>>;
case UnaryOpOperation_SIGMOID:
return _Wrap<_Sigmoid>;
case UnaryOpOperation_SILU:
return _Wrap<_SiLu>;
case UnaryOpOperation_TANH:
return _Wrap<_Tanh>;
case UnaryOpOperation_TAN:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ mov x10, x1 // dst
mov x12, x3 // src_depth_quad
sub x13, x7, #128 // src_step - 128
sub x14, x6, #32 // dst_step - 32
sub x15, x6, #64

// quant_scale: v10, v11
//ld1 {v10.8h}, [x2], #16
Expand Down
23 changes: 23 additions & 0 deletions source/backend/coreml/execution/CoreMLUnary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,29 @@ ErrorCode CoreMLUnary::onResize(const std::vector<Tensor *> &inputs, const std::
mCoreMLBackend->addLayer(mLayer_);
return NO_ERROR;
}
case UnaryOpOperation_SILU:
{
auto sigmoidLayer = mCoreMLBackend->create<CoreML__Specification__NeuralNetworkLayer>();
core_ml__specification__neural_network_layer__init(sigmoidLayer);
mCoreMLBackend->setLayerName(sigmoidLayer, "silu-sigmoid");
sigmoidLayer->layer_case = CORE_ML__SPECIFICATION__NEURAL_NETWORK_LAYER__LAYER_ACTIVATION;
sigmoidLayer->activation = mCoreMLBackend->create<CoreML__Specification__ActivationParams>();
core_ml__specification__activation_params__init(sigmoidLayer->activation);
sigmoidLayer->activation->nonlinearity_type_case = CORE_ML__SPECIFICATION__ACTIVATION_PARAMS__NONLINEARITY_TYPE_SIGMOID;
sigmoidLayer->activation->sigmoid = mCoreMLBackend->create<CoreML__Specification__ActivationSigmoid>();
core_ml__specification__activation_sigmoid__init(sigmoidLayer->activation->sigmoid);
std::string sigOutput = inputName + "-sigmoid";
setLayerInputsAndOutputs(sigmoidLayer, {inputName}, {sigOutput});
mCoreMLBackend->addLayer(sigmoidLayer);

mLayer_->layer_case = CORE_ML__SPECIFICATION__NEURAL_NETWORK_LAYER__LAYER_MULTIPLY;
mLayer_->multiply = mCoreMLBackend->create<CoreML__Specification__MultiplyLayerParams>();
core_ml__specification__multiply_layer_params__init(mLayer_->multiply);
setLayerInputsAndOutputs(mLayer_, {sigOutput, inputName}, {mCoreMLBackend->getTensorName(outputs[0])});
mCoreMLBackend->addLayer(mLayer_);

return NO_ERROR;
}
case UnaryOpOperation_GELU:
case UnaryOpOperation_GELU_STANDARD:
mLayer_->layer_case = CORE_ML__SPECIFICATION__NEURAL_NETWORK_LAYER__LAYER_GELU;
Expand Down
29 changes: 20 additions & 9 deletions source/backend/cpu/CPUAttention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@
#define FLOAT16_T float
#endif

<<<<<<< HEAD
// reduce the value of 'query' to 'query * FP16_QSCALE', avoid fp16 overflow
#define FP16_QSCALE 0.25

=======
>>>>>>> alibaba/master
namespace MNN {

template <typename T>
Expand Down Expand Up @@ -155,10 +158,10 @@ ErrorCode CPUAttention::onResize(const std::vector<Tensor*>& inputs, const std::
}
auto query = inputs[0];
auto key = inputs[1];
int seq_len = query->shape()[1];
mNumHead = query->shape()[2];
mHeadDim = query->shape()[3];
mKvNumHead = key->shape()[2];
int seq_len = query->length(1);
mNumHead = query->length(2);
mHeadDim = query->length(3);
mKvNumHead = key->length(2);
mKVCacheManager->onResize(mKvNumHead, mHeadDim);
if (mUseGemmInt8) {
mPackQ.reset(Tensor::createDevice<int8_t>({mThreadNum, UP_DIV(seq_len, eP8), UP_DIV(mHeadDim, lP8), eP8 * lP8}));
Expand Down Expand Up @@ -191,11 +194,10 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
auto key = inputs[1];
auto value = inputs[2];
auto mask = inputs[3];
auto mask_shape = mask->shape();
bool float_mask = (mask->getType() == halide_type_of<float>());
int mask_seqlen = mask_shape[2];
int mask_kvlen = mask_shape[3];
int seq_len = query->shape()[1];
int mask_seqlen = mask->length(2);
int mask_kvlen = mask->length(3);
int seq_len = query->length(1);
MNN_ASSERT(seq_len == mask_seqlen);
mIsPrefill = (seq_len > 1);
// isPrefill and mask is Square Matrix, is FirstPrefill
Expand All @@ -206,7 +208,16 @@ ErrorCode CPUAttention::onExecute(const std::vector<Tensor*>& inputs, const std:
float mScale = 1.0 / sqrt(mHeadDim);
float q_scale = 1.0;
if (bytes == 2) {
q_scale = FP16_QSCALE;
// reduce the value of 'query' to 'query * FP16_QSCALE', avoid fp16 overflow
FLOAT16_T minValue;
FLOAT16_T maxValue;
core->MNNCountMaxMinValue(query->host<float>(), (float*)(&minValue), (float*)(&maxValue), query->elementSize());
float maxV = maxValue;
float minV = minValue;
float absMax = ALIMAX(fabsf(maxV), fabsf(minV));
if (absMax > 1.0f) {
q_scale = 1.0f / absMax;
}
mScale /= q_scale;
}

Expand Down
13 changes: 10 additions & 3 deletions source/backend/cpu/CPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ ErrorCode CastWrapExecution::onExecute(const std::vector<Tensor*>& inputs, const
CPUCastCreator::cast(inputs[0], outputs[0], cpuBackend, convertType);
return NO_ERROR;
}
void CPUBackend::computeDivideSizes(int size, int* dst) const {
if (mGroupWithComputeRate.size() <= 1) {
void CPUBackend::computeDivideSizes(int size, int* dst, float avgDiv) const {
if (mGroupWithComputeRate.size() <= 1 || (avgDiv > 0 && avgDiv < mComputeI)) {
// Avg divide
int length = UP_DIV(size, mThreadNumber);
int cur = length;
Expand All @@ -60,6 +60,7 @@ void CPUBackend::computeDivideSizes(int size, int* dst) const {
}
return;
}

int cur = 0;
int curPos = 0;
for (auto& group : mGroupWithComputeRate) {
Expand Down Expand Up @@ -379,11 +380,17 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p
if (cpuInfo->groups.size() < 2) {
break;
}
if (cpuInfo->i8mm) {
mComputeI = 28.f;
} else if (cpuInfo->dot) {
mComputeI = 14.f;
} else {
mComputeI = 7.f;
}
mGroupWithComputeRate.clear();
float decreaseRate = (float)(rate) / 100.0f;
int validCpuSize = (int)(cpuInfo->groups[cpuInfo->groups.size()-1].ids.size());
int groupIndex = (int)cpuInfo->groups.size()-2;
float maxFreq = (float)cpuInfo->groups[cpuInfo->groups.size()-1].maxFreq;
validCpuSize = ALIMIN(validCpuSize, mThreadNumber);
float totalComputeRate = 1.0f * validCpuSize;
mGroupWithComputeRate.emplace_back(std::make_pair(totalComputeRate, validCpuSize));
Expand Down
Loading

0 comments on commit 9f116b8

Please sign in to comment.