# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
class PositionalEncoder(nn.Module):
def __init__(self, head_dim: int, base: Optional[float] = 10e4):
"""
Position encoding as described by Vaswani et. al.
https://arxiv.org/abs/1706.03762
Args:
head_dim (int): dimensionality of the key/query vectors
base (float, *optional*): base used to create the positional encodings
"""
super().__init__()
self.base = base
self.head_dim = head_dim
if self.head_dim % 2 != 0:
raise ValueError(f"`head_dim`:{self.head_dim} must be even")
div_term = self.base ** (torch.arange(0, head_dim, 2) / head_dim)
self.register_buffer("div_term", tensor=div_term, persistent=False)
def forward(
self, x: torch.FloatTensor, position_ids: Optional[torch.Tensor] = None
):
"""
Args:
x (torch.FloatTensor): query/key of shape `(bsz, num_heads, seq_len,
head_dim)`
position_ids (torch.tensor, *optional*): specify the position ids, by
default constructs 0-sequence_length
Returns:
`(torch.Tensor)` query/key tensors with standard additive positional
encoding
"""
seq_length = x.shape[-2]
if position_ids is None:
position_ids = torch.arange(0, seq_length, 1).to(x)
div_term = position_ids.view(-1, 1) / self.div_term.to(x)
pe = torch.zeros_like(x)
pe[..., 0::2] += torch.cos(div_term)
pe[..., 1::2] += torch.sin(div_term)
return x + pe
class IdentityEncoder(nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x: torch.FloatTensor, **kwargs):
"""
No transformation of the input is applied.
Args:
x `(torch.FloatTensor)`
Return
`(torch.FloatTensor)`
"""
return x
class RotaryEncoder(nn.Module):
def __init__(self, head_dim: int, base: Optional[float] = 10e4):
"""
Rotary position encoding as described by Su et. al.
https://arxiv.org/abs/2104.09864
Args:
head_dim (int): feature dimension of the key/query vector
base (float): base to be used to create the positional encodings
"""
super().__init__()
self.base = base
self.head_dim = head_dim
if self.head_dim % 2 != 0:
raise ValueError(f"`head_dim`:{self.head_dim} must be even")
div_term = self.base ** (torch.arange(0, head_dim, 2) / head_dim).repeat(2)
self.register_buffer("div_term", tensor=div_term, persistent=False)
def rotate_half(self, x: torch.FloatTensor):
"""
Rotates half the hidden dims of the input.
Args:
x (torch.FloatTensor): query/key tensors of shape `(bsz, num_heads,
seq_len, head_dim)`
Returns
`(torch.Tensor)` query/key rotated tensors
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def forward(
self, x: torch.FloatTensor, position_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Applies Rotary Position Encoding to the query and key tensors.
Args:
x (`torch.Tensor`): query/key tensor of shape `(bsz, num_heads, seq_len,
head_dim)`
position_ids (`torch.Tensor`, *optional*):
specify the position ids, by default constructs 0-seq_len
Returns:
`(torch.Tensor)` comprising the query/key tensors rotated using the
Rotary Position Encoding.
"""
seq_length = x.shape[-2]
if position_ids is None:
position_ids = torch.arange(0, seq_length, 1).to(x)
freqs = position_ids.view(-1, 1) / self.div_term.to(x)
x_embed = x * freqs.cos() + self.rotate_half(x) * freqs.sin()
return x_embed
class FullAttention(nn.Module):
# Adapted from https://github.com/huggingface/transformers/main/src/transformers/
# models/phi3/modeling_phi3.py
def __init__(self, config):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
super().__init__()
self.config = config
self.feature_space_dim = config["feature_space_dim"]
head_dim = config["head_dim"]
if head_dim is None:
if (config["feature_space_dim"] % config["num_attention_heads"]) != 0:
raise ValueError(
"If not providing head_dim, ensure `feature_space_dim` is "
"divisible by `num_attention_heads`"
)
head_dim = config["feature_space_dim"] // config["num_attention_heads"]
self.head_dim = head_dim
self.num_heads = config["num_attention_heads"]
if (config["num_attention_heads"] % config["num_key_value_heads"]) != 0:
raise ValueError(
"`num_attention_heads` must be divisible by `num_key_value_heads`"
)
self.num_key_value_groups = (
config["num_attention_heads"] // config["num_key_value_heads"]
)
self.num_key_value_heads = config["num_key_value_heads"]
self.scaling = self.head_dim**-0.5
self.attention_dropout = config["attention_dropout"]
op_size = config["num_attention_heads"] * self.head_dim + 2 * (
config["num_key_value_heads"] * self.head_dim
)
self.o_proj = nn.Linear(
config["num_attention_heads"] * self.head_dim,
config["feature_space_dim"],
bias=False,
)
# This single layer performs the query, key and value projections
# The output is then spit into key_states, query_states, and value_states
# with the corresponding dimensions
self.qkv_proj = nn.Linear(config["feature_space_dim"], op_size, bias=False)
pos_emb = {
"positional": PositionalEncoder,
"rotary": RotaryEncoder,
"none": IdentityEncoder,
}
self.pos_emb = pos_emb[config["pos_emb"]](
head_dim=self.head_dim, base=config["pos_emb_base"]
)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Computes the attention
Args:
hidden_states (`torch.FloatTensor`): query/key tensor of shape `(bsz,
seq_len, feature_space_dim)`
position_ids (`torch.Tensor`, *optional*): specify the position ids, by
default constructs 0-sequence_length
attention_mask (`torch.Tensor`) : Attention mask of shape `(batch_size,
sequence_length, feature_space_dim)`
output_attentions (bool) : return the attention weights, cannot be used
within the NPE/NRE/NLE pipelines,
use it for analyzing the embedding modules
Returns:
`(torch.Tensor, torch.Tensor)` or `(torch.Tensor)` attention output and
optionally the attention weights
"""
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
query_pos = self.num_heads * self.head_dim
query_states = qkv[..., :query_pos]
key_states = qkv[
..., query_pos : query_pos + self.num_key_value_heads * self.head_dim
]
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
query_states = self.pos_emb(query_states, position_ids=position_ids)
key_states = self.pos_emb(key_states, position_ids=position_ids)
# repeat k/v heads if n_kv_heads < n_heads
key_states = key_states.repeat_interleave(
dim=1, repeats=self.num_key_value_groups
)
value_states = value_states.repeat_interleave(
dim=1, repeats=self.num_key_value_groups
)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None:
causal_mask = attention_mask[..., : key_states.shape[-2]]
attn_weights += causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(
value_states.dtype
)
attn_weights = nn.functional.dropout(
attn_weights, p=self.attention_dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
"`attn_output` should be of size "
+ f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.feature_space_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
class MLP(nn.Module):
# Adapted from https://github.com/huggingface/transformers/main/src/
# transformers/models/phi3/modeling_phi3.py
def __init__(self, config):
"""
Feed-forward layer which can be replaced by a custom implementation
f(x):
R^{feature_space_dim} -> R^{intermediate_size}
R^{intermediate_size} -> R^{feature_space_dim}
"""
super().__init__()
self.config = config
self.gate_up_proj = nn.Linear(
config["feature_space_dim"], 2 * config["intermediate_size"], bias=False
)
self.down_proj = nn.Linear(
config["intermediate_size"], config["feature_space_dim"], bias=False
)
if config["mlp_activation"] == "gelu":
self.activation_fn = F.gelu
elif config["mlp_activation"] == "relu":
self.activation_fn = F.relu
else:
raise ValueError(
"Unsupported activation function, currently supported: `gelu, relu`"
)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
"""
Args:
hidden_states (torch.FloatTensor): output from the attention layer of
shape `(batch_size, sequence_length, feature_space_dim)`
Returns:
`(torch.FloatTensor)`
"""
up_states = self.gate_up_proj(
hidden_states
) # projection of hidden_states to 2*intermediate_size
gate, up_states = up_states.chunk(
2, dim=-1
) # split the resulting vector in two (intermediate_size,intermediate_size)
up_states = up_states * self.activation_fn(
gate
) # use one of the splits as input to the activation function and the other
# to scale it
return self.down_proj(up_states)
# Copied from https://github.com/huggingface/transformers/blob/main/src/
# transformers/models/phi3/modeling_phi3.py
class RMSNorm(nn.Module):
def __init__(self, feature_space_dim, eps: float):
"""
RMSNorm is equivalent to T5LayerNorm
Variant of layer normalization https://arxiv.org/abs/1607.06450
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(feature_space_dim))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor):
"""
Args:
hidden_states (torch.FloatTensor): input of shape `(batch_size,
sequence_length, feature_space_dim)`
RMS normalization with per dimension scaling (self.weight)
Returns:
`(torch.FloatTensor)`
"""
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class MoeBlock(nn.Module):
# Adapted from https://github.com/huggingface/transformers/blob/main/src/
# transformers/models/mixtral/modeling_mixtral.py
# https://arxiv.org/abs/2401.04088
def __init__(self, config):
super().__init__()
"""
Mixture of experts implementation with full capacity (no dropped tokens).
`num_local_experts` : specifies the total number of experts available
`num_experts_per_tok`: number of experts each token is assigned to
`router_jitter_noise` : noise to be added at training time before routing
the tokens to experts
"""
self.hidden_dim = config["feature_space_dim"]
self.ffn_dim = config["intermediate_size"]
self.num_experts = config["num_local_experts"]
self.top_k = config[
"num_experts_per_tok"
] # Each token is assigned independently to experts
if self.top_k > self.num_experts:
raise ValueError(
"Each token cannot be assigned to more that num_local_experts"
)
# gating function to determine the experts to be used for each token
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
# experts can be replaced by a custom implementation of feed forward layer
self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)])
# Jitter parameters
self.jitter_noise = config["router_jitter_noise"]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
`hidden_states` : input of shape `(batch_size, sequence_length,
hidden_dim)`
Returns:
`final_hidden_states` : output of shape `(batch_size, sequence_length,
hidden_dim)`
"""
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation
# on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
if top_x.numel() > 0:
# Only proceed if there are tokens assigned to this expert
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = (
expert_layer(current_state) * routing_weights[top_x, idx, None]
)
# However `index_add_` only support torch tensors for indexing so
# we'll use the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)
final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)
return final_hidden_states
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
"""
Transformer block containing an attention and feed-forward layer
It also contains 2 learnable layer-norms
"""
self.feature_space_dim = config["feature_space_dim"]
self.self_attn = FullAttention(config=config)
ffn = {
"mlp": MLP,
"moe": MoeBlock,
}
self.ffn = ffn[config["ffn"]](config)
self.input_layernorm = RMSNorm(
config["feature_space_dim"], eps=config["rms_norm_eps"]
)
self.post_attention_layernorm = RMSNorm(
config["feature_space_dim"], eps=config["rms_norm_eps"]
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
hidden_states (`torch.Tensor`): input to the layer of shape `(batch,
seq_len, embed_dim)`
attention_mask (`torch.Tensor`, *optional*):
attention mask of size `(batch_size, sequence_length)`
output_attentions (`bool`, *optional*):
Whether or not to return the attention tensors
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.ffn(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
else:
outputs += (None,)
return outputs
class ViTEmbeddings(nn.Module):
# Adapted from https://github.com/huggingface/transformers/blob/main/src/
# transformers/models/vit/modeling_vit.py
"""
This class turns `pixel_values` of shape `(batch_size, num_channels,
height, width)` into the initial `hidden_states` (patch embeddings)
of shape `(batch_size, seq_length, feature_space_dim)` to be consumed by a
Transformer.
"""
def __init__(self, config):
super().__init__()
"""
image_size (int) is the expected size of the square image
patch_size (int) is the expected size of the square patch
"""
image_size, patch_size = config["image_size"], config["patch_size"]
num_channels, feature_space_dim = (
config["num_channels"],
config["feature_space_dim"],
)
num_patches = (image_size // patch_size) ** 2
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.position_embeddings = nn.Parameter(
torch.randn(1, num_patches + 1, config["feature_space_dim"])
)
self.projection = nn.Conv2d(
num_channels, feature_space_dim, kernel_size=patch_size, stride=patch_size
)
self.cls_token = nn.Parameter(torch.randn(1, 1, config["feature_space_dim"]))
self.dropout = nn.Dropout(config["vit_dropout"])
def interpolate_pos_encoding(
self, embeddings: torch.Tensor, height: int, width: int
) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings,
to be able to use the model on higher resolution images.
Args:
embeddings (torch.FloatTensor): embedding patches generated after
applying the CNN filters on the input image
height (int): height of the input image
width (int): width of the input image
"""
# Adapted from:
# https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac95
# 2ab558447af1fa1365362a/vision_transformer.py
# https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fb
# adf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py
num_positions = self.position_embeddings.shape[1] - 1
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(
1, sqrt_num_positions, sqrt_num_positions, dim
)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
"""
Images of shape `(batch_size, num_channels, height, width)`
Return input_embeddings of shape `(batch_size, seq_length, feature_space_dim)`
to be consumed by a Transformer.
"""
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values "
"match with the one set in the configuration."
f" Expected {self.num_channels} but got {num_channels}."
)
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
embeddings = embeddings + self.interpolate_pos_encoding(
embeddings, height, width
)
embeddings = self.dropout(embeddings)
return embeddings