《昇思25天学习打卡营第2天|快速入门》

文章目录

  • 前言:
  • 今日所学:
    • 1. 数据集处理
    • 2. 网络的构建
    • 3. 模型训练
    • 4. 保存模型
    • 5. 加载模型
  • 总体代码与运行结果:
    • 1. 总体代码
    • 2. 运行结果


前言:

今天是学习打卡的第2天,今天的内容是对MindSpore的一个快速入门,主要通过MindSpore的API来快速实现一个简单的深度学习模型,首先我们要确保我们的实验环境安装了mindspore,可以通过如下代码来安装所需版本的mindspore如下:

%%capture captured_output
# 实验环境已经预装了mindspore==2.3.0rc1,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1

今日所学:

1. 数据集处理

首先我们要在代码中申明我们所需要用到的库:

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

然后对我们所需要使用的Mnist数据集来进行下载:

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

在这里插入图片描述
然后后去数据集对象后打印数据列名等用于dataset的预处理。

2. 网络的构建

通过mindspore.nn类来进行网络构建的过程,它是构建所有网络的基类,也是网络的基本单元,具体的构建如下:

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

model = Network()
print(model)

得到如下的结果:

在这里插入图片描述

3. 模型训练

在进行了以上两步后我们进一步来进行模型训练,一个完整的模型训练一般需要如下三步:

  • 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
  • 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
  • 参数优化:将梯度更新到参数上。
    针对以上模型训练的三个步骤我们通过如下代码来实现:
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)

# 第一步:首先我们定义了正向计算函数
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# 第二步:然后通过value_and_grad获得梯度计算函数
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# 第三部:定义训练函数从而来执行正向计算、反向传播与参数优化
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

然后在此训练之外我们定义测试函数test来评估模型性能吗,并且设置多次迭代数据集来提高我们的预测准确率。

4. 保存模型

在模型训练完毕后我们通过如下的代码来对我们的模型保存:

mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

5. 加载模型

通过一下代码加载模型:

model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

加载之后的模型即可进行预测推理:

model.set_train(False)
for data, label in test_dataset:
    pred = model(data)
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
    break

在这里插入图片描述

总体代码与运行结果:

1. 总体代码

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset
from download import download
 
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)
train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')
print(train_dataset.get_col_names())
 
def datapipe(dataset, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)
 
    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset
 
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)
 
for image, label in test_dataset.create_tuple_iterator():
    print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")
    print(f"Shape of label: {label.shape} {label.dtype}")
    break
for data in test_dataset.create_dict_iterator():
    print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")
    print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")
    break
 
    # Define model
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )
 
    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits
 
model = Network()
print(model)
 
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)

# 第一步:首先我们定义了正向计算函数
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits

# 第二步:然后通过value_and_grad获得梯度计算函数
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# 第三部:定义训练函数从而来执行正向计算、反向传播与参数优化
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss

