Java自学者论坛

 找回密码
 立即注册

手机号码,快捷登录

恭喜Java自学者论坛(https://www.javazxz.com)已经为数万Java学习者服务超过8年了!积累会员资料超过10000G+
成为本站VIP会员,下载本站10000G+会员资源,会员资料板块,购买链接:点击进入购买VIP会员

JAVA高级面试进阶训练营视频教程

Java架构师系统进阶VIP课程

分布式高可用全栈开发微服务教程Go语言视频零基础入门到精通Java架构师3期(课件+源码)
Java开发全终端实战租房项目视频教程SpringBoot2.X入门到高级使用教程大数据培训第六期全套视频教程深度学习(CNN RNN GAN)算法原理Java亿级流量电商系统视频教程
互联网架构师视频教程年薪50万Spark2.0从入门到精通年薪50万!人工智能学习路线教程年薪50万大数据入门到精通学习路线年薪50万机器学习入门到精通教程
仿小米商城类app和小程序视频教程深度学习数据分析基础到实战最新黑马javaEE2.1就业课程从 0到JVM实战高手教程MySQL入门到精通教程
查看: 676|回复: 0

解决TensorBoard训练集和测试集指标只能分开显示的问题(基于Keras)

[复制链接]
  • TA的每日心情
    奋斗
    6 天前
  • 签到天数: 745 天

    [LV.9]以坛为家II

    2041

    主题

    2099

    帖子

    70万

    积分

    管理员

    Rank: 9Rank: 9Rank: 9

    积分
    704660
    发表于 2021-7-10 07:55:28 | 显示全部楼层 |阅读模式

    参考https://stackoverflow.com/questions/47877475/keras-tensorboard-plot-train-and-validation-scalars-in-a-same-figure
    tensorflow版本:1.13.1
    keras版本:2.2.4
    重新写一个TrainValTensorBoard继承TensorBoard。

    import os
    import tensorflow as tf
    from keras.callbacks import TensorBoard
    
    class TrainValTensorBoard(TensorBoard):
        def __init__(self, log_dir='./logs', **kwargs):
            # Make the original `TensorBoard` log to a subdirectory 'training'
            training_log_dir = os.path.join(log_dir, 'training')
            super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)
    
            # Log the validation metrics to a separate subdirectory
            self.val_log_dir = os.path.join(log_dir, 'validation')
    
        def set_model(self, model):
            # Setup writer for validation metrics
            self.val_writer = tf.summary.FileWriter(self.val_log_dir)
            super(TrainValTensorBoard, self).set_model(model)
    
        def on_epoch_end(self, epoch, logs=None):
            # Pop the validation logs and handle them separately with
            # `self.val_writer`. Also rename the keys so that they can
            # be plotted on the same figure with the training metrics
            logs = logs or {}
            val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
            for name, value in val_logs.items():
                summary = tf.Summary()
                summary_value = summary.value.add()
                summary_value.simple_value = value.item()
                summary_value.tag = name
                self.val_writer.add_summary(summary, epoch)
            self.val_writer.flush()
    
            # Pass the remaining logs to `TensorBoard.on_epoch_end`
            logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
            super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)
    
        def on_train_end(self, logs=None):
            super(TrainValTensorBoard, self).on_train_end(logs)
            self.val_writer.close()
    

    使用新的TrainValTensorBoard。

    from keras.models import Sequential
    from keras.layers import Dense
    from keras.datasets import mnist
    
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.reshape(60000, 784)
    x_test = x_test.reshape(10000, 784)
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    
    model = Sequential()
    model.add(Dense(64, activation='relu', input_shape=(784,)))
    model.add(Dense(10, activation='softmax'))
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    model.fit(x_train, y_train, epochs=10,
              validation_data=(x_test, y_test),
              callbacks=[TrainValTensorBoard(write_graph=False)])
    

    哎...今天够累的,签到来了1...
    回复

    使用道具 举报

    您需要登录后才可以回帖 登录 | 立即注册

    本版积分规则

    QQ|手机版|小黑屋|Java自学者论坛 ( 声明:本站文章及资料整理自互联网,用于Java自学者交流学习使用,对资料版权不负任何法律责任,若有侵权请及时联系客服屏蔽删除 )

    GMT+8, 2024-3-28 19:30 , Processed in 0.064034 second(s), 29 queries .

    Powered by Discuz! X3.4

    Copyright © 2001-2021, Tencent Cloud.

    快速回复 返回顶部 返回列表