博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Keras 多任务实现,Multi Loss #########Keras Xception Multi loss 细粒度图像分类
阅读量:4261 次
发布时间:2019-05-26

本文共 2320 字,大约阅读时间需要 7 分钟。

这里只摘取关键代码:

# create the base pre-trained modelinput_tensor = Input(shape=(299, 299, 3))base_model = Xception(include_top=True, weights='imagenet', input_tensor=None, input_shape=None)plot_model(base_model, to_file='xception_model.png')base_model.layers.pop()base_model.outputs = [base_model.layers[-1].output]base_model.layers[-1].outbound_nodes = []base_model.output_layers = [base_model.layers[-1]]feature = base_modelimg1 = Input(shape=(299, 299, 3), name='img_1')img2 = Input(shape=(299, 299, 3), name='img_2')feature1 = feature(img1)feature2 = feature(img2)# Three loss functionscategory_predict1 = Dense(100, activation='softmax', name='ctg_out_1')(    Dropout(0.5)(feature1))category_predict2 = Dense(100, activation='softmax', name='ctg_out_2')(    Dropout(0.5)(feature2))dis = Lambda(eucl_dist, name='square')([feature1, feature2])model = Model(inputs=[img1, img2], outputs=[category_predict1, category_predict2, judge])model.compile(optimizer=SGD(lr=0.0001, momentum=0.9),              loss={                  'ctg_out_1': 'categorical_crossentropy',                  'ctg_out_2': 'categorical_crossentropy',                  'bin_out': 'categorical_crossentropy'},              loss_weights={                  'ctg_out_1': 1.,                  'ctg_out_2': 1.,                  'bin_out': 0.5              },              metrics=['accuracy'])

如果觉得我的工作对你有帮助,就点个吧

关于

这是百度举办的一个关于狗的细粒度分类比赛,比赛链接: 

框架

硬件

  • Geforce GTX 1060 6G
  • Intel® Core™ i7-6700 CPU
  • Memory 8G

模型

  • 提取深度特征

  • 受的启发,在多分类基础上增加一个样本是否相同判断的二分类loss,增加类间距离,减小类内距离

Keras实现

  • 去掉Xception最后用于imagenet分类的全连接层,获取图像深度特征
  • 输入两张图片,可能属于相同类也可能属于不同类
  • 根据特征和标签进行多分类训练
  • 同时以两图是否属于同一类作为二分类标签训练

数据预处理

  • 从Baidu云下载数据
    • 训练集:  Key: 5axb
    • 测试集:  Key:fl5n
  • 按类别把图片放在不同的目录下,方便ImageDataGenerator的使用
  • 因为先前我把图片命名为这种格式"typeid_randhash.jpg"了, 所以我写了这段代码来做图片移动的工作
  • 数据预处理还有许多细节要处理,遇到问题的话可以先查看keras的文档,如果还有问题,可以提.

训练

  • 使用Keras的ImageDataGenerator接口进行数据增广
  • 同时使用ImageDataGenerator做数据增广并进行正负样本对采样是一个难点.因为从ImageDataGenerator获得的图片被打乱了.
    遍历数据集找同类样本作为正样本效率很低,幸运的是,在每个batch中,存在同类的样本,所以我们可以通过在同一个batch中交换同类样本的位置,构造出包含正样本对的另一个输入.
  • 冻结Xception的卷积层,采用ADMM训练多分类和二分类模型.
  • 解冻Xception卷积层的最后两个block(总共有12个block,最后两个block从Xception的105层开始)继续使用SGD训练
  • 去掉数据增广,再训练直至收敛

代码

  • 单一Xception模型
    • 训练: 
    • 测试: 
  • Multi loss模型
    • 冻结训练全连接层+微调卷积层: 
    • Trick微调: 
    • 测试: 

一些测试结果

  • InceptionV3,多分类模型: 0.2502
  • Xception,多分类模型: 0.2235
  • Xception, 混合模型: 0.211
  • Xception, 混合模型,最后去掉数据增广再训练: 0.2045

转载地址:http://wqlei.baihongyu.com/

你可能感兴趣的文章
使用Github进行合作开发
查看>>
Impala入门笔记(转载)
查看>>
cloudera Manager中监控数据的存储
查看>>
Kafka简要介绍
查看>>
Maven环境的搭建
查看>>
hbase 学习梳理
查看>>
浅谈医学大数据(中)
查看>>
阿里巴巴数据产品经理工作总结
查看>>
大数据的特点及作用
查看>>
IBM朱辉:大数据分析的5个高复制使用场景及案例分享(含PPT)
查看>>
Java返回对象快捷键
查看>>
STL中的Iterator
查看>>
C语言拾遗
查看>>
数据库查询语句拾遗
查看>>
STL中的Vector
查看>>
C++中的trivial、standard layout、POD
查看>>
阿里中间件三大存储系统
查看>>
Tair源码阅读1---ConfigServer
查看>>
STL中的RB-tree
查看>>
STL中的Sort
查看>>