混合精度方法

1.main.py

from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler
    if opt.half:
        scaler = GradScaler()
        for i, data in enumerate(train_loader, 0):
            data_time.update(time.time() - start_time)
            # Hnet.change(False)
            Hnet.zero_grad()
            Rnet.zero_grad()
            with torch.autocast(device_type='cuda', dtype=torch.float16):
                all_pics = data  # allpics contains cover images and secret images
                this_batch_size = int(all_pics.size()[0])  # get true batch size of this step

                # first half of images will become cover images, the rest are treated as secret images
                cover_img = all_pics[0:this_batch_size, :, :, :]  # batchsize,3,256,256
                secret_img,target=sample_secret_img(this_batch_size)

                # secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]

                cover_img = cover_img.to(device)
                secret_img = secret_img.to(device)
              
                with torch.no_grad():
                    cover_img = Variable(cover_img)
                    secret_imgv = Variable(secret_img)
                cover_imgv = JPEG(cover_img)
                # print(cover_imgv.shape, secret_imgv.shape)
                container = Hnet(cover_imgv, secret_imgv)  # put concat_image into H-net and get container image


                errH = criterion(container, cover_imgv)  # loss between cover and container
                Hlosses.update(errH.item(), this_batch_size)

                compress_img=JPEG(container)

                rev_secret_img = Rnet(compress_img)  # put concatenated image into R-net and get revealed secret image
                # print(rev_secret_img.shape, secret_imgv.shape)
                errR = criterion(rev_secret_img, secret_imgv)  # loss between secret image and revealed secret image
                Rlosses.update(errR.item(), this_batch_size)

                betaerrR_secret = opt.beta * errR
                err_sum = errH + betaerrR_secret
                SumLosses.update(err_sum.item(), this_batch_size)

            # err_sum.backward()
            scaler.scale(err_sum).backward()
            scaler.step(optimizerH)
            scaler.step(optimizerR)
            # optimizerH.step()
            # optimizerR.step()
            # Updates the scale for next iteration.
            scaler.update()
            batch_time.update(time.time() - start_time)
            start_time = time.time()

            log = '[%d/%d][%d/%d]\tLoss_H: %.4f Loss_R: %.4f Loss_sum: %.4f \tdatatime: %.4f \tbatchtime: %.4f' % (
                epoch, opt.niter, i, len(train_loader),
                Hlosses.val, Rlosses.val, SumLosses.val, data_time.val, batch_time.val)

            if i % opt.logFrequency == 0:
                print_log(log, logPath)
            else:
                print_log(log, logPath, console=False)

            # genereate a picture every resultPicFrequency steps
            if epoch % 1 == 0 and i % opt.resultPicFrequency == 0:
                save_result_pic(this_batch_size,
                                cover_img, container.data,compress_img.data,
                                secret_img, rev_secret_img.data,
                                epoch, i, opt.trainpics)
            # break
  
    else:
        for i, data in enumerate(train_loader, 0):
            data_time.update(time.time() - start_time)
            # Hnet.change(False)
            Hnet.zero_grad()
            Rnet.zero_grad()
            all_pics = data  # allpics contains cover images and secret images
            this_batch_size = int(all_pics.size()[0])  # get true batch size of this step

            # first half of images will become cover images, the rest are treated as secret images
            cover_img = all_pics[0:this_batch_size, :, :, :]  # batchsize,3,256,256
            secret_img,target=sample_secret_img(this_batch_size)

            # secret_img = all_pics[this_batch_size:this_batch_size * 2, :, :, :]

            cover_img = cover_img.to(device)
            secret_img = secret_img.to(device)
          
            with torch.no_grad():
                cover_img = Variable(cover_img)
                secret_imgv = Variable(secret_img)
            cover_imgv = JPEG(cover_img)
            # print(cover_imgv.shape, secret_imgv.shape)
            container = Hnet(cover_imgv, secret_imgv)  # put concat_image into H-net and get container image


            errH = criterion(container, cover_imgv)  # loss between cover and container
            Hlosses.update(errH.item(), this_batch_size)

            compress_img=JPEG(container)

            rev_secret_img = Rnet(compress_img)  # put concatenated image into R-net and get revealed secret image
            # print(rev_secret_img.shape, secret_imgv.shape)
            errR = criterion(rev_secret_img, secret_imgv)  # loss between secret image and revealed secret image
            Rlosses.update(errR.item(), this_batch_size)

            betaerrR_secret = opt.beta * errR
            err_sum = errH + betaerrR_secret
            SumLosses.update(err_sum.item(), this_batch_size)

            err_sum.backward()
 
            optimizerH.step()
            optimizerR.step()
            # Updates the scale for next iteration.
            batch_time.update(time.time() - start_time)
            start_time = time.time()

            log = '[%d/%d][%d/%d]\tLoss_H: %.4f Loss_R: %.4f Loss_sum: %.4f \tdatatime: %.4f \tbatchtime: %.4f' % (
                epoch, opt.niter, i, len(train_loader),
                Hlosses.val, Rlosses.val, SumLosses.val, data_time.val, batch_time.val)

            if i % opt.logFrequency == 0:
                print_log(log, logPath)
            else:
                print_log(log, logPath, console=False)

            # genereate a picture every resultPicFrequency steps
            if epoch % 1 == 0 and i % opt.resultPicFrequency == 0:
                save_result_pic(this_batch_size,
                                cover_img, container.data,compress_img.data,
                                secret_img, rev_secret_img.data,
                                epoch, i, opt.trainpics)

