stable_diffusion代码运行过程

加载模型以及参数

加载参数

首先在Main函数的最开始,新建argparse对象parser,向parser中输入参数以及模型信息,再将这些信息转化为opt

	arser = argparse.ArgumentParser()
    parser.add_argument(
        "--prompt",
        type=str,
        nargs="?",
        default="a painting of a virus monster playing guitar",
        help="the prompt to render"
    )
    opt = parser.parse_args()
    config = OmegaConf.load(f"{opt.config}")

下面是debug后config的值:

{'model': {'base_learning_rate': 0.0001, 
'target': 'ldm.models.diffusion.ddpm.LatentDiffusion',
 'params': {'linear_start': 0.00085, 'linear_end': 0.012, 'num_timesteps_cond': 1, 'log_every_t': 200, 'timesteps': 1000, 'first_stage_key': 'jpg', 'cond_stage_key': 'txt', 'image_size': 64, 'channels': 4, 'cond_stage_trainable': False, 'conditioning_key': 'crossattn', 'monitor': 'val/loss_simple_ema', 'scale_factor': 0.18215, 'use_ema': False, 
 'personalization_config': {'target': 'ldm.modules.embedding_manager.EmbeddingManager', 'params': {'placeholder_strings': ['*'], 'initializer_words': ['sculpture'], 
 'per_image_tokens': False, 'num_vectors_per_token': 1, 'progressive_words': False}}, 
 'unet_config': {'target': 'ldm.modules.diffusionmodules.openaimodel.UNetModel', 
 'params': {'image_size': 32, 'in_channels': 4, 'out_channels': 4, 'model_channels': 320, 'attention_resolutions': [4, 2, 1], 'num_res_blocks': 2, 'channel_mult': [1, 2, 4, 4], 'num_heads': 8, 'use_spatial_transformer': True, 'transformer_depth': 1, 'context_dim': 768, 'use_checkpoint': True, 'legacy': False}}, 
 'first_stage_config': {'target': 'ldm.models.autoencoder.AutoencoderKL',
  'params': {'embed_dim': 4, 'monitor': 'val/rec_loss', 'ddconfig': {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}, '
  lossconfig': {'target': 'torch.nn.Identity'}}}, 
  'cond_stage_config': {'target': 'ldm.modules.encoders.modules.FrozenCLIPEmbedder'}}}}

我们可以看到,config主要记录了LatentDiffusion的超参数以及模型参数,其中具体包括unet参数,first_stage参数,cond_stage参数,这三者可以认为是LatentDiffusion三个阶段,而这些参数都是parser从yaml文件中(configs/stable-diffusion/v1-inference.yaml)读取的:

model:
  base_learning_rate: 1.0e-04
  target: ldm.models.diffusion.ddpm.LatentDiffusion
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    image_size: 64
    channels: 4
    cond_stage_trainable: false   # Note: different from the one we trained before
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False

    personalization_config:
      target: ldm.modules.embedding_manager.EmbeddingManager
      params:
        placeholder_strings: ["*"]
        initializer_words: ["sculpture"]
        per_image_tokens: false
        num_vectors_per_token: 1
        progressive_words: False
        
    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

加载模型

之后使用下面的代码加载模型,其中调用关系复杂:

model = load_model_from_config(config, f"{opt.ckpt}")

首先会调用load_model_from_config函数,该函数接受config和ckpt,从ckpt文件加载模型的状态字典,并将其加载到根据config文件创建的模型中,还打印了可能存在的缺失或意外的键

def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    #从ckpt文件加载一个字典,其中包含模型的状态字典state_dict
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    #从配置文件config中实例化一个模型
    model = instantiate_from_config(config.model)
    #使用load_state_dict()方法将模型状态字典sd加载到模型中
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

其中instantiate_from_config函数的具体作用是检查config字典中是否存在名为"target"的键,如果不存在,将检查config是否等于’is_first_stage’或’is_unconditional’,如果"target"键存在于config中,将使用get_obj_from_str()函数根据config[“target”]的值实例化一个对象:

def instantiate_from_config(config, **kwargs):
    if not "target" in config:
        if config == '__is_first_stage__':
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)

其中涉及get_obj_from_str函数,它根据给定的字符串string实例化一个对象,在debug过程中,string输入为’ldm.models.diffusion.ddpm.LatentDiffusion’,也就是实例化ldm:

def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    #获取实例属性
    return getattr(importlib.import_module(module, package=None), cls)

由于string输入为’ldm.models.diffusion.ddpm.LatentDiffusion’,module就是’ldm.models.diffusion.ddpm’,就是cls就是LatentDiffusion,也就是说要去’ldm.models.diffusion.ddpm这个类去使用getattr函数取出LatentDiffusion的类属性。那么接下来就是加载ddpm代码:
进入DDPM类后,先进行各种初始化,在其中self.model = DiffusionWrapper(unet_config, conditioning_key)这段代码调用DiffusionWrapper类初始化模型:

传入的参数是unet_config,和conditioning_key

class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config, conditioning_key):
        super().__init__()
        #根据传入的unet参数进行实例化
        self.diffusion_model = instantiate_from_config(diff_model_config)
        self.conditioning_key = conditioning_key
        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']

    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        if self.conditioning_key is None:
            out = self.diffusion_model(x, t)
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)
        elif self.conditioning_key == 'crossattn':
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc)
        elif self.conditioning_key == 'hybrid':
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc)
        elif self.conditioning_key == 'adm':
            cc = c_crossattn[0]
            out = self.diffusion_model(x, t, y=cc)
        else:
            raise NotImplementedError()

        return out

