diff --git a/bytelatent/base_transformer.py b/bytelatent/base_transformer.py index d947040..60604b5 100644 --- a/bytelatent/base_transformer.py +++ b/bytelatent/base_transformer.py @@ -6,18 +6,19 @@ from typing import Optional, Tuple, Union import torch + +from bytelatent.model.utils import DTYPE_MAP +from bytelatent.tokenizers.constants import EOS_ID from pydantic import BaseModel, ConfigDict from torch import nn from torch.nn import functional as F from torch.nn.attention.flex_attention import ( - BlockMask, _mask_mod_signature, + BlockMask, flex_attention, ) from xformers.ops import AttentionBias, fmha -from bytelatent.tokenizers.constants import EOS_ID - logger = logging.getLogger() try: @@ -68,6 +69,9 @@ class BaseTransformerArgs(BaseModel): # Special token config eos_id: int | None = EOS_ID + init_device: str = "cpu" + init_dtype: str = "fp32" + def cross_entropy(pred, target, **kwargs): return F.nll_loss( @@ -95,6 +99,7 @@ def precompute_freqs_cis( end: int, theta: float = 10000.0, rope_use_fp32_in_outer_product: bool = False, + device: str | torch.device = torch.device("cpu"), ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -111,7 +116,9 @@ def precompute_freqs_cis( Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim) + ) t = torch.arange(end, device=freqs.device) if rope_use_fp32_in_outer_product: t = t.to(torch.float32) @@ -258,6 +265,8 @@ def __init__( head_dim: int, max_seqlen: int = 1024, rope_use_fp32_in_outer_product: bool = False, + device: str | torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, ): super().__init__() @@ -273,7 +282,8 @@ def __init__( end=max_seqlen, theta=theta, rope_use_fp32_in_outer_product=self.rope_use_fp32_in_outer_product, - ), + device=device, + ).to(dtype=dtype), persistent=False, ) @@ -325,6 +335,8 @@ def __init__( n_heads: int, n_kv_heads: int, rope_theta: float, + device: str | torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, ): super().__init__() @@ -340,22 +352,30 @@ def __init__( dim, n_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wk = nn.Linear( dim, n_kv_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wv = nn.Linear( dim, n_kv_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wo = nn.Linear( n_heads * head_dim, dim, bias=False, + device=device, + dtype=dtype, ) def forward( @@ -368,6 +388,7 @@ def forward( ) -> torch.Tensor: # B S D bsz, seq_len, dim = x.shape + xq = self.wq(x.view_as(x)) xk = self.wk(x.view_as(x)) xv = self.wv(x.view_as(x)) @@ -453,6 +474,8 @@ def __init__( multiple_of: int, ffn_dim_multiplier: Optional[float], mp_size: int = 1, + device: str | torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, ): super().__init__() @@ -469,16 +492,22 @@ def __init__( dim, hidden_dim, bias=False, + device=device, + dtype=dtype, ) self.w3 = nn.Linear( dim, hidden_dim, bias=False, + device=device, + dtype=dtype, ) self.w2 = nn.Linear( hidden_dim, dim, bias=False, + device=device, + dtype=dtype, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -535,15 +564,30 @@ def __init__(self, args: BaseTransformerArgs): n_heads=self.n_heads, n_kv_heads=self.n_kv_heads, rope_theta=args.rope_theta, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) self.feed_forward = FeedForward( dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of, ffn_dim_multiplier=args.ffn_dim_multiplier, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], + ) + # Norms stay in full precision + self.attention_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], + ) + self.ffn_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward( self, @@ -593,6 +637,8 @@ def __init__(self, args: BaseTransformerArgs): head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) self.eos_id = args.eos_id diff --git a/bytelatent/entropy_model.py b/bytelatent/entropy_model.py index 51973e2..e8a12b3 100644 --- a/bytelatent/entropy_model.py +++ b/bytelatent/entropy_model.py @@ -10,11 +10,12 @@ logger = logging.getLogger() -def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"): +def load_entropy_model( + entropy_model_checkpoint_dir, state_dict_path, device="cpu", dtype="bf16" +): with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr: reloaded = json.loads(fr.read()) - torch.set_default_dtype(torch.bfloat16) model_params = reloaded["entropy_model"] logger.warning( "Update checkpoint to load attn and sliding window args from checkpoint" @@ -29,6 +30,8 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp attn_bias_type="local_block_causal", attn_impl="xformers", sliding_window=512, + init_device=device, + init_dtype=dtype, ) entropy_model = LMTransformer(entropy_model_args) @@ -38,6 +41,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp entropy_model.to(device) entropy_model = entropy_model.eval() # no grads for the model: - for param in entropy_model.parameters(): + for n, param in entropy_model.named_parameters(): param.requires_grad = False + return entropy_model, entropy_model_args diff --git a/bytelatent/generate.py b/bytelatent/generate.py index 97434dc..71ddc52 100644 --- a/bytelatent/generate.py +++ b/bytelatent/generate.py @@ -4,11 +4,6 @@ import time import torch -from omegaconf import OmegaConf -from torch import nn -from torch.nn import functional as F -from torch.nn.attention.flex_attention import create_block_mask -from tqdm import tqdm from bytelatent.args import EvalArgs, PackedCausalTransformerGeneratorArgs, TrainArgs from bytelatent.base_transformer import ( @@ -19,9 +14,9 @@ lengths_to_start_ids, ) from bytelatent.checkpoint import ( + consolidate_checkpoints, CONSOLIDATE_FOLDER, CONSOLIDATE_NAME, - consolidate_checkpoints, ) from bytelatent.config_parser import parse_args_to_pydantic_model from bytelatent.data.file_util import get_fs @@ -33,6 +28,11 @@ from bytelatent.model.blt import ByteLatentTransformer from bytelatent.tokenizers.abstract_tokenizer import Tokenizer from bytelatent.transformer import LMTransformer +from omegaconf import OmegaConf +from torch import nn +from torch.nn import functional as F +from torch.nn.attention.flex_attention import create_block_mask +from tqdm import tqdm def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: @@ -400,25 +400,29 @@ def load_consolidated_model_and_tokenizer(consolidated_path, init_distributed=Fa setup_torch_distributed(distributed_args) train_args_path = os.path.join(consolidated_path, "params.json") fs = get_fs(train_args_path) + train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path)) if train_args.train_entropy_model: model_args = train_args.entropy_model + model_args.init_device = "cuda" + model_args.init_dtype = train_args.distributed.model_dtype model = LMTransformer(model_args) else: model_args = train_args.model - model = ByteLatentTransformer(model_args) + model_args.init_device = "cuda" + model_args.init_dtype = train_args.distributed.model_dtype + model = ByteLatentTransformer(args=model_args) + + model = model.eval() - param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[ - train_args.distributed.model_dtype - ] tokenizer = train_args.data.tokenizer_args.build() - with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f: - st_dict = torch.load(f, weights_only=True) + + with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as fp: + st_dict = torch.load(fp, weights_only=True) + model.load_state_dict(st_dict["model"]) - model = model.cuda().eval() - for param in model.parameters(): - param.data = param.data.to(dtype=param_dtype) + return model, tokenizer, train_args diff --git a/bytelatent/model/blt.py b/bytelatent/model/blt.py index 26934bb..ede27e3 100644 --- a/bytelatent/model/blt.py +++ b/bytelatent/model/blt.py @@ -1,14 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -from enum import Enum, auto +from enum import auto, Enum from typing import Any, Optional import torch -from huggingface_hub import PyTorchModelHubMixin -from pydantic import model_validator -from torch import nn -from torch.nn.attention.flex_attention import create_block_mask -from typing_extensions import Self from bytelatent.base_transformer import ( BaseTransformerArgs, @@ -18,8 +13,15 @@ from bytelatent.data.patcher import Patcher, PatcherArgs from bytelatent.model.latent_transformer import GlobalTransformer from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs -from bytelatent.model.utils import downsample +from bytelatent.model.utils import check_param_device, downsample, DTYPE_MAP from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID +from huggingface_hub import PyTorchModelHubMixin + +from numpy.random import f +from pydantic import model_validator +from torch import nn +from torch.nn.attention.flex_attention import create_block_mask +from typing_extensions import Self def attention_flops_per_token(n_layers, seq_len, dim, causal): @@ -155,6 +157,9 @@ def decoder_patch_ids_from_lengths(patch_lengths, nb_boe, seq_len): def rolling_polynomial_hash(t, hash_func_nb: int = 0): + if hash_func_nb >= len(primes): + print(f"len(primes): {len(primes)}, hash_func_nb: {hash_func_nb}") + prime = torch.tensor(primes[hash_func_nb], dtype=torch.int64, device=t.device) prime_powers = torch.stack([prime**i for i in range(t.shape[-1])]) return torch.sum(t * prime_powers, dim=-1) @@ -239,6 +244,9 @@ def create_patch_mask_from_ids( return mask +GLOBAL = set() + + def cross_attn_mask( patch_ids, patch_lengths, @@ -265,9 +273,12 @@ def cross_attn_mask( kv_len, ), f"{cross_mask.shape} != {(bs, q_len, kv_len)}" if block_mask: + # This appears to resolve occasional nondeterministic RuntimeErrors + # in the create_block_mask call. I have no idea why. + cross_mask_copy = cross_mask.clone() def patch_mask(b, h, q_idx, kv_idx): - return cross_mask[b, q_idx, kv_idx] + return cross_mask_copy[b, q_idx, kv_idx] block_mask = create_block_mask( patch_mask, @@ -277,6 +288,7 @@ def patch_mask(b, h, q_idx, kv_idx): KV_LEN=kv_len, _compile=True, ) + return block_mask else: return torch.where( @@ -632,6 +644,8 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder: cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, cross_attn_nheads=args.cross_attn_nheads, eos_id=args.eos_id, + init_device=args.init_device, + init_dtype=args.init_dtype, ) return LocalEncoder(local_encoder_args) @@ -675,6 +689,8 @@ def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder: cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder, cross_attn_nheads=args.cross_attn_nheads, eos_id=args.eos_id, + init_device=args.init_device, + init_dtype=args.init_dtype, ) return LocalDecoder(local_decoder_args) @@ -710,6 +726,8 @@ def init_embeddings( nn.Embedding( encoder_hash_byte_group_vocab, emb_dim, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) ) @@ -718,7 +736,14 @@ def init_embeddings( emb_dim = local_encoder_dim OFFSET = 4 # This should be passed as parameter if it's variable for ngram_vocab_size in encoder_ngram_to_size.values(): - embeddings.append(nn.Embedding(ngram_vocab_size + OFFSET, emb_dim)) + embeddings.append( + nn.Embedding( + ngram_vocab_size + OFFSET, + emb_dim, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], + ) + ) return nn.ModuleList(embeddings) @@ -792,7 +817,7 @@ class ByteLatentTransformer( """ def __init__(self, args: ByteLatentTransformerArgs): - super().__init__() + super(ByteLatentTransformer, self).__init__() # General configuration self.weight_tying = args.weight_tying @@ -854,7 +879,12 @@ def __init__(self, args: ByteLatentTransformerArgs): ngram_emb_dim = self.local_encoder.dim for ngram_vocab_size in self.encoder_ngram_to_size.values(): self.encoder_ngram_embedding.append( - nn.Embedding(ngram_vocab_size + OFFSET, ngram_emb_dim) + nn.Embedding( + ngram_vocab_size + OFFSET, + ngram_emb_dim, + device=args.init_device, + dtype=dtype_map[args.init_dtype], + ) ) # Output layer @@ -873,6 +903,9 @@ def __init__(self, args: ByteLatentTransformerArgs): ) ) + # Sanity check + check_param_device(self, args.init_device) + def push_to_hub(self, *args, **kwargs): raise ValueError( "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct." diff --git a/bytelatent/model/latent_transformer.py b/bytelatent/model/latent_transformer.py index a6cabdc..55507c1 100644 --- a/bytelatent/model/latent_transformer.py +++ b/bytelatent/model/latent_transformer.py @@ -5,9 +5,6 @@ import torch import torch.nn import torch.nn as nn -from torch.nn import functional as F -from torch.nn.attention.flex_attention import BlockMask -from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformer, @@ -15,7 +12,10 @@ flex_attention_comp, repeat_kv, ) -from bytelatent.model.utils import create_causal_mask +from bytelatent.model.utils import create_causal_mask, DTYPE_MAP +from torch.nn import functional as F +from torch.nn.attention.flex_attention import BlockMask +from xformers.ops import AttentionBias logger = logging.getLogger() try: @@ -40,6 +40,8 @@ def __init__( n_heads: int, n_kv_heads: int, norm_eps: float, + device: str | torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, ): super().__init__() @@ -50,29 +52,39 @@ def __init__( self.n_kv_heads = n_kv_heads self.heads_per_group = self.n_heads // self.n_kv_heads - self.cross_attn_norm_q = nn.RMSNorm(dim, eps=norm_eps) - self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps) + self.cross_attn_norm_q = nn.RMSNorm( + dim, eps=norm_eps, device=device, dtype=dtype + ) + self.cross_attn_norm_kv = RMSNorm(dim, eps=norm_eps, device=device, dtype=dtype) self.wq = nn.Linear( dim, n_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wk = nn.Linear( dim, n_kv_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wv = nn.Linear( dim, n_kv_heads * head_dim, bias=False, + device=device, + dtype=dtype, ) self.wo = nn.Linear( n_heads * head_dim, dim, bias=False, + device=device, + dtype=dtype, ) def forward( @@ -160,6 +172,8 @@ def __init__(self, args: BaseTransformerArgs): args.dim_token_emb, args.dim, bias=False, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) def forward( diff --git a/bytelatent/model/local_models.py b/bytelatent/model/local_models.py index 7083ac4..5b81bdc 100644 --- a/bytelatent/model/local_models.py +++ b/bytelatent/model/local_models.py @@ -6,10 +6,6 @@ import torch import torch.nn import torch.nn as nn -from pydantic import ConfigDict -from torch.nn import functional as F -from torch.nn.attention.flex_attention import BlockMask -from xformers.ops import AttentionBias from bytelatent.base_transformer import ( BaseTransformerArgs, @@ -18,8 +14,12 @@ TransformerBlock, ) from bytelatent.model.latent_transformer import CrossAttention -from bytelatent.model.utils import create_causal_mask, downsample +from bytelatent.model.utils import create_causal_mask, downsample, DTYPE_MAP from bytelatent.tokenizers.blt_tokenizer import BOE_ID +from pydantic import ConfigDict +from torch.nn import functional as F +from torch.nn.attention.flex_attention import BlockMask +from xformers.ops import AttentionBias logger = logging.getLogger() try: @@ -85,18 +85,31 @@ def __init__(self, args: LocalModelArgs): ) if not self.use_rope: - self.pos_embeddings = nn.Embedding(args.max_length, args.dim) + self.pos_embeddings = nn.Embedding( + args.max_length, + args.dim, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], + ) else: self.rope = RotaryEmbedding( theta=args.rope_theta, head_dim=args.head_dim or args.dim // args.n_heads, max_seqlen=args.max_seqlen, rope_use_fp32_in_outer_product=args.rope_use_fp32_in_outer_product, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) self.pos_embeddings = None self.token_embedding_projection = ( - nn.Linear(args.dim_token_emb, args.dim, bias=False) + nn.Linear( + args.dim_token_emb, + args.dim, + bias=False, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], + ) if hasattr(args, "dim_token_emb") and args.dim_token_emb != self.dim else None ) @@ -125,6 +138,8 @@ def _create_patch_projection(self, args): in_features=args.dim_patch_emb, out_features=output_dim, bias=False, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) def apply_embedding(self, tokens, embeds): @@ -218,7 +233,12 @@ def __init__(self, args: LocalModelArgs): self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_nheads = args.cross_attn_nheads - self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) + self.tok_embeddings = nn.Embedding( + self.vocab_size, + args.dim, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], + ) if self.cross_attn_encoder: self.cross_attn_layers = torch.nn.ModuleList() @@ -231,6 +251,8 @@ def __init__(self, args: LocalModelArgs): n_heads=self.cross_attn_nheads, n_kv_heads=self.cross_attn_nheads, norm_eps=args.norm_eps, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) ) @@ -321,7 +343,12 @@ def __init__(self, args: LocalModelArgs): self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling self.cross_attn_nheads = args.cross_attn_nheads - self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.norm = RMSNorm( + args.dim, + eps=args.norm_eps, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], + ) if self.cross_attn_decoder: self.cross_attn_layers = torch.nn.ModuleList() @@ -334,6 +361,8 @@ def __init__(self, args: LocalModelArgs): n_heads=self.cross_attn_nheads, n_kv_heads=self.cross_attn_nheads, norm_eps=args.norm_eps, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) ) @@ -341,6 +370,8 @@ def __init__(self, args: LocalModelArgs): self.dim, args.vocab_size, bias=False, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) def forward( diff --git a/bytelatent/model/utils.py b/bytelatent/model/utils.py index e01672e..ebb4edf 100644 --- a/bytelatent/model/utils.py +++ b/bytelatent/model/utils.py @@ -8,6 +8,13 @@ logger = logging.getLogger() +DTYPE_MAP = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + "fp64": torch.float64, +} + def patch_reduce(h, max_num_patches, reduction, patch_ids): """ @@ -175,3 +182,10 @@ def create_causal_mask( raise NotImplementedError( f"Attention {attn_impl} with {sliding_window} sliding window not implemented" ) + + +def check_param_device(model, device_type: str = "cpu"): + for name, param in model.named_parameters(): + assert ( + param.device.type == device_type + ), f"Parameter {name} is on {param.device.type}, not on {device_type}" diff --git a/bytelatent/transformer.py b/bytelatent/transformer.py index 32d63be..86ec17c 100644 --- a/bytelatent/transformer.py +++ b/bytelatent/transformer.py @@ -4,26 +4,26 @@ from typing import Optional, Tuple, Union import torch + +from bytelatent.base_transformer import ( + BaseTransformer, + BaseTransformerArgs, + cross_entropy, +) +from bytelatent.model.utils import check_param_device, create_causal_mask, DTYPE_MAP from huggingface_hub import PyTorchModelHubMixin from torch import nn from torch.distributed._tensor import Replicate, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, + parallelize_module, PrepareModuleInput, RowwiseParallel, SequenceParallel, - parallelize_module, ) from torch.nn.attention.flex_attention import BlockMask, create_block_mask from xformers.ops import AttentionBias -from bytelatent.base_transformer import ( - BaseTransformer, - BaseTransformerArgs, - cross_entropy, -) -from bytelatent.model.utils import create_causal_mask - logger = logging.getLogger() try: @@ -84,19 +84,34 @@ def __init__(self, args: LMTransformerArgs): assert args.vocab_size > 0 - self.tok_embeddings = torch.nn.Embedding(args.vocab_size, args.dim) + self.tok_embeddings = torch.nn.Embedding( + args.vocab_size, + args.dim, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], + ) - self.norm = RMSNorm(args.dim, eps=args.norm_eps) + self.norm = RMSNorm( + args.dim, + eps=args.norm_eps, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], + ) self.output = nn.Linear( args.dim, args.vocab_size, bias=False, + device=args.init_device, + dtype=DTYPE_MAP[args.init_dtype], ) if args.weight_tying: self.output.weight = self.embeddings.tok_embeddings.weight + # Sanity check + check_param_device(self, args.init_device) + def push_to_hub(self, *args, **kwargs): raise ValueError( "For meta authors: Do not push BLT weights with this, save weights with save_pretrained() then push them manually to HF hub to ensure the repository metadata is correct."