博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
keras 迁移学习inception_v3,缺陷检测
阅读量:7241 次
发布时间:2019-06-29

本文共 3733 字,大约阅读时间需要 12 分钟。

from keras.models import Sequential from keras.layers.normalization import BatchNormalization from keras.layers.convolutional import Conv2D from keras.layers.convolutional import MaxPooling2D from keras.layers.core import Activation from keras.layers.core import Flatten from keras.layers.core import Dropout from keras.layers.core import Dense from keras import backend as K from keras.preprocessing.image import ImageDataGenerator from keras.optimizers import Adam from keras.preprocessing import image from keras.preprocessing.image import img_to_array from sklearn.preprocessing import MultiLabelBinarizer from sklearn.model_selection import train_test_split from sklearn.preprocessing import OneHotEncoder from keras.utils import to_categorical from keras.applications import inception_v3 from keras.layers import GlobalAveragePooling2D from keras.models import Model import matplotlib.pyplot as plt import imutils import numpy as np import argparse import random import pickle import cv2 import os from PIL import Image import matplotlib matplotlib.use("Agg") # 获取该路径下所有图片 path = list(imutils.paths.list_images(r'C:\Users\Desktop\guangdong\train')) imagePaths = sorted(path) random.shuffle(imagePaths) name_dic = {
'正常':'norm','不导电':'defect1','擦花':'defect2','横条压凹':'defect3','桔皮':'defect4','漏底':'defect5', '碰伤':'defect6','起坑':'defect7','凸粉':'defect8','涂层开裂':'defect9','脏点':'defect10','其他':'defect11'} # 将其他文件夹中,名称都改为其他 other_list_1 = os.listdir(r'C:\Users\Desktop\guangdong\train\guangdong_round1_train2_20180916\guangdong_round1_train2_20180916\瑕疵样本\其他') other_list = other_list_1[1:] other_dic = { '伤口':'其他', '划伤':'其他', '变形':'其他', '喷流':'其他', '喷涂碰伤':'其他', '打白点':'其他', '打磨印':'其他','拖烂':'其他', '杂色':'其他', '气泡':'其他', '油印':'其他', '油渣':'其他', '漆泡':'其他', '火山口':'其他', '碰凹':'其他', '粘接':'其他', '纹粗':'其他', '角位漏底':'其他', '返底':'其他', '铝屑':'其他', '驳口':'其他'} # 打印出name_dic里的英文部分,手动复制,再在每个后面添加‘:’及相应的数字 name_dic.values() digit_dir = {
'norm':0, 'defect1':1, 'defect2':2, 'defect3':3, 'defect4':4, 'defect5':5, 'defect6':6, 'defect7':7, 'defect8':8, 'defect9':9, 'defect10':10, 'defect11':11} # 将图片resize成inception_v3需要的(299,299)大小,并转化成array labels = [] data =[] for imagePath in imagePaths: img = Image.open(imagePath) img = img.resize((299,299)) img = img_to_array(img) data.append(img) label_gbk = imagePath.split('\\')[-1].split('2')[0] if label_gbk in other_list: label_gbk = other_dic[label_gbk] label_english = name_dic[label_gbk] label = digit_dir[label_english] print(label_gbk,':',label_english,':',label) labels.append(label) # 像素归一化(有利于加速收敛) labels = np.array(labels) data = np.array(data, dtype="float") / 255.0 # 标签one-hot labels = to_categorical(labels) x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42) # 数据增强 train_aug = ImageDataGenerator(rotation_range=25, width_shift_range=0.1,height_shift_range=0.1, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode="nearest",preprocessing_function=inception_v3.preprocess_input) # inception_v3基础模型,include_top=False就是要修改原模型的最后一层 base_model = inception_v3.InceptionV3(weights='imagenet',include_top=False) x = base_model.output x = GlobalAveragePooling2D()(x) x = Dense(units=1024,activation='relu')(x) predictions = Dense(units=12,activation='softmax')(x) model = Model(inputs=base_model.input, output=predictions) base_model.summary() model.summary() # 不训练基础层 for layer in base_model.layers: layer.trainable = False model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy']) # batch_size最好选2的n次方,参考的是内存格式 history_tl = model.fit_generator(generator=train_aug.flow(x=x_train,y=y_train,batch_size=32),validation_data=(x_test, y_test), steps_per_epoch=len(x_train)//32,epochs=10,verbose=1) model.save()

转载地址:http://pvfbm.baihongyu.com/

你可能感兴趣的文章
取汉子拼音首字母的C#方法
查看>>
C语言 · 求先序遍历
查看>>
java oracle thin 和 oci 连接方式实现多数据库的故障切换
查看>>
使用spring利用HandlerExceptionResolver实现全局异常捕获
查看>>
字符串 上
查看>>
jmeter设置全局变量
查看>>
MySQLi基于面向对象的编程
查看>>
CAAnimation 动画支撑系统
查看>>
读vue-0.6-text-parser.js源码
查看>>
对map进行排序
查看>>
IntelliJ IDEA 13.1.3 SVN无法正常使用问题
查看>>
Element link doesn't have required attribute property
查看>>
linux ctags
查看>>
RMAN备份(转)
查看>>
Oracle 12c 多租户 手工创建 pdb 与 手工删除 pdb
查看>>
FlexPaper:使用flash在线展示pdf
查看>>
漫游Kafka设计篇之性能优化
查看>>
JConsole
查看>>
JavaScript初探之——图片移动
查看>>
ABI 管理
查看>>