进入openaimodel.py文件中的unet类去实例化unet对象,其中unet就是先计算time_embed,再下采样(由ResBlock、AttentionBlock和TimestepEmbedSequential组成),中间层,上采样
接下来ldm函数中的super().__init__(conditioning_key=conditioning_key, *args, **kwargs)会跳转到ddpm类,通过register_schedule计算所需参数

接下来就是处理第一阶段模型,使用self.instantiate_first_stage(first_stage_config)初始化模型参数,其中方法和上述差不多,在进入autoencoder.py中的AutoencoderKL类后,首先初始化encoder和decoder,之后加载必要参数:

    def instantiate_first_stage(self, config):
    	#根据config字典来实例化一个模型(model)
        model = instantiate_from_config(config)
        #将实例化后的模型赋值给self.first_stage_model
        self.first_stage_model = model.eval()
        self.first_stage_model.train = disabled_train
        #禁用参数的梯度计算
        for param in self.first_stage_model.parameters():
            param.requires_grad = False

在之后进行初始化cond_stage模型,还是调用self.instantiate_cond_stage(cond_stage_config),这次到get_obj_from_str方法时的string=‘ldm.modules.encoders.modules.FrozenCLIPEmbedder’,也就是要去’FrozenCLIPEmbedder’中提取属性
在FrozenCLIPEmbedder类中我们可以设置预训练的参数,