2.note

how to deal the issue.

  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    536 引用 • 672 回帖 • 1 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • Kotlin

    Kotlin 是一种在 Java 虚拟机上运行的静态类型编程语言,由 JetBrains 设计开发并开源。Kotlin 可以编译成 Java 字节码,也可以编译成 JavaScript,方便在没有 JVM 的设备上运行。在 Google I/O 2017 中,Google 宣布 Kotlin 成为 Android 官方开发语言。

    19 引用 • 33 回帖 • 27 关注
  • Webswing

    Webswing 是一个能将任何 Swing 应用通过纯 HTML5 运行在浏览器中的 Web 服务器,详细介绍请看 将 Java Swing 应用变成 Web 应用

    1 引用 • 15 回帖 • 635 关注
  • 微服务

    微服务架构是一种架构模式,它提倡将单一应用划分成一组小的服务。服务之间互相协调,互相配合,为用户提供最终价值。每个服务运行在独立的进程中。服务于服务之间才用轻量级的通信机制互相沟通。每个服务都围绕着具体业务构建,能够被独立的部署。

    96 引用 • 155 回帖
  • React

    React 是 Facebook 开源的一个用于构建 UI 的 JavaScript 库。

    192 引用 • 291 回帖 • 443 关注
  • Android

    Android 是一种以 Linux 为基础的开放源码操作系统,主要使用于便携设备。2005 年由 Google 收购注资,并拉拢多家制造商组成开放手机联盟开发改良,逐渐扩展到到平板电脑及其他领域上。

    333 引用 • 323 回帖 • 66 关注
  • 电影

    这是一个不能说的秘密。

    120 引用 • 597 回帖 • 2 关注
  • webpack

    webpack 是一个用于前端开发的模块加载器和打包工具,它能把各种资源,例如 JS、CSS(less/sass)、图片等都作为模块来使用和处理。

    41 引用 • 130 回帖 • 295 关注
  • 支付宝

    支付宝是全球领先的独立第三方支付平台,致力于为广大用户提供安全快速的电子支付/网上支付/安全支付/手机支付体验,及转账收款/水电煤缴费/信用卡还款/AA 收款等生活服务应用。

    29 引用 • 347 回帖
  • Typecho

    Typecho 是一款博客程序,它在 GPLv2 许可证下发行,基于 PHP 构建,可以运行在各种平台上,支持多种数据库(MySQL、PostgreSQL、SQLite)。

    12 引用 • 60 回帖 • 464 关注
  • Wide

    Wide 是一款基于 Web 的 Go 语言 IDE。通过浏览器就可以进行 Go 开发,并有代码自动完成、查看表达式、编译反馈、Lint、实时结果输出等功能。

    欢迎访问我们运维的实例: https://wide.b3log.org

    30 引用 • 218 回帖 • 605 关注
  • GitHub

    GitHub 于 2008 年上线,目前,除了 Git 代码仓库托管及基本的 Web 管理界面以外,还提供了订阅、讨论组、文本渲染、在线文件编辑器、协作图谱(报表)、代码片段分享(Gist)等功能。正因为这些功能所提供的便利,又经过长期的积累,GitHub 的用户活跃度很高,在开源世界里享有深远的声望,并形成了社交化编程文化(Social Coding)。

    207 引用 • 2031 回帖
  • SMTP

    SMTP(Simple Mail Transfer Protocol)即简单邮件传输协议,它是一组用于由源地址到目的地址传送邮件的规则,由它来控制信件的中转方式。SMTP 协议属于 TCP/IP 协议簇,它帮助每台计算机在发送或中转信件时找到下一个目的地。

    4 引用 • 18 回帖 • 589 关注
  • SendCloud

    SendCloud 由搜狐武汉研发中心孵化的项目,是致力于为开发者提供高质量的触发邮件服务的云端邮件发送平台,为开发者提供便利的 API 接口来调用服务,让邮件准确迅速到达用户收件箱并获得强大的追踪数据。

    2 引用 • 8 回帖 • 439 关注
  • 百度

    百度(Nasdaq:BIDU)是全球最大的中文搜索引擎、最大的中文网站。2000 年 1 月由李彦宏创立于北京中关村,致力于向人们提供“简单,可依赖”的信息获取方式。“百度”二字源于中国宋朝词人辛弃疾的《青玉案·元夕》词句“众里寻他千百度”,象征着百度对中文信息检索技术的执著追求。

    63 引用 • 785 回帖 • 251 关注
  • BAE

    百度应用引擎(Baidu App Engine)提供了 PHP、Java、Python 的执行环境,以及云存储、消息服务、云数据库等全面的云服务。它可以让开发者实现自动地部署和管理应用,并且提供动态扩容和负载均衡的运行环境,让开发者不用考虑高成本的运维工作,只需专注于业务逻辑,大大降低了开发者学习和迁移的成本。

    19 引用 • 75 回帖 • 619 关注
  • DNSPod

    DNSPod 建立于 2006 年 3 月份,是一款免费智能 DNS 产品。 DNSPod 可以为同时有电信、网通、教育网服务器的网站提供智能的解析,让电信用户访问电信的服务器,网通的用户访问网通的服务器,教育网的用户访问教育网的服务器,达到互联互通的效果。

    6 引用 • 26 回帖 • 521 关注
  • 微软

    微软是一家美国跨国科技公司,也是世界 PC 软件开发的先导,由比尔·盖茨与保罗·艾伦创办于 1975 年,公司总部设立在华盛顿州的雷德蒙德(Redmond,邻近西雅图)。以研发、制造、授权和提供广泛的电脑软件服务业务为主。

    8 引用 • 44 回帖
  • Ruby

    Ruby 是一种开源的面向对象程序设计的服务器端脚本语言,在 20 世纪 90 年代中期由日本的松本行弘(まつもとゆきひろ/Yukihiro Matsumoto)设计并开发。在 Ruby 社区,松本也被称为马茨(Matz)。

    7 引用 • 31 回帖 • 175 关注
  • Logseq

    Logseq 是一个隐私优先、开源的知识库工具。

    Logseq is a joyful, open-source outliner that works on top of local plain-text Markdown and Org-mode files. Use it to write, organize and share your thoughts, keep your to-do list, and build your own digital garden.

    4 引用 • 55 回帖 • 9 关注
  • 程序员

    程序员是从事程序开发、程序维护的专业人员。

    533 引用 • 3528 回帖
  • JWT

    JWT(JSON Web Token)是一种用于双方之间传递信息的简洁的、安全的表述性声明规范。JWT 作为一个开放的标准(RFC 7519),定义了一种简洁的,自包含的方法用于通信双方之间以 JSON 的形式安全的传递信息。

    20 引用 • 15 回帖 • 20 关注
  • Angular

    AngularAngularJS 的新版本。

    26 引用 • 66 回帖 • 511 关注
  • wolai

    我来 wolai:不仅仅是未来的云端笔记!

    1 引用 • 11 回帖 • 2 关注
  • FlowUs

    FlowUs.息流 个人及团队的新一代生产力工具。

    让复杂的信息管理更轻松、自由、充满创意。

    1 引用
  • Flume

    Flume 是一套分布式的、可靠的,可用于有效地收集、聚合和搬运大量日志数据的服务架构。

    9 引用 • 6 回帖 • 594 关注
  • 宕机

    宕机,多指一些网站、游戏、网络应用等服务器一种区别于正常运行的状态,也叫“Down 机”、“当机”或“死机”。宕机状态不仅仅是指服务器“挂掉了”、“死机了”状态,也包括服务器假死、停用、关闭等一些原因而导致出现的不能够正常运行的状态。

    13 引用 • 82 回帖 • 38 关注
  • SSL

    SSL(Secure Sockets Layer 安全套接层),及其继任者传输层安全(Transport Layer Security,TLS)是为网络通信提供安全及数据完整性的一种安全协议。TLS 与 SSL 在传输层对网络连接进行加密。

    69 引用 • 190 回帖 • 496 关注