V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
laziji
V2EX  ›  分享发现

TensorFlow.js 卷积神经网络手写数字识别

  •  1
     
  •   laziji · 2018-11-21 18:31:41 +08:00 · 1411 次点击
    这是一个创建于 1997 天前的主题,其中的信息可能已经有所发展或是发生改变。

    原博地址https://laboo.top/2018/11/21/tfjs-dr/

    源码

    digit-recognizer

    demo

    https://github-laziji.github.io/digit-recognizer/ 演示开始时需要加载大概100M的训练数据, 稍等片刻

    调整训练集的大小, 观察测试结果的准确性

    数据来源

    数据来源与 https://www.kaggle.com 中的一道题目 digit-recognizer 题目给出42000条训练数据(包含图片和标签)以及28000条测试数据(只包含图片) 要求给这些测试数据打上标签[0,1,2,3....,9] 要尽可能的准确

    网站中还有许多其他的机器学习的题目以及数据, 是个很好的练手的地方

    实现

    这里我们使用TensorFlow.js来实现这个项目

    创建模型

    卷积神经网络的第一层有两种作用, 它既是输入层也是执行层, 接收IMAGE_H * IMAGE_W大小的黑白像素 最后一层是输出层, 有 10 个输出单元, 代表着0-9这十个值的概率分布, 例如 Label=2 , 输出为[0.02,0.01,0.9,...,0.01]

    function createConvModel() {
      const model = tf.sequential();
    
      model.add(tf.layers.conv2d({
        inputShape: [IMAGE_H, IMAGE_W, 1],
        kernelSize: 3,
        filters: 16,
        activation: 'relu'
      }));
    
      model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
      model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
      model.add(tf.layers.maxPooling2d({ poolSize: 2, strides: 2 }));
      model.add(tf.layers.conv2d({ kernelSize: 3, filters: 32, activation: 'relu' }));
      model.add(tf.layers.flatten({}));
    
      model.add(tf.layers.dense({ units: 64, activation: 'relu' }));
      model.add(tf.layers.dense({ units: 10, activation: 'softmax' }));
    
      return model;
    }
    

    训练模型

    我们选择适当的优化器和损失函数, 来编译模型

    async function train() {
    
      ui.trainLog('Create model...');
      model = createConvModel();
      
      ui.trainLog('Compile model...');
      const optimizer = 'rmsprop';
      model.compile({
        optimizer,
        loss: 'categoricalCrossentropy',
        metrics: ['accuracy'],
      });
      const trainData = Data.getTrainData(ui.getTrainNum());
      
      ui.trainLog('Training model...');
      await model.fit(trainData.xs, trainData.labels, {});
    
      ui.trainLog('Completed!');
      ui.trainCompleted();
    }
    

    测试

    这里测试一组测试数据, 返回对应的标签, 即十个输出单元中概率最高的下标

    function testOne(xs){
      if(!model){
        ui.viewLog('Need to train the model first');
        return;
      }
      ui.viewLog('Testing...');
      let output = model.predict(xs);
      ui.viewLog('Completed!');
      output.print();
      const axis = 1;
      const predictions = output.argMax(axis).dataSync();
      return predictions[0];
    }
    

    欢迎关注我的博客公众号 2018_11_16_0048241709.png

    目前尚无回复
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   1654 人在线   最高记录 6543   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 25ms · UTC 16:54 · PVG 00:54 · LAX 09:54 · JFK 12:54
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.