def train(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
    
def test(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
 
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(model, train_dataset)
    test(model, test_dataset, loss_fn)
print("Done!")
# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")
 
# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)
 
 
model.set_train(False)
for data, label in test_dataset:
    pred = model(data)
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
    break

2. 运行结果

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/754971.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Selenium IDE 的使用指南

Selenium IDE 的使用指南 在自动化测试的领域中,Selenium 是一个广为人知且强大的工具集。而 Selenium IDE 作为其中的一个组件,为测试人员提供了一种便捷且直观的方式来创建和执行自动化测试脚本。 一、Selenium IDE 简介 Selenium IDE 是一个用于录…

第十三章 常用类

一、包装类 1. 包装类的分类 (1)针对八种基本数据类型相应的引用类型—包装类 (2)有了类的特点,就可以调用类中的方法。 2. 包装类和基本数据的转换 jdk5 前的手动装箱和拆箱方式,装箱:基本…

【Qt】信号和槽机制

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; &#x1f525;c系列专栏&#xff1a;C/C零基础到精通 &#x1f525; 给大…

操作系统之《PV操作》【知识点+详细解题过程】

1、并发进程 &#xff1a; 并发的实质是一个处理器在几个进程之间的多路复用&#xff0c;并发是对有限的物理资源强制行使多用户共享&#xff0c;消除计算机部件之间的互等现象&#xff0c;以提高系统资源利用率。 &#xff08;1&#xff09;并发进程——互斥性&#xff1a; 进…

使用Jetpack Compose实现具有多选功能的图片网格

使用Jetpack Compose实现具有多选功能的图片网格 在现代应用中,多选功能是一项常见且重要的需求。例如,Google Photos允许用户轻松选择多个照片进行分享、添加到相册或删除。在本文中,我们将展示如何使用Jetpack Compose实现类似的多选行为,最终效果如下: 主要步骤 实现…

【redis】Redis AOF

1、AOF的基本概念 AOF持久化方式是通过保存Redis所执行的写命令来记录数据库状态的。AOF以日志的形式来记录每个写操作&#xff08;增量保存&#xff09;&#xff0c;将Redis执行过的所有写指令记录下来&#xff08;读操作不记录&#xff09;。AOF文件是一个只追加的文件&…

Redis 高级数据结构业务实践

0、前言 本文所有代码可见 > 【gitee code demo】 本文会涉及 hyperloglog 、GEO、bitmap、布隆过滤器的介绍和业务实践 1、HyperLogLog 1.1、功能 基数统计&#xff08;去重&#xff09; 1.2、redis api 命令作用案例PFADD key element [element ...]添加元素到keyPF…

PortSip测试

安装PBX 下载 免费下载 PortSIP PBX 安装PBX&#xff0c;安装后&#xff0c;运行 &#xff0c;默认用户是admin 密码是admin&#xff0c;然后配置IP 为192.168.0.189 设置域名为192.168.0.189 配置分机 添加分机&#xff0c;添加了10001、10002、9999 三个分机&#xff0c…

深度学习实验第T2周:彩色图片分类

>- **&#x1f368; 本文为[&#x1f517;365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客** >- **&#x1f356; 原作者&#xff1a;[K同学啊](https://mtyjkh.blog.csdn.net/)** 目录 一、前言 目标 二、我的环境&#…

【Linux进程通信】进程间通信介绍、匿名管道原理分析

目录 进程通信是什么&#xff1f; 进程通信的目的 进程通信的本质 匿名管道&#xff1a;基于文件级别的通信方式 站在文件描述符角度-深度理解管道原理 进程通信是什么&#xff1f; 进程通信就是两个或多个进程之间进行数据层面的交互。 进程通信的目的 1.数据传输&#x…

已解决java.security.acl.LastOwnerException:无法移除最后一个所有者的正确解决方法,亲测有效!!!

已解决java.security.acl.LastOwnerException&#xff1a;无法移除最后一个所有者的正确解决方法&#xff0c;亲测有效&#xff01;&#xff01;&#xff01; 目录 问题分析 出现问题的场景 报错原因 解决思路 解决方法 1. 检查当前所有者数量 2. 添加新的所有者 3. 维…

mac Canon打印机连接教程

官网下载安装驱动&#xff1a; 选择打印机类型和mac系统型号下载即可 Mac PS 打印机驱动程序 双击安装 系统偏好设置 点击“”添加&#xff1a; OK可打印玩耍&#xff01;&#xff01; 备注&#xff1a; 若需扫描&#xff0c;下载扫描程序&#xff1a; 备注&#xff1a;…

设置小蓝熊的CPU亲和性、CPU优先级再设置法环的CPU亲和性

# 适用于Windows系统 # 时间 : 2024-06-28 # 作者 : 三巧(https://blog.csdn.net/qq_39124701) # 文件名 : 设置小蓝熊的CPU亲和性、CPU优先级再设置法环的CPU亲和性.ps1 # 使用方法: 打开记事本&#xff0c;将所有代码复制到记事本中&#xff0c;保存文件时候修改文件后…

Hugging Face Accelerate 两个后端的故事:FSDP 与 DeepSpeed

社区中有两个流行的零冗余优化器 (Zero Redundancy Optimizer&#xff0c;ZeRO)算法实现&#xff0c;一个来自DeepSpeed&#xff0c;另一个来自PyTorch。Hugging FaceAccelerate对这两者都进行了集成并通过接口暴露出来&#xff0c;以供最终用户在训练/微调模型时自主选择其中之…

zabbix-server的搭建

zabbix-server的搭建 部署 zabbix 服务端(192.168.99.180) rpm -ivh https://mirrors.aliyun.com/zabbix/zabbix/5.0/rhel/7/x86_64/zabbix-release-5.0-1.el7.noarch.rpm cd /etc/yum.repos.d sed -i s#http://repo.zabbix.com#https://mirrors.aliyun.com/zabbix# zabbix.r…

关于FPGA对 DDR4 (MT40A256M16)的读写控制 4

关于FPGA对 DDR4 &#xff08;MT40A256M16&#xff09;的读写控制 4 语言 &#xff1a;Verilg HDL 、VHDL EDA工具&#xff1a;ISE、Vivado、Quartus II 关于FPGA对 DDR4 &#xff08;MT40A256M16&#xff09;的读写控制 4一、引言二、DDR4 SDRAM设备中模式寄存器重要的模式寄存…

Arduino - LED 矩阵

Arduino - LED 矩阵 Arduino - LED Matrix LED matrix display, also known as LED display, or dot matrix display, are wide-used. In this tutorial, we are going to learn: LED矩阵显示器&#xff0c;也称为LED显示器&#xff0c;或点阵显示器&#xff0c;应用广泛。在…

“Hello, World!“ 历史由来

布莱恩W.克尼汉&#xff08;Brian W. Kernighan&#xff09;—— Unix 和 C 语言背后的巨人 布莱恩W.克尼汉在 1942 年出生在加拿大多伦多&#xff0c;他在普林斯顿大学取得了电气工程的博士学位&#xff0c;2000 年之后取得普林斯顿大学计算机科学的教授教职。 1973 年&#…

C++ | Leetcode C++题解之第203题移除链表元素

题目&#xff1a; 题解&#xff1a; class Solution { public:ListNode* removeElements(ListNode* head, int val) {struct ListNode* dummyHead new ListNode(0, head);struct ListNode* temp dummyHead;while (temp->next ! NULL) {if (temp->next->val val) {…

小柴冲刺软考中级嵌入式系统设计师系列一、计算机系统基础知识(1)嵌入式计算机系统概述

flechazohttps://www.zhihu.com/people/jiu_sheng 小柴冲刺嵌入式系统设计师系列总目录https://blog.csdn.net/qianshang52013/article/details/139975720?spm1001.2014.3001.5501 根据IEEE&#xff08;国际电气电子工程师协会&#xff09;的定义&#xff0c;嵌入式系统是&q…