<返回更多

腾讯TNN推理引擎源码解读系列:优化策略管理器

2020-06-17    
加入收藏

 

TNN:由腾讯优图实验室打造,移动端高性能、轻量级推理框架,同时拥有跨平台、高性能、模型压缩、代码裁剪等众多突出优势。TNN框架在原有Rapidnet、ncnn框架的基础上进一步加强了移动端设备的支持以及性能优化,同时也借鉴了业界主流开源框架高性能和良好拓展性的优点。目前TNN已经在手Q、微视、P图等应用中落地,欢迎大家参与协同共建,促进TNN推理框架进一步完善。

腾讯TNN推理引擎源码解读系列:优化策略管理器

腾讯推理引擎TNN

本系列文章为对腾讯TNN的深度源码级别解读,希望通过对一个推理框架的完整描述,来增强读者对于神经网络设计、实现到优化的方方面面。

本节将对TNN中的优化策略管理执行过程做详细的分析介绍。


在神经网络优化管理器中主要定义了NetOptimizerManager和NetOptimizerRegister两个类,分别实现管理执行和注册的功能,接下来将对这两个类的具体功能做详细分析。

网络优化策略管理器(NetOptimizerManager)

这里先来看看优化策略(网络优化器net_optimizer)的定义。TNN中定义了多种神经网络优化器(优化策略),比如fuse_conv_relu可以将conv层合并达到减少内存/显存拷贝,减少计算和调用开销的目的。NetOptimizer是所有这些优化器基类定义了优化器需要具备的基础功能。具体的每种优化器优化方法后续文章依次说明。

class NetOptimizer {
    public:
        virtual ~NetOptimizer() {}

        //定义策略
        virtual std::string Strategy() = 0;
        //确认设备是否支持优化
        virtual bool SupportDevice(DeviceType device) = 0;
        //基于网络结构进行优化
        virtual Status Optimize(NetStructure *structure, NetResource *resource) = 0;
    };

优化策略方法说明

  • Strategy方法:返回对应优化器的字符串名称。

当前内部定义的一些名称如下:

static const std::string kNetOptimizerFuseConvRelu = "net_optimizer_fuse_conv_relu";

static const std::string kNetOptimizerInsertReformat = "net_optimizer_Insert_reformat";

static const std::string kNetOptimizerRemoveLayers = "net_optimizer_remove_layers";
  • SupportDevice方法:返回该优化器是否在设备上可用。
  • Optimize方法:定义了具体的优化方法。

接下来看一下本文的重点NetOptimizerManager的实现。

NetOptimizerManager定义了优化管理器的执行方法。通过map和vector来存储优化策略。其中map结构维护了strategy名称到优化器实现的映射。vector结构维护了优先级信息和strategy名称的映射。

    //@brief net optimize: fuse relu and relu6 to convolution
    class NetOptimizerManager {
    public:
        //执行优化
        static Status Optimize(NetStructure *structure, NetResource *resource, DeviceType device);
        //静态方法:注册优化策略并指定优先级
        static void RegisterNetOptimizer(NetOptimizer *ptimizer, OptPriority prior);

    private:
  
        static std::map<std::string, std::shared_ptr<NetOptimizer>> &GetNetOptimizerMap();
        static std::vector<std::pair<OptPriority, std::string>> &GetNetOptimizerSeq();
    };

Optimize方法:

Optimize方法:定义了优化的执行逻辑框架(具体的优化由每个optimizer来执行)。

管理器需要依次执行每个优化器,这里是基于优先级来执行优化策略的。具体方法是通过对vector序列按优先级排序,然后遍历map这样可以按优先级从高到低依次执行优化。

支持的优先级,目前为P0到PLAST可以根据自己的需求进行扩展。

  typedef enum {
        // TOP
        P0 = 0,
        // MIDDLE
        P1 = 1,
        //
        P2 = 2,
        // LAST
        PLAST = 1000
    } OptPriority;
//基于优先级排序依次执行优化器。
    Status NetOptimizerManager::Optimize(NetStructure *structure, NetResource *resource, DeviceType device) {
        //获取当前优化器的map结构
        auto &optimizer_map = NetOptimizerManager::GetNetOptimizerMap();
        //按照strategy字符串对优化器进行排序
        std::sort(NetOptimizerManager::GetNetOptimizerSeq().begin(), NetOptimizerManager::GetNetOptimizerSeq().end());
        //循环优化器
        for (auto iter : NetOptimizerManager::GetNetOptimizerSeq()) {
            //
            auto optimizer = optimizer_map[iter.second];
            //确认当前设备是否支持该优化
            if (optimizer->SupportDevice(device)) {
                //执行优化逻辑
                auto status = optimizer->Optimize(structure, resource);
                if (status != TNN_OK) {  //执行出错会导致优化中断,返回对应状态值
                    return status;
                }
            }
        }

        return TNN_OK;
    }

RegisterNetOptimizer方法:

RegisterNetOptimizer方法定义了每个优化器optimizer的注册过程,其实就是map和vector插入对应数据就OK了。

  //参数:优化器和优先级
    void NetOptimizerManager::RegisterNetOptimizer(NetOptimizer *optimizer, OptPriority prior) {
        //优化器不为空
        if (optimizer && optimizer->Strategy().length() > 0) {
            auto &optimizer_map = NetOptimizerManager::GetNetOptimizerMap();
            //往map和vector中将消息注册进去
            optimizer_map[optimizer->Strategy()] = std::shared_ptr<NetOptimizer>(optimizer);
            NetOptimizerManager::GetNetOptimizerSeq().push_back(std::make_pair(prior, optimizer->Strategy()));
        }
    }

NetOptimizerRegister

在net_optimizer_manageer.h头文件中,通过模板类将RegisterNetOptimizer方法封装到NetOptimizerRegister类中,这样就可以做到定义即注册。

//通过模板指定具体的优化器。实现了定义即注册的功能。
    template <typename T>
    class NetOptimizerRegister {
    public:
        explicit NetOptimizerRegister(OptPriority p) {
            NetOptimizerManager::RegisterNetOptimizer(new T(), p);
        }
    };

 

声明:本站部分内容来自互联网,如有版权侵犯或其他问题请与我们联系,我们将立即删除或处理。
▍相关推荐
更多资讯 >>>