混合精度方法

本贴最后更新于 276 天前,其中的信息可能已经东海扬尘

1.main.py

from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler
    if opt.half:
        scaler = GradScaler()
        for i, data in enumerate(train_loader, 0):
            data_time.update(time.time() - start_time)
            # Hnet.change(False)
            Hnet.zero_grad()
            Rnet.zero_grad()
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                all_pics = data  # allpics contains cover images and secret images
                this_batch_size = int(all_pics.size()[0])  # get true batch size of this step

                # first half of images will become cover images, the rest are treated as secret images
                cover_img = all_pics[0:this_batch_size, :, :, :]  # batchsize,3,256,256
                secret_img,target=sample_secret_img(this_batch_size)

                # secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]

                cover_img = cover_img.to(device)
                secret_img = secret_img.to(device)
              
                with torch.no_grad():
                    cover_img = Variable(cover_img)
                    secret_imgv = Variable(secret_img)
                cover_imgv = JPEG(cover_img)
                # print(cover_imgv.shape, secret_imgv.shape)
                container = Hnet(cover_imgv, secret_imgv)  # put concat_image into H-net and get container image


                errH = criterion(container, cover_imgv)  # loss between cover and container
                Hlosses.update(errH.item(), this_batch_size)

                compress_img=JPEG(container)

                rev_secret_img = Rnet(compress_img)  # put concatenated image into R-net and get revealed secret image
                # print(rev_secret_img.shape, secret_imgv.shape)
                errR = criterion(rev_secret_img, secret_imgv)  # loss between secret image and revealed secret image
                Rlosses.update(errR.item(), this_batch_size)

                betaerrR_secret = opt.beta * errR
                err_sum = errH + betaerrR_secret
                SumLosses.update(err_sum.item(), this_batch_size)

            # err_sum.backward()
            scaler.scale(err_sum).backward()
            scaler.step(optimizerH)
            scaler.step(optimizerR)
            # optimizerH.step()
            # optimizerR.step()
            # Updates the scale for next iteration.
            scaler.update()
            batch_time.update(time.time() - start_time)
            start_time = time.time()

            log = '[%d/%d][%d/%d]\tLoss_H: %.4f Loss_R: %.4f Loss_sum: %.4f \tdatatime: %.4f \tbatchtime: %.4f' % (
                epoch, opt.niter, i, len(train_loader),
                Hlosses.val, Rlosses.val, SumLosses.val, data_time.val, batch_time.val)

            if i % opt.logFrequency == 0:
                print_log(log, logPath)
            else:
                print_log(log, logPath, console=False)

            # genereate a picture every resultPicFrequency steps
            if epoch % 1 == 0 and i % opt.resultPicFrequency == 0:
                save_result_pic(this_batch_size,
                                cover_img, container.data,compress_img.data,
                                secret_img, rev_secret_img.data,
                                epoch, i, opt.trainpics)
            # break
  
    else:
        for i, data in enumerate(train_loader, 0):
            data_time.update(time.time() - start_time)
            # Hnet.change(False)
            Hnet.zero_grad()
            Rnet.zero_grad()
            all_pics = data  # allpics contains cover images and secret images
            this_batch_size = int(all_pics.size()[0])  # get true batch size of this step

            # first half of images will become cover images, the rest are treated as secret images
            cover_img = all_pics[0:this_batch_size, :, :, :]  # batchsize,3,256,256
            secret_img,target=sample_secret_img(this_batch_size)

            # secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]

            cover_img = cover_img.to(device)
            secret_img = secret_img.to(device)
          
            with torch.no_grad():
                cover_img = Variable(cover_img)
                secret_imgv = Variable(secret_img)
            cover_imgv = JPEG(cover_img)
            # print(cover_imgv.shape, secret_imgv.shape)
            container = Hnet(cover_imgv, secret_imgv)  # put concat_image into H-net and get container image


            errH = criterion(container, cover_imgv)  # loss between cover and container
            Hlosses.update(errH.item(), this_batch_size)

            compress_img=JPEG(container)

            rev_secret_img = Rnet(compress_img)  # put concatenated image into R-net and get revealed secret image
            # print(rev_secret_img.shape, secret_imgv.shape)
            errR = criterion(rev_secret_img, secret_imgv)  # loss between secret image and revealed secret image
            Rlosses.update(errR.item(), this_batch_size)

            betaerrR_secret = opt.beta * errR
            err_sum = errH + betaerrR_secret
            SumLosses.update(err_sum.item(), this_batch_size)

            err_sum.backward()
 
            optimizerH.step()
            optimizerR.step()
            # Updates the scale for next iteration.
            batch_time.update(time.time() - start_time)
            start_time = time.time()

            log = '[%d/%d][%d/%d]\tLoss_H: %.4f Loss_R: %.4f Loss_sum: %.4f \tdatatime: %.4f \tbatchtime: %.4f' % (
                epoch, opt.niter, i, len(train_loader),
                Hlosses.val, Rlosses.val, SumLosses.val, data_time.val, batch_time.val)

            if i % opt.logFrequency == 0:
                print_log(log, logPath)
            else:
                print_log(log, logPath, console=False)

            # genereate a picture every resultPicFrequency steps
            if epoch % 1 == 0 and i % opt.resultPicFrequency == 0:
                save_result_pic(this_batch_size,
                                cover_img, container.data,compress_img.data,
                                secret_img, rev_secret_img.data,
                                epoch, i, opt.trainpics)