class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from Hugging Face)"""
    def __init__(self, version="clip-vit-large-patch14", device="cuda", max_length=77):
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version)
        self.device = device
        self.max_length = max_length

进行到这里以及完成了ldm模型的加载,下面进行加载采样器:

sampler = DDIMSampler(model)

class DDIMSampler(object):
    def __init__(self, model, schedule="linear", **kwargs):
        super().__init__()
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

接下来将空提示词加载为条件向量:

if opt.scale != 1.0:
	uc = model.get_learned_conditioning(batch_size * [""])

这条命令执行会调用以下函数:

	#其中c=['', '', '']
    def get_learned_conditioning(self, c):
        if self.cond_stage_forward is None:
        	#有encode
            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
            	#将c encode一下
                c = self.cond_stage_model.encode(c, embedding_manager=self.embedding_manager)
                if isinstance(c, DiagonalGaussianDistribution):
                    c = c.mode()
            else:
                c = self.cond_stage_model(c)
        else:
            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
        return c

接下来对于encode会调用

    def encode(self, text, **kwargs):
        return self(text, **kwargs)

转到clip的前向传播函数:

    def forward(self, text, **kwargs):
    	#按照self.max_length的大小对['', '', '']进行编码
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        #取出token
        tokens = batch_encoding["input_ids"].to(self.device)
        #调用transformer的前向函数        
        z = self.transformer(input_ids=tokens, **kwargs)

        return z

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/570126.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Jmeter插件技术:性能测试中服务端资源监控

性能测试过程中我们需要不断的监测服务端资源的使用情况,例如CPU、内存、I/O等。 Jmeter的插件技术可以很好的实时监控到服务器资源的运行情况,并以图形化的方式展示出来,非常方便我们性能测试分析。 操作步骤: 1、安装插件管理…

主打国产算力 广州市通用人工智能公共算力中心项目签约

4月9日,第十届广州国际投资年会期间,企商在线(北京)数据技术股份有限公司与广州市增城区政府就“广州市通用人工智能公共算力中心”项目进行签约。 该项目由广州市增城区人民政府发起,企商在线承建。项目拟建成中国最…

后端工程师——Java工程师如何准备面试

在国内,Java 程序员是后端开发工程师中最大的一部分群体,其市场需求量也是居高不下,C++ 程序员也是热门岗位之一,此二者的比较也常是热点话题,例如新学者常困惑的问题之一 —— 后端开发学 Java 好还是学 C++ 好。读完本文后,我们可以从自身情况、未来的发展,岗位需求量…

【JVM系列】关于静态块、静态属性、构造块、构造方法的执行顺序

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

Java算法 空间换时间(找重复)

一、算法示例 1、题目:题目:0-999的数组中,添加一个重复的元素,打乱后,找出这个重复元素 代码示例: package com.zw.study.algorithm; import java.util.*; public class XorTest {public static void mai…

Vue报错 Cannot read properties of undefined (reading ‘websiteDomains‘) 解决办法

浏览器控制台如下报错: Unchecked runtime.lastError: The message port closed before a response was received. Uncaught (in promise) TypeError: Cannot read properties of undefined (reading websiteDomains) at xl-content.js:1:100558 此问题困扰了…

可持续发展:制造铝制饮料罐要消耗多少资源?

铝制饮料罐是人们经常使用的日常用品,无论是在购物、午休还是在自动售货机前选择喝什么的时候,很少有人会想知道装他们喝的饮料的罐子到底是如何制成的,或者这些铝罐的原材料是如何进出的。 虽然有化学品和一些合金进入铝饮料罐制造过程或成为…

【VSCode调试技巧】Pytorch分布式训练调试

最近遇到个头疼的问题,对于单机多卡的训练脚本,不知道如何使用VSCode进行Debug。 解决方案: 1、找到控制分布式训练的启动脚本,在自己的虚拟环境的/lib/python3.9/site-packages/torch/distributed/launch.py中 2、配置launch.…

【Qt常用控件】—— 输入类控件

目录 1.1 Line Edit 1.2 Text Edit 1.3 Combo Box 1.4 Spin Box 1.5 Date Edit & Time Edit 1.6 Dial 1.7 Slider 1.1 Line Edit QLineEdit是Qt中的一个控件,用于 接收和显示单行文本输入。 核心属性 属性 说明 text 输⼊框中的⽂本 inputMask 输⼊…

Science Robotics 美国斯坦福大学研制了外行星洞穴探测机器人

月球和火星上的悬崖、洞穴和熔岩管已被确定为具有地质和天体生物学研究理想地点。由于其隔绝特性,这些洞穴提供了相对稳定的条件,可以促进矿物质沉淀和微生物生长。在火星上,这些古老的地下环境与火星表面可能适合居住时几乎没有变化&#xf…

JavaEE 初阶篇-深入了解网络通信相关的基本概念(三次握手建立连接、四次挥手断开连接)

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 网络通信概述 1.1 基本的通信架构 2.0 网络通信三要素 3.0 网络通信三要素 - IP 地址 3.1 查询 IP 地址 3.2 IP 地址由谁供应? 3.3 IP 域名 3.4 IP 分…

大模型接口管理和分发系统One API

老苏就职于一家专注于音视频实时交互技术和智能算法的创新企业。公司通过提供全面的 SDK 和解决方案,助力用户轻松实现实时音视频通话和消息传递等功能。尽管公司网站上有详细的文档中心,但在实际开发中,仍面临大量咨询工作。 鉴于此&#x…

知识图谱嵌入领域的重要研究:编辑基于语言模型的知识图谱嵌入

今天,向大家介绍一篇在知识图谱嵌入领域具有重要意义的研究论文——Editing Language Model-based Knowledge Graph Embeddings。这项工作由浙江大学和腾讯公司的研究人员联合完成,为我们在动态更新知识图谱嵌入方面提供了新的视角和方法。 研究背景 在…

Linux安装MongoDB超详细

Linux端安装 我们从MonDB官网下载Linux端的安装包,建议下载4.0版本 打开虚拟机,在虚拟机上安装传输工具lrzsz,将下载好的.tgz包拖到虚拟机当中,拖到/usr/local/mongoDB目录下, [rootserver ~]# yum install -y lrzsz [rootser…

如何使用 Vercel 托管静态网站

今天向大家介绍 Vercel 托管静态网站的几种方式,不熟悉 Vercel 的伙伴可以看一下之前的文章:Vercel: 开发者免费的网站托管平台 Github 部署 打开 Vercel 登录界面,推荐使用 GitHub账号 授权登录。 来到控制台界面,点击 Add New …

四川古力未来科技抖音小店:科技新宠,购物新体验

在当下数字化、智能化的时代,电商平台如雨后春笋般涌现,其中不乏一些富有创新精神和实力雄厚的科技企业。四川古力未来科技有限公司就是其中的佼佼者,其抖音小店更是凭借其独特的魅力和优质的服务,赢得了广大消费者的青睐。 一、科…

6步教你APP广告高效变现,收益翻倍秘诀大揭秘!

移动应用广告变现最佳实践与策略指南 在移动应用市场中,广告变现已成为开发者和公司获取收益的重要途径。然而,如何在保证用户体验的同时,实现广告收入的最大化,成为了众多开发者和公司面临的挑战。本文将为您介绍一些最佳的实践…

Seal^_^【送书活动第2期】——《Flink入门与实战》

Seal^_^【送书活动第2期】——《Flink入门与实战》 一、参与方式二、本期推荐图书2.1 作者简介2.2 编辑推荐2.3 前 言2.4 本书特点2.5 内容简介2.6 本书适用读者2.7 书籍目录 三、正版购买 一、参与方式 评论:"掌握Flink,驭大数据,实战…

nginx配置https及wss

环境说明 服务器的是centos7 nginx版本nginx/1.20.1 springboot2.7.12 nginx安装教程点击这里 微信小程序wss配置 如果您的业务是开发微信小程序&#xff0c; 请先进行如下配置。 boot集成websocket maven <dependency><groupId>org.springframework.boot<…

APP UI自动化测试,思路全总结在这里了

首先想要说明一下&#xff0c;APP自动化测试可能很多公司不用&#xff0c;但也是大部分自动化测试工程师、高级测试工程师岗位招聘信息上要求的&#xff0c;所以为了更好的待遇&#xff0c;我们还是需要花时间去掌握的&#xff0c;毕竟谁也不会跟钱过不去。 接下来&#xff0c…
最新文章