最新公告
  • 欢迎您光临起源地模板网,本站秉承服务宗旨 履行“站长”责任,销售只是起点 服务永无止境!立即加入钻石VIP
  • 如何导出python中的模型参数

    正文概述    2020-01-28   239

    如何导出python中的模型参数

    模型的保存和读取

    1.tensorflow保存和读取模型:tf.train.Saver()     .save()

    #保存模型需要用到save函数
    save(
        sess,
        save_path,
        global_step=None,
        latest_filename=None,
        meta_graph_suffix='meta',
        write_meta_graph=True,
        write_state=True
    )
    '''
    sess: 保存模型要求必须有一个加载了计算图的会话,而且所有变量必须已被初始化。
    save_path: 模型保存路径及保存名称
    global_step: 如果提供的话,这个数字会添加到save_path后面,用于区分不同训练阶段的结果
    '''

    示例:

    #例子
    import tensorflow as tf  
    import numpy as np  
    import os  
      
    #用numpy产生数据  
    x_data = np.linspace(-1,1,300)[:, np.newaxis] #转置  
    noise = np.random.normal(0,0.05, x_data.shape)  
    y_data = np.square(x_data)-0.5+noise  
      
    #输入层  
    x_ph = tf.placeholder(tf.float32, [None, 1])  
    y_ph = tf.placeholder(tf.float32, [None, 1])  
      
    #隐藏层  
    w1 = tf.Variable(tf.random_normal([1,10]))  
    b1 = tf.Variable(tf.zeros([1,10])+0.1)  
    wx_plus_b1 = tf.matmul(x_ph, w1) + b1  
    hidden = tf.nn.relu(wx_plus_b1)  
      
    #输出层  
    w2 = tf.Variable(tf.random_normal([10,1]))  
    b2 = tf.Variable(tf.zeros([1,1])+0.1)  
    wx_plus_b2 = tf.matmul(hidden, w2) + b2  
    y = wx_plus_b2  
      
    #损失  
    loss = tf.reduce_mean(tf.reduce_sum(tf.square(y_ph-y),reduction_indices=[1]))  
    train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)  
      
    #保存模型对象saver  
    saver = tf.train.Saver()  
      
    #判断模型保存路径是否存在,不存在就创建  
    if not os.path.exists('tmp/'):  
        os.mkdir('tmp/')  
      
    #初始化  
    with tf.Session() as sess:  
        if os.path.exists('tmp/checkpoint'):         #判断模型是否存在  
            saver.restore(sess, 'tmp/model.ckpt')    #存在就从模型中恢复变量  
        else:  
            init = tf.global_variables_initializer() #不存在就初始化变量  
            sess.run(init)  
      
        for i in range(1000):  
            _,loss_value = sess.run([train_op,loss], feed_dict={x_ph:x_data, y_ph:y_data})  
            if(i%50==0):  
                save_path = saver.save(sess, 'tmp/model.ckpt')  
                print("迭代次数:%d , 训练损失:%s"%(i, loss_value))

    每调用一次保存操作会创建后3个数据文件并创建一个检查点(checkpoint)文件,简单理解就是权重等参数被保存到 .chkp.data 文件中,以字典的形式;图和元数据被保存到 .chkp.meta 文件中,可以被 tf.train.import_meta_graph 加载到当前默认的图。

    2.keras保存和读取模型

    model.save(filepath),同时保存model和权重的

    import numpy as np
    from keras.datasets import mnist
    from keras.utils import np_utils
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.optimizers import SGD
    # 载入数据
    (x_train,y_train),(x_test,y_test) = mnist.load_data()
    # (60000,28,28)
    print('x_shape:',x_train.shape)
    # (60000)
    print('y_shape:',y_train.shape)
    # (60000,28,28)->(60000,784)
    x_train = x_train.reshape(x_train.shape[0],-1)/255.0
    x_test = x_test.reshape(x_test.shape[0],-1)/255.0
    # 换one hot格式
    y_train = np_utils.to_categorical(y_train,num_classes=10)
    y_test = np_utils.to_categorical(y_test,num_classes=10)
    # 创建模型,输入784个神经元,输出10个神经元
    model = Sequential([
            Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
        ])
    # 定义优化器
    sgd = SGD(lr=0.2)
    # 定义优化器,loss function,训练过程中计算准确率
    model.compile(
        optimizer = sgd,
        loss = 'mse',
        metrics=['accuracy'],
    )
    # 训练模型
    model.fit(x_train,y_train,batch_size=64,epochs=5)
    # 评估模型
    loss,accuracy = model.evaluate(x_test,y_test)
    print('\ntest loss',loss)
    print('accuracy',accuracy)
    # 保存模型
    model.save('model.h5')

    推荐学习《Python教程》。


    起源地下载网 » 如何导出python中的模型参数

    常见问题FAQ

    免费下载或者VIP会员专享资源能否直接商用?
    本站所有资源版权均属于原作者所有,这里所提供资源均只能用于参考学习用,请勿直接商用。若由于商用引起版权纠纷,一切责任均由使用者承担。更多说明请参考 VIP介绍。
    提示下载完但解压或打开不了?
    最常见的情况是下载不完整: 可对比下载完压缩包的与网盘上的容量,若小于网盘提示的容量则是这个原因。这是浏览器下载的bug,建议用百度网盘软件或迅雷下载。若排除这种情况,可在对应资源底部留言,或 联络我们.。
    找不到素材资源介绍文章里的示例图片?
    对于PPT,KEY,Mockups,APP,网页模版等类型的素材,文章内用于介绍的图片通常并不包含在对应可供下载素材包内。这些相关商业图片需另外购买,且本站不负责(也没有办法)找到出处。 同样地一些字体文件也是这种情况,但部分素材会在素材包内有一份字体下载链接清单。
    模板不会安装或需要功能定制以及二次开发?
    请QQ联系我们

    发表评论

    还没有评论,快来抢沙发吧!

    如需帝国cms功能定制以及二次开发请联系我们

    联系作者

    请选择支付方式

    ×
    迅虎支付宝
    迅虎微信
    支付宝当面付
    余额支付
    ×
    微信扫码支付 0 元