2.note

how to deal the issue.

  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    545 引用 • 672 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • Quicker

    Quicker 您的指尖工具箱!操作更少,收获更多!

    34 引用 • 148 回帖
  • 锤子科技

    锤子科技(Smartisan)成立于 2012 年 5 月,是一家制造移动互联网终端设备的公司,公司的使命是用完美主义的工匠精神,打造用户体验一流的数码消费类产品(智能手机为主),改善人们的生活质量。

    4 引用 • 31 回帖
  • 反馈

    Communication channel for makers and users.

    123 引用 • 913 回帖 • 250 关注
  • CSS

    CSS(Cascading Style Sheet)“层叠样式表”是用于控制网页样式并允许将样式信息与网页内容分离的一种标记性语言。

    196 引用 • 540 回帖 • 1 关注
  • Hibernate

    Hibernate 是一个开放源代码的对象关系映射框架,它对 JDBC 进行了非常轻量级的对象封装,使得 Java 程序员可以随心所欲的使用对象编程思维来操纵数据库。

    39 引用 • 103 回帖 • 715 关注
  • MySQL

    MySQL 是一个关系型数据库管理系统,由瑞典 MySQL AB 公司开发,目前属于 Oracle 公司。MySQL 是最流行的关系型数据库管理系统之一。

    692 引用 • 535 回帖
  • SpaceVim

    SpaceVim 是一个社区驱动的模块化 vim/neovim 配置集合,以模块的方式组织管理插件以
    及相关配置,为不同的语言开发量身定制了相关的开发模块,该模块提供代码自动补全,
    语法检查、格式化、调试、REPL 等特性。用户仅需载入相关语言的模块即可得到一个开箱
    即用的 Vim-IDE。

    3 引用 • 31 回帖 • 104 关注
  • Git

    Git 是 Linux Torvalds 为了帮助管理 Linux 内核开发而开发的一个开放源码的版本控制软件。

    209 引用 • 358 回帖
  • ZeroNet

    ZeroNet 是一个基于比特币加密技术和 BT 网络技术的去中心化的、开放开源的网络和交流系统。

    1 引用 • 21 回帖 • 632 关注
  • 游戏

    沉迷游戏伤身,强撸灰飞烟灭。

    177 引用 • 816 回帖
  • IBM

    IBM(国际商业机器公司)或万国商业机器公司,简称 IBM(International Business Machines Corporation),总公司在纽约州阿蒙克市。1911 年托马斯·沃森创立于美国,是全球最大的信息技术和业务解决方案公司,拥有全球雇员 30 多万人,业务遍及 160 多个国家和地区。

    17 引用 • 53 回帖 • 141 关注
  • 职场

    找到自己的位置,萌新烦恼少。

    127 引用 • 1706 回帖
  • Sillot

    Insights(注意当前设置 master 为默认分支)

    汐洛彖夲肜矩阵(Sillot T☳Converbenk Matrix),致力于服务智慧新彖乄,具有彖乄驱动、极致优雅、开发者友好的特点。其中汐洛绞架(Sillot-Gibbet)基于自思源笔记(siyuan-note),前身是思源笔记汐洛版(更早是思源笔记汐洛分支),是智慧新录乄终端(多端融合,移动端优先)。

    主仓库地址:Hi-Windom/Sillot

    文档地址:sillot.db.sc.cn

    注意事项:

    1. ⚠️ 汐洛仍在早期开发阶段,尚不稳定
    2. ⚠️ 汐洛并非面向普通用户设计,使用前请了解风险
    3. ⚠️ 汐洛绞架基于思源笔记,开发者尽最大努力与思源笔记保持兼容,但无法实现 100% 兼容
    29 引用 • 25 回帖 • 86 关注
  • 博客

    记录并分享人生的经历。

    273 引用 • 2388 回帖
  • Kafka

    Kafka 是一种高吞吐量的分布式发布订阅消息系统,它可以处理消费者规模的网站中的所有动作流数据。 这种动作(网页浏览,搜索和其他用户的行动)是现代系统中许多功能的基础。 这些数据通常是由于吞吐量的要求而通过处理日志和日志聚合来解决。

    36 引用 • 35 回帖
  • OpenResty

    OpenResty 是一个基于 NGINX 与 Lua 的高性能 Web 平台,其内部集成了大量精良的 Lua 库、第三方模块以及大多数的依赖项。用于方便地搭建能够处理超高并发、扩展性极高的动态 Web 应用、Web 服务和动态网关。

    17 引用 • 38 关注
  • 生活

    生活是指人类生存过程中的各项活动的总和,范畴较广,一般指为幸福的意义而存在。生活实际上是对人生的一种诠释。生活包括人类在社会中与自己息息相关的日常活动和心理影射。

    230 引用 • 1454 回帖
  • HHKB

    HHKB 是富士通的 Happy Hacking 系列电容键盘。电容键盘即无接点静电电容式键盘(Capacitive Keyboard)。

    5 引用 • 74 回帖 • 478 关注
  • 创造

    你创造的作品可能会帮助到很多人,如果是开源项目的话就更赞了!

    178 引用 • 997 回帖
  • Sphinx

    Sphinx 是一个基于 SQL 的全文检索引擎,可以结合 MySQL、PostgreSQL 做全文搜索,它可以提供比数据库本身更专业的搜索功能,使得应用程序更容易实现专业化的全文检索。

    1 引用 • 221 关注
  • 30Seconds

    📙 前端知识精选集,包含 HTML、CSS、JavaScript、React、Node、安全等方面,每天仅需 30 秒。

    • 精选常见面试题,帮助您准备下一次面试
    • 精选常见交互,帮助您拥有简洁酷炫的站点
    • 精选有用的 React 片段,帮助你获取最佳实践
    • 精选常见代码集,帮助您提高打码效率
    • 整理前端界的最新资讯,邀您一同探索新世界
    488 引用 • 384 回帖
  • JSON

    JSON (JavaScript Object Notation)是一种轻量级的数据交换格式。易于人类阅读和编写。同时也易于机器解析和生成。

    52 引用 • 190 回帖 • 1 关注
  • 倾城之链
    23 引用 • 66 回帖 • 138 关注
  • 单点登录

    单点登录(Single Sign On)是目前比较流行的企业业务整合的解决方案之一。SSO 的定义是在多个应用系统中,用户只需要登录一次就可以访问所有相互信任的应用系统。

    9 引用 • 25 回帖
  • 创业

    你比 99% 的人都优秀么?

    85 引用 • 1399 回帖 • 1 关注
  • NetBeans

    NetBeans 是一个始于 1997 年的 Xelfi 计划,本身是捷克布拉格查理大学的数学及物理学院的学生计划。此计划延伸而成立了一家公司进而发展这个商用版本的 NetBeans IDE,直到 1999 年 Sun 买下此公司。Sun 于次年(2000 年)六月将 NetBeans IDE 开源,直到现在 NetBeans 的社群依然持续增长。

    78 引用 • 102 回帖 • 683 关注
  • jsDelivr

    jsDelivr 是一个开源的 CDN 服务,可为 npm 包、GitHub 仓库提供免费、快速并且可靠的全球 CDN 加速服务。

    5 引用 • 31 回帖 • 72 关注