16. Large Language Models: An Introduction#

This chapter introduces large language models (LLMs). We will discuss tokenization strategies, model architecture, the attention mechanism, and dynamic embeddings.

Learning objectives

By the end of this chapter, you should be able to:

  • Explain the difference between pretraining and fine-tuning

  • Have a general idea of how LLMs work

  • Know where to get LLMs for specific tasks

  • Explain what subwords are and tokenize texts with a subword tokenizer

  • Generate embeddings for texts from a LLM

  • Explain embedding pooling and retrieve different embedding representations from a model

16.1. Preliminaries#

We need the following libraries:

import torch
from transformers import AutoTokenizer, AutoModel

16.1.1. Using a pretrained model#

Training LLMs requires vast amounts of data and computational resources. While these resources are expensive, the very scale of these models contributes to their ability to generalize. Practitioners will therefore use the same model for a variety of tasks. They do this by pretraining a general model to perform a foundational task, usually next-token prediction. Then, once that model is trained, practitioners fine-tune that model for other tasks. The fine-tuned variants benefit from the generalized language representations learned during pretraining, but they adapt those representations to more specific contexts and tasks.

The best place to find these pretrained models is Hugging Face. The company hosts thousands of them on its platform, and it also develops various machine learning tools for working with these models. Hugging Face also features fine-tuned models for various tasks, which may work out of the box for your needs. Take a look at the model listing to see all models on the platform. At the left, you’ll see categories for model types, task types, and more.

16.1.2. Loading a model#

To load a model from Hugging Face, first specify the checkpoint you’d like to use. Typically this is just the name of the model.

checkpoint = "google-bert/bert-base-uncased"

The transformers library has different tokenizer and model classes for different models/architectures and tasks. You can write these out directly, or use the Auto classes, which dynamically determine what class you’ll need for a model and task. Below, we load the base BERT model without specifying a task.

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
bert = AutoModel.from_pretrained(checkpoint)

If you don’t have this model stored on your own computer, it will download directly from Hugging Face. The default directory for storing Hugging Face data is ~/.cache/hugggingface. Set a HF_HOME environment variable from the command line if you want Hugging Face downloads to default to a different location on your computer.

export HF_HOME=/path/to/another/directory

16.2. Subword Tokenization#

Note that we have initialized a tokenizer and model from the same checkpoint. This is important: LLMs depend on specific tokenizers, which are themselves trained on corpus data before their corresponding models even see that data. But why do tokenizers need to be trained in the first place?

The answer has to do with the highly general nature of LLMs. These models are trained on huge corpora, which means they must represent millions of different pieces of text. Model vocabularies would quickly balloon to a huge size if they represented all these tokens, however, and at any rate this would both inefficient and a waste of resources, since some tokens are extremely rare. In traditional tokenization and model building, you’d set a cutoff below which rare tokens could be ignored, but LLMs need all text. That means they need to represent every token in a corpus—without storing representations for every token in a corpus.

Model developers square this circle by using pieces of words, or subwords, to represent other tokens. That way, a model can literally spell out any text sequence it needs to build without having representations for every unique token in its training corpus. (This also means LLMs can handle text they’ve never seen before.) Setting the cutoff for which tokens should be represented in full and which are best represented by subwords requires training a tokenizer to learn the token distribution in a corpus, build subwords, and determine said cutoff.

With subword tokenization, the following phrase:

large language models use subword tokenization

…becomes:

large language models use sub ##word token ##ization

See the hashes? This tokenizer prepends them to its subwords.

16.2.1. Input IDs#

The actual output of transformers tokenizer has a few parts. We use the following sentence as an example:

sentence = "Then I tried to find some way of embracing my mother's ghost."

Send this to the tokenizer, setting the return type to PyTorch tensors. We also return the attention mask.

inputs = tokenizer(
    sentence, return_tensors = "pt", return_attention_mask = True
)

Input IDs are the unique identifiers for every token in the input text. These are what the model actually looks at.

inputs["input_ids"]
tensor([[  101,  2059,  1045,  2699,  2000,  2424,  2070,  2126,  1997, 23581,
          2026,  2388,  1005,  1055,  5745,  1012,   102]])

Use the .decode() method to transform an ID (or sequences of ids) back to text.

tokenizer.decode(5745)
'ghost'

The tokenizer has entries for punctuation:

tokenizer.decode([1005, 1012])
"'."

Whitespace tokens are often removed, however:

ws = tokenizer(" \t\n")
ws["input_ids"]
[101, 102]

But if that’s the case, what are those two IDs? These are two special tokens that BERT uses for a few different tasks.

tokenizer.decode([101, 102])
'[CLS] [SEP]'

[CLS] is prepended to every input sequence. It marks the start of a sequence, and it also serves as a “summarization” token for sequences, a kind of aggregate representation of model outputs. When you train a model for classification tasks, the model uses [CLS] to decide how to categorize a sequence.

[SEP] is appended to every input sequence. It marks the end of a sequence, and it is used to separate input pairs for tasks like sentence similarity, question answering, and summarization. When training, a model looks to [SEP] to distinguish which parts of the input correspond to task components.

16.2.2. Token type IDs#

Some models don’t need anything more than [CLS] and [SEP] to make the above distinctions. But other models also incorporate token type IDs to further distinguish individual pieces of input. These IDs are binary values that tell the model which parts of the input belong to what components in the task.

Our sentence makes no distinction:

inputs["token_type_ids"]
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

But a pair of sentences would:

question = "What did I do then?"
with_token_types = tokenizer(question, sentence)
with_token_types["token_type_ids"]
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

16.2.3. Attention mask#

A final output, attention mask, tells the model what part of the input it should use when it processes the sequence.

inputs["attention_mask"]
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

16.2.4. Padding and truncation#

It may seem like a redundancy to add an attention mask, but tokenizers often pad input sequences. While Transformer models can process sequences in parallel, which massively speeds up their run time, each sequence in a batch needs to be the same length. Texts, however, are rarely the same length, hence the padding.

two_sequence_inputs = tokenizer(
    [question, sentence],
    return_tensors = "pt",
    return_attention_mask = True,
    padding = "longest"
)
two_sequence_inputs["input_ids"]
tensor([[  101,  2054,  2106,  1045,  2079,  2059,  1029,   102,     0,     0,
             0,     0,     0,     0,     0,     0,     0],
        [  101,  2059,  1045,  2699,  2000,  2424,  2070,  2126,  1997, 23581,
          2026,  2388,  1005,  1055,  5745,  1012,   102]])

Token ID 0 is the [PAD] token.

tokenizer.decode(0)
'[PAD]'

Note the attention masks:

two_sequence_inputs["attention_mask"]
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

There are a few different strategies for padding. Above, we had the tokenizer pad to the longest sequence in the input. But usually it’s best to set it to max_length:

two_sequence_inputs = tokenizer(
    [question, sentence],
    return_tensors = "pt",
    return_attention_mask = True,
    padding = "max_length"
)
two_sequence_inputs["input_ids"][0]
tensor([ 101, 2054, 2106, 1045, 2079, 2059, 1029,  102,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0])

This will pad the text out to the maximum number of tokens the model can process at once. This number is known as the context window.

print("Context window size:", tokenizer.model_max_length)
Context window size: 512

Warning

Not all tokenizers have this information stored in their configuration. You should always check whether this is the case before you use a tokenizer. If it doesn’t have this information, take a look at the model documentation.

If your input exceeds the number above, you will need to truncate it, otherwise the model may not process input properly.

too_long = "a " * 10_000
too_long_inputs = tokenizer(
    too_long, return_tensors = "pt", return_attention_mask = True
)
Token indices sequence length is longer than the specified maximum sequence length for this model (10002 > 512). Running this sequence through the model will result in indexing errors

Set truncation to True to avoid this problem.

too_long_inputs = tokenizer(
    too_long,
    return_tensors = "pt",
    return_attention_mask = True,
    padding = "max_length",
    truncation = True
)

What if you have long texts, like novels? You’ll need to make some decisions. You could, for example, look for a model with a bigger context window; several of the newest LLMs can process novel-length documents now. Or, you might strategically chunk your text. Perhaps you’re only interested in dialogue, or maybe paragraph-length descriptions. You could preprocess your texts to create chunks of this kind, ensure they do not exceed your context window size, and then send them to the model.

Regardless of what strategy you use, it will take iterative tries to settle on a final tokenization workflow.

16.3. Running the Model#

Now that we’ve tokenized our text, we can run our model. First, we move the model to a device (like a GPU, represented by 0 below). The transformers library is pretty good at doing this for us, but we can always set the device explicitly

device = 0 if torch.cuda.is_available() else "cpu"
bert.to(device)

print("Moved model to", device)
Moved model to cpu

Tip

You can also set the model device when initializing it.

model = AutoModel.from_pretrained(checkpoint, device = device)

Time to process the inputs. First, put the model in evaluation mode. This disables dropout, which can make outputs inconsistent (e.g. non-deterministic).

bert.eval()
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

Then, wrap the process in a context manager. This context manager will keep the model from collecting gradients when it processes. Unless you are training a model or trying understand model internals, there’s no need for gradients. With the context manager built, send the inputs to the model.

with torch.no_grad():
    outputs = bert(**inputs, output_hidden_states = True)

16.3.1. Model outputs#

There are several components in this output:

outputs
BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.1813, -0.1627, -0.2402,  ..., -0.1174,  0.2389,  0.5933],
         [-0.1219,  0.2374, -0.8745,  ...,  0.3379,  0.4232, -0.2547],
         [ 0.3440,  0.2197, -0.0133,  ..., -0.1566,  0.2564,  0.2016],
         ...,
         [ 0.5548, -0.4396,  0.7075,  ...,  0.1718, -0.1337,  0.4442],
         [ 0.5042,  0.1461, -0.2642,  ...,  0.0728, -0.4193, -0.3139],
         [ 0.4306,  0.1996, -0.0055,  ...,  0.1924, -0.5685, -0.3189]]]), pooler_output=tensor([[-9.0215e-01, -3.6355e-01, -8.6364e-01,  6.3150e-01,  5.2959e-01,
         -2.3988e-01,  7.8954e-01,  1.9048e-01, -7.7929e-01, -9.9996e-01,
         -1.1218e-01,  9.3984e-01,  9.7191e-01,  2.6334e-01,  8.9467e-01,
         -6.9052e-01, -3.4147e-01, -5.5349e-01,  2.6398e-01, -4.4079e-01,
          6.8211e-01,  9.9988e-01,  6.3717e-02,  1.4847e-01,  4.0085e-01,
          9.8006e-01, -5.9317e-01,  8.9493e-01,  9.5012e-01,  6.7092e-01,
         -5.0414e-01,  1.7098e-01, -9.8875e-01,  9.4528e-03, -7.9112e-01,
         -9.8928e-01,  1.5902e-01, -6.4964e-01,  1.5020e-01,  1.5166e-01,
         -9.0359e-01,  2.1367e-01,  9.9995e-01, -4.1530e-01,  4.3635e-01,
         -1.1558e-01, -1.0000e+00,  2.2613e-01, -8.9463e-01,  8.5809e-01,
          8.1223e-01,  9.0436e-01,  8.7882e-02,  4.7012e-01,  4.0199e-01,
         -2.3712e-01, -1.9334e-01,  4.6894e-04, -2.2115e-01, -5.5566e-01,
         -5.7362e-01,  3.4552e-01, -7.4646e-01, -8.7389e-01,  8.9060e-01,
          6.6513e-01, -2.7116e-02, -1.4315e-01, -4.6456e-02, -1.1431e-01,
          8.5752e-01,  4.0294e-02,  1.1696e-02, -8.3643e-01,  4.3631e-01,
          1.6847e-01, -5.5569e-01,  1.0000e+00, -4.2450e-01, -9.7485e-01,
          7.4691e-01,  6.4409e-01,  4.5656e-01,  7.3656e-02,  1.3891e-01,
         -1.0000e+00,  6.1069e-01,  8.5232e-02, -9.8454e-01,  5.9957e-02,
          5.2395e-01, -2.2201e-01,  1.2185e-01,  5.0690e-01, -3.0269e-01,
         -4.4167e-01, -1.1198e-01, -7.8108e-01, -9.9268e-02, -3.5597e-01,
         -1.5475e-02,  6.0668e-02, -2.7171e-01, -1.8928e-01,  2.8009e-01,
         -4.4744e-01, -6.4973e-01,  2.8635e-01,  3.2424e-03,  5.7549e-01,
          3.8756e-01, -2.1564e-01,  2.9823e-01, -9.4773e-01,  5.1873e-01,
         -2.5598e-01, -9.8738e-01, -5.5347e-01, -9.8613e-01,  6.5802e-01,
         -1.9799e-01, -2.3325e-01,  9.4350e-01,  7.4329e-02,  2.8480e-01,
          1.1823e-01, -9.1459e-01, -1.0000e+00, -7.1362e-01, -3.3832e-01,
         -1.5427e-02, -2.2573e-01, -9.7104e-01, -9.5031e-01,  5.0464e-01,
          9.3296e-01,  1.1413e-01,  9.9982e-01, -1.4094e-01,  9.2251e-01,
         -1.1937e-02, -5.9505e-01,  6.0277e-01, -3.5788e-01,  6.4097e-01,
         -5.7842e-02, -2.5286e-01,  2.0187e-01, -2.5656e-01,  4.5353e-01,
         -6.9229e-01, -3.8542e-03, -6.3896e-01, -8.9819e-01, -2.3500e-01,
          9.3484e-01, -4.6395e-01, -8.6960e-01, -3.4996e-02, -7.0628e-02,
         -3.8399e-01,  7.3087e-01,  7.5051e-01,  2.9726e-01, -3.3074e-01,
          3.7799e-01, -1.4909e-01,  4.0023e-01, -7.6468e-01, -4.4100e-02,
          4.0490e-01, -2.2321e-01, -6.1502e-01, -9.8328e-01, -2.5075e-01,
          5.6157e-01,  9.8075e-01,  6.5358e-01,  5.6203e-02,  7.9783e-01,
         -1.8338e-01,  7.3604e-01, -9.2419e-01,  9.7970e-01,  2.9938e-02,
          5.1264e-02, -8.8996e-04,  5.2048e-01, -8.5517e-01, -2.0053e-01,
          7.7441e-01, -6.9283e-01, -8.3609e-01, -9.7112e-04, -3.6640e-01,
         -2.3179e-01, -6.8952e-01,  5.6256e-01, -1.9093e-01, -1.7197e-01,
          1.5008e-01,  9.0820e-01,  9.2894e-01,  7.8880e-01,  1.3647e-01,
          6.8747e-01, -8.5234e-01, -3.4889e-01,  1.4948e-02,  7.0052e-02,
          8.3297e-02,  9.9249e-01, -5.5713e-01,  1.2017e-02, -9.3871e-01,
         -9.8170e-01, -2.3915e-01, -8.9892e-01, -4.3212e-02, -5.4737e-01,
          5.8934e-01, -2.9702e-01,  3.1357e-01,  3.1863e-01, -9.4778e-01,
         -6.9678e-01,  3.0899e-01, -4.9348e-01,  3.4331e-01, -3.1521e-01,
          9.4813e-01,  8.7397e-01, -4.9162e-01,  4.0090e-01,  9.2305e-01,
         -8.6857e-01, -7.5446e-01,  6.5415e-01, -2.4450e-01,  8.2804e-01,
         -5.3416e-01,  9.8093e-01,  8.6154e-01,  8.7709e-01, -8.8134e-01,
         -5.8875e-01, -7.7638e-01, -4.9128e-01,  6.3842e-02, -3.1175e-01,
          8.1792e-01,  5.7528e-01,  3.3582e-01,  6.7806e-01, -5.1993e-01,
          9.9162e-01, -9.6824e-01, -9.4774e-01, -5.3149e-01,  2.6096e-02,
         -9.8804e-01,  8.2734e-01,  1.6417e-01,  3.2679e-01, -4.0681e-01,
         -4.9522e-01, -9.5393e-01,  7.7904e-01, -1.4957e-03,  9.6117e-01,
         -2.7639e-01, -8.6071e-01, -6.4372e-01, -9.0329e-01, -2.5810e-01,
         -1.1203e-01, -1.1593e-01, -2.3476e-01, -9.4845e-01,  3.6636e-01,
          5.2653e-01,  5.2235e-01, -6.2701e-01,  9.9611e-01,  1.0000e+00,
          9.7235e-01,  8.7527e-01,  8.2946e-01, -9.9941e-01, -6.6086e-01,
          9.9998e-01, -9.8423e-01, -1.0000e+00, -9.1387e-01, -6.1058e-01,
          9.1625e-02, -1.0000e+00, -1.6898e-01,  1.8335e-01, -9.1385e-01,
          6.8013e-01,  9.7530e-01,  9.8308e-01, -1.0000e+00,  8.5240e-01,
          9.2787e-01, -5.7181e-01,  9.1319e-01, -3.5524e-01,  9.6975e-01,
          3.2755e-01,  5.3928e-01, -2.5199e-02,  2.4171e-01, -8.6792e-01,
         -7.3762e-01, -3.5711e-01, -7.3350e-01,  9.9496e-01,  9.6279e-02,
         -7.7655e-01, -8.4988e-01,  6.1927e-01,  2.1021e-03, -3.2598e-01,
         -9.5913e-01, -1.0941e-01,  5.3905e-01,  7.7692e-01,  2.0210e-01,
          1.3288e-01, -5.2597e-01,  1.3350e-01, -3.0387e-01,  3.1106e-02,
          6.0131e-01, -9.2876e-01, -4.0126e-01,  9.3837e-02, -9.1770e-02,
         -2.1734e-01, -9.5718e-01,  9.5094e-01, -2.4999e-01,  8.7258e-01,
          1.0000e+00,  4.4012e-01, -8.2275e-01,  5.4966e-01,  1.5370e-01,
          2.0310e-01,  1.0000e+00,  7.8576e-01, -9.7345e-01, -5.6520e-01,
          5.5103e-01, -4.8463e-01, -6.1582e-01,  9.9890e-01, -1.2441e-01,
         -5.9766e-01, -4.3516e-01,  9.7372e-01, -9.8673e-01,  9.8594e-01,
         -8.5766e-01, -9.6840e-01,  9.5956e-01,  9.2108e-01, -6.8813e-01,
         -7.0525e-01,  3.6422e-02, -4.2375e-01,  1.7284e-01, -9.3253e-01,
          7.1364e-01,  3.9647e-01, -9.4511e-02,  8.9084e-01, -5.6835e-01,
         -5.2339e-01,  1.4913e-01, -6.5024e-01, -1.9193e-01,  9.0409e-01,
          4.0446e-01, -7.4188e-02, -5.9329e-02, -1.0553e-01, -8.4495e-01,
         -9.6772e-01,  6.0419e-01,  1.0000e+00,  1.2257e-02,  7.9661e-01,
         -2.5697e-01,  7.9121e-02, -2.7145e-01,  3.9955e-01,  3.5015e-01,
         -1.9779e-01, -8.1081e-01,  6.1581e-01, -9.3205e-01, -9.8435e-01,
          5.6242e-01,  2.5414e-02, -1.9855e-01,  9.9998e-01,  3.9147e-01,
          3.8831e-02,  3.6355e-01,  9.7193e-01, -1.5554e-01,  3.0005e-01,
          7.3116e-01,  9.7749e-01, -1.4626e-01,  5.5644e-01,  7.9268e-01,
         -8.0457e-01, -2.1986e-01, -5.8049e-01, -1.1498e-01, -9.2331e-01,
          2.5465e-01, -9.5982e-01,  9.4562e-01,  9.3056e-01,  2.6739e-01,
         -4.9374e-04,  5.2062e-01,  1.0000e+00, -8.0677e-01,  3.9905e-01,
          2.6592e-01,  5.3715e-01, -9.9927e-01, -7.9586e-01, -3.2750e-01,
         -5.8726e-02, -6.6198e-01, -2.9297e-01,  1.0346e-01, -9.6175e-01,
          5.6368e-01,  5.5213e-01, -9.7025e-01, -9.8716e-01, -3.4926e-01,
          7.4946e-01,  6.1641e-02, -9.7373e-01, -7.1220e-01, -3.7798e-01,
          5.8977e-01, -1.0241e-01, -9.3295e-01,  2.2246e-02, -1.3604e-01,
          5.2007e-01, -8.4998e-02,  5.1492e-01,  7.3342e-01,  8.4501e-01,
         -5.2785e-01, -2.8822e-01,  4.6259e-02, -6.9614e-01,  8.7093e-01,
         -7.8254e-01, -8.6091e-01, -4.9411e-03,  1.0000e+00, -4.8026e-01,
          8.4091e-01,  6.7065e-01,  7.7482e-01, -1.2159e-01,  1.1097e-01,
          7.9307e-01,  2.5259e-01, -4.3484e-01, -7.9768e-01, -5.9230e-03,
         -2.7596e-01,  6.4743e-01,  4.9924e-01,  4.2030e-01,  7.4892e-01,
          7.1720e-01,  2.1605e-01,  1.7675e-01, -7.7313e-02,  9.9814e-01,
         -1.3775e-01, -1.5530e-01, -3.0964e-01,  4.3301e-02, -2.4627e-01,
          2.7069e-01,  1.0000e+00,  2.0974e-01,  4.2502e-01, -9.8813e-01,
         -7.9993e-01, -7.9667e-01,  1.0000e+00,  8.3059e-01, -8.1765e-01,
          7.4333e-01,  6.1189e-01,  6.8243e-02,  7.5832e-01, -1.0380e-02,
          1.1043e-03,  1.9780e-01, -1.8198e-02,  9.3912e-01, -5.1335e-01,
         -9.6651e-01, -5.2125e-01,  3.9677e-01, -9.5898e-01,  9.9963e-01,
         -5.3292e-01, -2.3007e-01, -4.3810e-01, -7.4668e-02, -4.8650e-01,
         -1.8025e-01, -9.8233e-01, -1.9585e-01,  1.0636e-01,  9.5299e-01,
          1.4254e-01, -5.2442e-01, -8.6130e-01,  6.9175e-01,  7.5675e-01,
         -9.0013e-01, -9.0459e-01,  9.4746e-01, -9.7303e-01,  6.2423e-01,
          1.0000e+00,  3.3112e-01, -9.2328e-02,  1.7466e-01, -4.8845e-01,
          3.1759e-01, -3.8244e-01,  6.9155e-01, -9.5553e-01, -2.3247e-01,
         -1.3807e-01,  2.2340e-01,  3.9980e-02, -6.1394e-01,  5.1713e-01,
          6.2565e-02, -4.7686e-01, -5.7325e-01,  3.2512e-02,  3.9323e-01,
          8.0339e-01, -3.3119e-02, -1.3022e-01,  2.2383e-01, -2.0161e-02,
         -8.4427e-01, -4.3153e-01, -4.6155e-01, -9.9996e-01,  4.6373e-01,
         -1.0000e+00,  3.2713e-01, -1.4122e-01, -2.0265e-01,  8.0648e-01,
          7.0225e-01,  6.5395e-01, -5.9549e-01, -7.6726e-01,  6.8309e-01,
          6.5727e-01, -2.7250e-01, -5.2437e-01, -5.6740e-01,  2.4622e-01,
          1.5573e-01,  1.4188e-01, -5.6907e-01,  6.8004e-01, -2.5054e-01,
          1.0000e+00,  2.2881e-02, -7.2110e-01, -9.5469e-01,  4.3091e-02,
         -2.4881e-01,  1.0000e+00, -8.4721e-01, -9.5246e-01,  1.7552e-01,
         -6.6395e-01, -7.8544e-01,  3.4256e-01, -8.3179e-02, -7.1474e-01,
         -8.8283e-01,  9.0191e-01,  7.6542e-01, -5.9564e-01,  5.3151e-01,
         -1.7647e-01, -5.0729e-01, -1.2652e-01,  7.7420e-01,  9.8289e-01,
          1.4510e-01,  8.0867e-01, -1.6427e-01, -3.5557e-01,  9.6864e-01,
          2.1303e-01, -3.8066e-03, -9.2923e-02,  1.0000e+00,  1.9231e-01,
         -8.8242e-01,  7.1092e-02, -9.8139e-01, -2.3762e-02, -9.4682e-01,
          2.8557e-01,  5.2677e-02,  8.9981e-01, -1.9667e-01,  9.5172e-01,
         -6.6690e-01,  6.6651e-04, -6.0149e-01, -2.4802e-01,  3.8019e-01,
         -8.9955e-01, -9.7870e-01, -9.8522e-01,  5.4832e-01, -4.1989e-01,
         -2.7977e-02,  1.0936e-01, -1.4552e-01,  2.4088e-01,  2.6210e-01,
         -1.0000e+00,  9.2344e-01,  3.4815e-01,  7.4503e-01,  9.6188e-01,
          6.9634e-01,  6.4311e-01,  6.2565e-02, -9.7912e-01, -9.6776e-01,
         -1.6542e-01, -1.7331e-01,  4.8728e-01,  5.5929e-01,  8.0049e-01,
          3.6028e-01, -2.9847e-01, -5.4233e-01, -3.9048e-01, -9.2027e-01,
         -9.9083e-01,  2.3925e-01, -5.0385e-01, -9.1661e-01,  9.3613e-01,
         -6.3137e-01, -2.6656e-02, -7.6423e-03, -6.3848e-01,  8.7719e-01,
          8.0430e-01,  1.3750e-01, -4.6426e-02,  3.8011e-01,  8.8338e-01,
          8.8459e-01,  9.7661e-01, -7.2297e-01,  6.2925e-01, -7.2241e-01,
          3.1132e-01,  8.6522e-01, -9.2078e-01,  4.3722e-02,  2.3552e-01,
         -9.4707e-02,  1.9644e-01, -1.0795e-01, -9.3186e-01,  7.4395e-01,
         -2.4304e-01,  2.8119e-01, -2.1058e-01,  2.3263e-01, -3.1718e-01,
         -2.5258e-02, -7.1409e-01, -6.0906e-01,  6.1541e-01,  1.9725e-01,
          8.6647e-01,  7.9473e-01,  1.4623e-01, -7.4865e-01,  4.9832e-02,
         -6.5079e-01, -8.6864e-01,  8.7466e-01,  9.9015e-02,  3.3872e-02,
          5.8198e-01, -3.2675e-01,  7.9461e-01,  8.8223e-02, -3.0361e-01,
         -2.4622e-01, -6.1891e-01,  8.8182e-01, -7.5603e-01, -3.8631e-01,
         -3.5338e-01,  6.3820e-01,  2.1275e-01,  9.9991e-01, -6.3115e-01,
         -8.5991e-01, -6.4168e-01, -1.8362e-01,  2.6631e-01, -4.0186e-01,
         -1.0000e+00,  3.6668e-01, -5.3633e-01,  5.8175e-01, -5.9615e-01,
          7.6011e-01, -6.7364e-01, -9.5899e-01, -4.1586e-02,  6.7226e-01,
          6.5868e-01, -4.6331e-01, -7.2867e-01,  4.7537e-01, -4.6924e-01,
          9.4955e-01,  8.0151e-01,  6.2790e-02,  4.0838e-01,  6.3502e-01,
         -6.3827e-01, -6.5047e-01,  8.9680e-01]]), hidden_states=(tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [-0.2008,  0.1479,  0.1878,  ...,  0.9505,  0.9427,  0.1835],
         [-0.3319,  0.4860, -0.1578,  ...,  0.5669,  0.7301,  0.1399],
         ...,
         [-0.1509,  0.1222,  0.4894,  ...,  0.0128, -0.1437, -0.0780],
         [-0.3884,  0.6414,  0.0598,  ...,  0.6821,  0.3488,  0.7101],
         [-0.5870,  0.2658,  0.0439,  ..., -0.1067, -0.0729, -0.0851]]]), tensor([[[-0.0422,  0.0229, -0.2086,  ...,  0.1785, -0.0790, -0.0525],
         [-0.5901,  0.1755, -0.0278,  ...,  1.0815,  1.6212,  0.1523],
         [ 0.0323,  0.8927, -0.2348,  ...,  0.0032,  1.3259,  0.2274],
         ...,
         [ 0.6683,  0.2020, -0.0523,  ...,  0.0027, -0.2793,  0.1329],
         [-0.1310,  0.5102, -0.1028,  ...,  0.3445,  0.0718,  0.6305],
         [-0.3432,  0.2476, -0.0468,  ..., -0.1301,  0.1246,  0.0411]]]), tensor([[[-0.1382, -0.2264, -0.4627,  ...,  0.3514,  0.0516, -0.0463],
         [-0.8300,  0.4672, -0.2483,  ...,  1.2602,  1.2012, -0.1328],
         [ 0.7289,  0.6790, -0.3091,  ..., -0.1309,  0.9835, -0.2290],
         ...,
         [ 0.8956,  0.3428,  0.0079,  ...,  0.2997, -0.3415,  0.7970],
         [-0.1553,  0.2835,  0.2071,  ...,  0.0758, -0.0326,  0.6186],
         [-0.3426,  0.0535,  0.0638,  ...,  0.0197,  0.1122, -0.1884]]]), tensor([[[-0.0770, -0.3675, -0.2666,  ...,  0.3117,  0.2467,  0.1323],
         [-0.3731, -0.0286, -0.1670,  ...,  0.6970,  1.5362, -0.3529],
         [ 0.7061,  0.4618, -0.2415,  ..., -0.0807,  0.8768, -0.2854],
         ...,
         [ 1.3325,  0.1663, -0.0099,  ...,  0.1685, -0.1381,  0.6110],
         [-0.3374,  0.1269,  0.1817,  ..., -0.0198, -0.0905,  0.3292],
         [-0.0850, -0.0934,  0.1007,  ...,  0.0459,  0.0579, -0.0371]]]), tensor([[[ 0.0599, -0.7039, -0.8094,  ...,  0.4053,  0.2542,  0.5017],
         [-0.7397, -0.5218, -0.1666,  ...,  0.6768,  1.5843, -0.2920],
         [ 0.8869,  0.5469, -0.3197,  ..., -0.0870,  0.5288,  0.1315],
         ...,
         [ 1.5591,  0.2863,  0.2924,  ...,  0.4971, -0.0800,  0.7023],
         [-0.3145,  0.1553, -0.0974,  ..., -0.1852, -0.3847,  0.5292],
         [-0.0261, -0.0488,  0.0042,  ...,  0.0081,  0.0475, -0.0346]]]), tensor([[[-0.0289, -0.7001, -0.6573,  ..., -0.0254,  0.2115,  0.5060],
         [-0.9080, -0.4675, -0.2327,  ...,  0.2051,  1.5554, -0.3402],
         [ 1.0436,  0.5098, -0.4004,  ..., -0.4537,  0.3073,  0.5464],
         ...,
         [ 1.8741,  0.1041, -0.1578,  ...,  0.5090,  0.0933,  0.9344],
         [ 0.2248,  0.2398, -0.3275,  ..., -0.2687, -0.5662,  0.7646],
         [-0.0183, -0.0432,  0.0123,  ...,  0.0138,  0.0110, -0.0385]]]), tensor([[[ 0.1700, -0.9118, -0.5099,  ..., -0.2153,  0.4185,  0.3388],
         [-0.5750, -0.5454, -0.3029,  ..., -0.1316,  1.3756, -0.3223],
         [ 0.8847,  0.6076, -0.5053,  ..., -0.5245,  0.0685,  0.3392],
         ...,
         [ 1.8617, -0.1778,  0.0593,  ..., -0.1164,  0.1354,  1.5028],
         [ 0.3238,  0.6568, -0.6567,  ..., -0.6430, -0.4393,  0.4841],
         [ 0.0172, -0.0527, -0.0179,  ..., -0.0102, -0.0174, -0.0409]]]), tensor([[[ 0.3411, -0.8139, -0.7188,  ..., -0.6404,  0.2390,  0.1338],
         [-0.6435, -0.1589, -0.1621,  ..., -0.0504,  0.9217, -0.4096],
         [ 0.7229,  0.5266, -0.7379,  ..., -0.5187,  0.0021,  0.3104],
         ...,
         [ 1.7987,  0.0404,  0.1860,  ..., -0.3626,  0.4451,  1.3464],
         [ 0.1577, -0.0492, -1.1795,  ..., -0.8191, -0.4314,  0.3754],
         [ 0.0079, -0.0187, -0.0308,  ..., -0.0261,  0.0054, -0.0522]]]), tensor([[[ 0.2597, -0.5194, -0.8438,  ..., -0.6873, -0.1183,  0.4508],
         [-0.5360,  0.0884, -0.3540,  ..., -0.2608,  0.5271, -0.4311],
         [ 0.3990,  0.4642, -0.6246,  ..., -0.5714,  0.1685,  0.5618],
         ...,
         [ 1.3260, -0.1660,  0.4866,  ...,  0.1439,  0.5888,  0.9798],
         [-0.2248, -0.3549, -1.2145,  ..., -0.7236, -0.3995,  0.3148],
         [ 0.0038, -0.0030,  0.0181,  ..., -0.0527, -0.0362, -0.0885]]]), tensor([[[ 0.2711, -0.3491, -0.6618,  ..., -0.1569,  0.0043,  0.3841],
         [-0.4096,  0.3449, -0.8822,  ...,  0.2367,  0.2244, -0.4131],
         [ 0.4250,  0.4963, -0.3541,  ..., -0.4456,  0.2106,  0.3286],
         ...,
         [ 1.1249, -0.2633,  0.2771,  ...,  0.2688,  0.2323,  0.7970],
         [ 0.1102,  0.2645, -0.9370,  ..., -0.3904, -0.3523,  0.1010],
         [-0.0321, -0.0416,  0.0300,  ..., -0.0738, -0.0530, -0.0741]]]), tensor([[[-0.0167, -0.2538, -0.4799,  ..., -0.0870, -0.4391,  0.3460],
         [-0.2158,  0.3668, -0.8787,  ...,  0.1046, -0.1264, -0.5901],
         [ 0.4833,  0.1214,  0.0037,  ..., -0.4762,  0.0543,  0.2185],
         ...,
         [ 0.8555, -0.2857,  0.6263,  ...,  0.5248,  0.1679,  0.6346],
         [ 0.0267,  0.0116, -0.0948,  ..., -0.0126, -0.0193,  0.0141],
         [-0.0377, -0.0243,  0.1689,  ...,  0.2037, -0.1910, -0.1169]]]), tensor([[[ 0.0439, -0.2886, -0.5210,  ..., -0.0585,  0.0057,  0.3484],
         [ 0.2003,  0.1950, -0.8941,  ...,  0.2855,  0.3792, -0.4433],
         [ 0.6422,  0.2077, -0.0531,  ..., -0.2940,  0.1614,  0.3406],
         ...,
         [ 0.8555, -0.3486,  0.6021,  ...,  0.2175,  0.1230,  0.5547],
         [ 0.0507,  0.0111, -0.0194,  ...,  0.0255, -0.0229,  0.0141],
         [ 0.0348, -0.0095, -0.0098,  ...,  0.0583, -0.0379, -0.0241]]]), tensor([[[-0.1813, -0.1627, -0.2402,  ..., -0.1174,  0.2389,  0.5933],
         [-0.1219,  0.2374, -0.8745,  ...,  0.3379,  0.4232, -0.2547],
         [ 0.3440,  0.2197, -0.0133,  ..., -0.1566,  0.2564,  0.2016],
         ...,
         [ 0.5548, -0.4396,  0.7075,  ...,  0.1718, -0.1337,  0.4442],
         [ 0.5042,  0.1461, -0.2642,  ...,  0.0728, -0.4193, -0.3139],
         [ 0.4306,  0.1996, -0.0055,  ...,  0.1924, -0.5685, -0.3189]]])), past_key_values=None, attentions=None, cross_attentions=None)

The last_hidden_state tensor contains the hidden states for each token after the final layer of the model. Every vector is a contextualized representation of a token. The shape of this tensor is (batch size, sequence length, hidden state size).

outputs.last_hidden_state
tensor([[[-0.1813, -0.1627, -0.2402,  ..., -0.1174,  0.2389,  0.5933],
         [-0.1219,  0.2374, -0.8745,  ...,  0.3379,  0.4232, -0.2547],
         [ 0.3440,  0.2197, -0.0133,  ..., -0.1566,  0.2564,  0.2016],
         ...,
         [ 0.5548, -0.4396,  0.7075,  ...,  0.1718, -0.1337,  0.4442],
         [ 0.5042,  0.1461, -0.2642,  ...,  0.0728, -0.4193, -0.3139],
         [ 0.4306,  0.1996, -0.0055,  ...,  0.1924, -0.5685, -0.3189]]])

The pooler_output tensor is usually the one you want to use if you are embedding text to use for some other purpose. It corresponds to the hidden state of the [CLS] token. Remember that models use this as a summary representation of the entire sequence. The shape of this tensor is (batch size, hidden state size).

outputs.pooler_output
tensor([[-9.0215e-01, -3.6355e-01, -8.6364e-01,  6.3150e-01,  5.2959e-01,
         -2.3988e-01,  7.8954e-01,  1.9048e-01, -7.7929e-01, -9.9996e-01,
         -1.1218e-01,  9.3984e-01,  9.7191e-01,  2.6334e-01,  8.9467e-01,
         -6.9052e-01, -3.4147e-01, -5.5349e-01,  2.6398e-01, -4.4079e-01,
          6.8211e-01,  9.9988e-01,  6.3717e-02,  1.4847e-01,  4.0085e-01,
          9.8006e-01, -5.9317e-01,  8.9493e-01,  9.5012e-01,  6.7092e-01,
         -5.0414e-01,  1.7098e-01, -9.8875e-01,  9.4528e-03, -7.9112e-01,
         -9.8928e-01,  1.5902e-01, -6.4964e-01,  1.5020e-01,  1.5166e-01,
         -9.0359e-01,  2.1367e-01,  9.9995e-01, -4.1530e-01,  4.3635e-01,
         -1.1558e-01, -1.0000e+00,  2.2613e-01, -8.9463e-01,  8.5809e-01,
          8.1223e-01,  9.0436e-01,  8.7882e-02,  4.7012e-01,  4.0199e-01,
         -2.3712e-01, -1.9334e-01,  4.6894e-04, -2.2115e-01, -5.5566e-01,
         -5.7362e-01,  3.4552e-01, -7.4646e-01, -8.7389e-01,  8.9060e-01,
          6.6513e-01, -2.7116e-02, -1.4315e-01, -4.6456e-02, -1.1431e-01,
          8.5752e-01,  4.0294e-02,  1.1696e-02, -8.3643e-01,  4.3631e-01,
          1.6847e-01, -5.5569e-01,  1.0000e+00, -4.2450e-01, -9.7485e-01,
          7.4691e-01,  6.4409e-01,  4.5656e-01,  7.3656e-02,  1.3891e-01,
         -1.0000e+00,  6.1069e-01,  8.5232e-02, -9.8454e-01,  5.9957e-02,
          5.2395e-01, -2.2201e-01,  1.2185e-01,  5.0690e-01, -3.0269e-01,
         -4.4167e-01, -1.1198e-01, -7.8108e-01, -9.9268e-02, -3.5597e-01,
         -1.5475e-02,  6.0668e-02, -2.7171e-01, -1.8928e-01,  2.8009e-01,
         -4.4744e-01, -6.4973e-01,  2.8635e-01,  3.2424e-03,  5.7549e-01,
          3.8756e-01, -2.1564e-01,  2.9823e-01, -9.4773e-01,  5.1873e-01,
         -2.5598e-01, -9.8738e-01, -5.5347e-01, -9.8613e-01,  6.5802e-01,
         -1.9799e-01, -2.3325e-01,  9.4350e-01,  7.4329e-02,  2.8480e-01,
          1.1823e-01, -9.1459e-01, -1.0000e+00, -7.1362e-01, -3.3832e-01,
         -1.5427e-02, -2.2573e-01, -9.7104e-01, -9.5031e-01,  5.0464e-01,
          9.3296e-01,  1.1413e-01,  9.9982e-01, -1.4094e-01,  9.2251e-01,
         -1.1937e-02, -5.9505e-01,  6.0277e-01, -3.5788e-01,  6.4097e-01,
         -5.7842e-02, -2.5286e-01,  2.0187e-01, -2.5656e-01,  4.5353e-01,
         -6.9229e-01, -3.8542e-03, -6.3896e-01, -8.9819e-01, -2.3500e-01,
          9.3484e-01, -4.6395e-01, -8.6960e-01, -3.4996e-02, -7.0628e-02,
         -3.8399e-01,  7.3087e-01,  7.5051e-01,  2.9726e-01, -3.3074e-01,
          3.7799e-01, -1.4909e-01,  4.0023e-01, -7.6468e-01, -4.4100e-02,
          4.0490e-01, -2.2321e-01, -6.1502e-01, -9.8328e-01, -2.5075e-01,
          5.6157e-01,  9.8075e-01,  6.5358e-01,  5.6203e-02,  7.9783e-01,
         -1.8338e-01,  7.3604e-01, -9.2419e-01,  9.7970e-01,  2.9938e-02,
          5.1264e-02, -8.8996e-04,  5.2048e-01, -8.5517e-01, -2.0053e-01,
          7.7441e-01, -6.9283e-01, -8.3609e-01, -9.7112e-04, -3.6640e-01,
         -2.3179e-01, -6.8952e-01,  5.6256e-01, -1.9093e-01, -1.7197e-01,
          1.5008e-01,  9.0820e-01,  9.2894e-01,  7.8880e-01,  1.3647e-01,
          6.8747e-01, -8.5234e-01, -3.4889e-01,  1.4948e-02,  7.0052e-02,
          8.3297e-02,  9.9249e-01, -5.5713e-01,  1.2017e-02, -9.3871e-01,
         -9.8170e-01, -2.3915e-01, -8.9892e-01, -4.3212e-02, -5.4737e-01,
          5.8934e-01, -2.9702e-01,  3.1357e-01,  3.1863e-01, -9.4778e-01,
         -6.9678e-01,  3.0899e-01, -4.9348e-01,  3.4331e-01, -3.1521e-01,
          9.4813e-01,  8.7397e-01, -4.9162e-01,  4.0090e-01,  9.2305e-01,
         -8.6857e-01, -7.5446e-01,  6.5415e-01, -2.4450e-01,  8.2804e-01,
         -5.3416e-01,  9.8093e-01,  8.6154e-01,  8.7709e-01, -8.8134e-01,
         -5.8875e-01, -7.7638e-01, -4.9128e-01,  6.3842e-02, -3.1175e-01,
          8.1792e-01,  5.7528e-01,  3.3582e-01,  6.7806e-01, -5.1993e-01,
          9.9162e-01, -9.6824e-01, -9.4774e-01, -5.3149e-01,  2.6096e-02,
         -9.8804e-01,  8.2734e-01,  1.6417e-01,  3.2679e-01, -4.0681e-01,
         -4.9522e-01, -9.5393e-01,  7.7904e-01, -1.4957e-03,  9.6117e-01,
         -2.7639e-01, -8.6071e-01, -6.4372e-01, -9.0329e-01, -2.5810e-01,
         -1.1203e-01, -1.1593e-01, -2.3476e-01, -9.4845e-01,  3.6636e-01,
          5.2653e-01,  5.2235e-01, -6.2701e-01,  9.9611e-01,  1.0000e+00,
          9.7235e-01,  8.7527e-01,  8.2946e-01, -9.9941e-01, -6.6086e-01,
          9.9998e-01, -9.8423e-01, -1.0000e+00, -9.1387e-01, -6.1058e-01,
          9.1625e-02, -1.0000e+00, -1.6898e-01,  1.8335e-01, -9.1385e-01,
          6.8013e-01,  9.7530e-01,  9.8308e-01, -1.0000e+00,  8.5240e-01,
          9.2787e-01, -5.7181e-01,  9.1319e-01, -3.5524e-01,  9.6975e-01,
          3.2755e-01,  5.3928e-01, -2.5199e-02,  2.4171e-01, -8.6792e-01,
         -7.3762e-01, -3.5711e-01, -7.3350e-01,  9.9496e-01,  9.6279e-02,
         -7.7655e-01, -8.4988e-01,  6.1927e-01,  2.1021e-03, -3.2598e-01,
         -9.5913e-01, -1.0941e-01,  5.3905e-01,  7.7692e-01,  2.0210e-01,
          1.3288e-01, -5.2597e-01,  1.3350e-01, -3.0387e-01,  3.1106e-02,
          6.0131e-01, -9.2876e-01, -4.0126e-01,  9.3837e-02, -9.1770e-02,
         -2.1734e-01, -9.5718e-01,  9.5094e-01, -2.4999e-01,  8.7258e-01,
          1.0000e+00,  4.4012e-01, -8.2275e-01,  5.4966e-01,  1.5370e-01,
          2.0310e-01,  1.0000e+00,  7.8576e-01, -9.7345e-01, -5.6520e-01,
          5.5103e-01, -4.8463e-01, -6.1582e-01,  9.9890e-01, -1.2441e-01,
         -5.9766e-01, -4.3516e-01,  9.7372e-01, -9.8673e-01,  9.8594e-01,
         -8.5766e-01, -9.6840e-01,  9.5956e-01,  9.2108e-01, -6.8813e-01,
         -7.0525e-01,  3.6422e-02, -4.2375e-01,  1.7284e-01, -9.3253e-01,
          7.1364e-01,  3.9647e-01, -9.4511e-02,  8.9084e-01, -5.6835e-01,
         -5.2339e-01,  1.4913e-01, -6.5024e-01, -1.9193e-01,  9.0409e-01,
          4.0446e-01, -7.4188e-02, -5.9329e-02, -1.0553e-01, -8.4495e-01,
         -9.6772e-01,  6.0419e-01,  1.0000e+00,  1.2257e-02,  7.9661e-01,
         -2.5697e-01,  7.9121e-02, -2.7145e-01,  3.9955e-01,  3.5015e-01,
         -1.9779e-01, -8.1081e-01,  6.1581e-01, -9.3205e-01, -9.8435e-01,
          5.6242e-01,  2.5414e-02, -1.9855e-01,  9.9998e-01,  3.9147e-01,
          3.8831e-02,  3.6355e-01,  9.7193e-01, -1.5554e-01,  3.0005e-01,
          7.3116e-01,  9.7749e-01, -1.4626e-01,  5.5644e-01,  7.9268e-01,
         -8.0457e-01, -2.1986e-01, -5.8049e-01, -1.1498e-01, -9.2331e-01,
          2.5465e-01, -9.5982e-01,  9.4562e-01,  9.3056e-01,  2.6739e-01,
         -4.9374e-04,  5.2062e-01,  1.0000e+00, -8.0677e-01,  3.9905e-01,
          2.6592e-01,  5.3715e-01, -9.9927e-01, -7.9586e-01, -3.2750e-01,
         -5.8726e-02, -6.6198e-01, -2.9297e-01,  1.0346e-01, -9.6175e-01,
          5.6368e-01,  5.5213e-01, -9.7025e-01, -9.8716e-01, -3.4926e-01,
          7.4946e-01,  6.1641e-02, -9.7373e-01, -7.1220e-01, -3.7798e-01,
          5.8977e-01, -1.0241e-01, -9.3295e-01,  2.2246e-02, -1.3604e-01,
          5.2007e-01, -8.4998e-02,  5.1492e-01,  7.3342e-01,  8.4501e-01,
         -5.2785e-01, -2.8822e-01,  4.6259e-02, -6.9614e-01,  8.7093e-01,
         -7.8254e-01, -8.6091e-01, -4.9411e-03,  1.0000e+00, -4.8026e-01,
          8.4091e-01,  6.7065e-01,  7.7482e-01, -1.2159e-01,  1.1097e-01,
          7.9307e-01,  2.5259e-01, -4.3484e-01, -7.9768e-01, -5.9230e-03,
         -2.7596e-01,  6.4743e-01,  4.9924e-01,  4.2030e-01,  7.4892e-01,
          7.1720e-01,  2.1605e-01,  1.7675e-01, -7.7313e-02,  9.9814e-01,
         -1.3775e-01, -1.5530e-01, -3.0964e-01,  4.3301e-02, -2.4627e-01,
          2.7069e-01,  1.0000e+00,  2.0974e-01,  4.2502e-01, -9.8813e-01,
         -7.9993e-01, -7.9667e-01,  1.0000e+00,  8.3059e-01, -8.1765e-01,
          7.4333e-01,  6.1189e-01,  6.8243e-02,  7.5832e-01, -1.0380e-02,
          1.1043e-03,  1.9780e-01, -1.8198e-02,  9.3912e-01, -5.1335e-01,
         -9.6651e-01, -5.2125e-01,  3.9677e-01, -9.5898e-01,  9.9963e-01,
         -5.3292e-01, -2.3007e-01, -4.3810e-01, -7.4668e-02, -4.8650e-01,
         -1.8025e-01, -9.8233e-01, -1.9585e-01,  1.0636e-01,  9.5299e-01,
          1.4254e-01, -5.2442e-01, -8.6130e-01,  6.9175e-01,  7.5675e-01,
         -9.0013e-01, -9.0459e-01,  9.4746e-01, -9.7303e-01,  6.2423e-01,
          1.0000e+00,  3.3112e-01, -9.2328e-02,  1.7466e-01, -4.8845e-01,
          3.1759e-01, -3.8244e-01,  6.9155e-01, -9.5553e-01, -2.3247e-01,
         -1.3807e-01,  2.2340e-01,  3.9980e-02, -6.1394e-01,  5.1713e-01,
          6.2565e-02, -4.7686e-01, -5.7325e-01,  3.2512e-02,  3.9323e-01,
          8.0339e-01, -3.3119e-02, -1.3022e-01,  2.2383e-01, -2.0161e-02,
         -8.4427e-01, -4.3153e-01, -4.6155e-01, -9.9996e-01,  4.6373e-01,
         -1.0000e+00,  3.2713e-01, -1.4122e-01, -2.0265e-01,  8.0648e-01,
          7.0225e-01,  6.5395e-01, -5.9549e-01, -7.6726e-01,  6.8309e-01,
          6.5727e-01, -2.7250e-01, -5.2437e-01, -5.6740e-01,  2.4622e-01,
          1.5573e-01,  1.4188e-01, -5.6907e-01,  6.8004e-01, -2.5054e-01,
          1.0000e+00,  2.2881e-02, -7.2110e-01, -9.5469e-01,  4.3091e-02,
         -2.4881e-01,  1.0000e+00, -8.4721e-01, -9.5246e-01,  1.7552e-01,
         -6.6395e-01, -7.8544e-01,  3.4256e-01, -8.3179e-02, -7.1474e-01,
         -8.8283e-01,  9.0191e-01,  7.6542e-01, -5.9564e-01,  5.3151e-01,
         -1.7647e-01, -5.0729e-01, -1.2652e-01,  7.7420e-01,  9.8289e-01,
          1.4510e-01,  8.0867e-01, -1.6427e-01, -3.5557e-01,  9.6864e-01,
          2.1303e-01, -3.8066e-03, -9.2923e-02,  1.0000e+00,  1.9231e-01,
         -8.8242e-01,  7.1092e-02, -9.8139e-01, -2.3762e-02, -9.4682e-01,
          2.8557e-01,  5.2677e-02,  8.9981e-01, -1.9667e-01,  9.5172e-01,
         -6.6690e-01,  6.6651e-04, -6.0149e-01, -2.4802e-01,  3.8019e-01,
         -8.9955e-01, -9.7870e-01, -9.8522e-01,  5.4832e-01, -4.1989e-01,
         -2.7977e-02,  1.0936e-01, -1.4552e-01,  2.4088e-01,  2.6210e-01,
         -1.0000e+00,  9.2344e-01,  3.4815e-01,  7.4503e-01,  9.6188e-01,
          6.9634e-01,  6.4311e-01,  6.2565e-02, -9.7912e-01, -9.6776e-01,
         -1.6542e-01, -1.7331e-01,  4.8728e-01,  5.5929e-01,  8.0049e-01,
          3.6028e-01, -2.9847e-01, -5.4233e-01, -3.9048e-01, -9.2027e-01,
         -9.9083e-01,  2.3925e-01, -5.0385e-01, -9.1661e-01,  9.3613e-01,
         -6.3137e-01, -2.6656e-02, -7.6423e-03, -6.3848e-01,  8.7719e-01,
          8.0430e-01,  1.3750e-01, -4.6426e-02,  3.8011e-01,  8.8338e-01,
          8.8459e-01,  9.7661e-01, -7.2297e-01,  6.2925e-01, -7.2241e-01,
          3.1132e-01,  8.6522e-01, -9.2078e-01,  4.3722e-02,  2.3552e-01,
         -9.4707e-02,  1.9644e-01, -1.0795e-01, -9.3186e-01,  7.4395e-01,
         -2.4304e-01,  2.8119e-01, -2.1058e-01,  2.3263e-01, -3.1718e-01,
         -2.5258e-02, -7.1409e-01, -6.0906e-01,  6.1541e-01,  1.9725e-01,
          8.6647e-01,  7.9473e-01,  1.4623e-01, -7.4865e-01,  4.9832e-02,
         -6.5079e-01, -8.6864e-01,  8.7466e-01,  9.9015e-02,  3.3872e-02,
          5.8198e-01, -3.2675e-01,  7.9461e-01,  8.8223e-02, -3.0361e-01,
         -2.4622e-01, -6.1891e-01,  8.8182e-01, -7.5603e-01, -3.8631e-01,
         -3.5338e-01,  6.3820e-01,  2.1275e-01,  9.9991e-01, -6.3115e-01,
         -8.5991e-01, -6.4168e-01, -1.8362e-01,  2.6631e-01, -4.0186e-01,
         -1.0000e+00,  3.6668e-01, -5.3633e-01,  5.8175e-01, -5.9615e-01,
          7.6011e-01, -6.7364e-01, -9.5899e-01, -4.1586e-02,  6.7226e-01,
          6.5868e-01, -4.6331e-01, -7.2867e-01,  4.7537e-01, -4.6924e-01,
          9.4955e-01,  8.0151e-01,  6.2790e-02,  4.0838e-01,  6.3502e-01,
         -6.3827e-01, -6.5047e-01,  8.9680e-01]])

Setting output_hidden_states = True had the model return all of the hidden states, from the first embedding layer to the very last layer. These are accessible from hidden_states. This is a tuple of tensors. Every tensor has the shape (batch size, sequence length, hidden state size).

outputs.hidden_states
(tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
          [-0.2008,  0.1479,  0.1878,  ...,  0.9505,  0.9427,  0.1835],
          [-0.3319,  0.4860, -0.1578,  ...,  0.5669,  0.7301,  0.1399],
          ...,
          [-0.1509,  0.1222,  0.4894,  ...,  0.0128, -0.1437, -0.0780],
          [-0.3884,  0.6414,  0.0598,  ...,  0.6821,  0.3488,  0.7101],
          [-0.5870,  0.2658,  0.0439,  ..., -0.1067, -0.0729, -0.0851]]]),
 tensor([[[-0.0422,  0.0229, -0.2086,  ...,  0.1785, -0.0790, -0.0525],
          [-0.5901,  0.1755, -0.0278,  ...,  1.0815,  1.6212,  0.1523],
          [ 0.0323,  0.8927, -0.2348,  ...,  0.0032,  1.3259,  0.2274],
          ...,
          [ 0.6683,  0.2020, -0.0523,  ...,  0.0027, -0.2793,  0.1329],
          [-0.1310,  0.5102, -0.1028,  ...,  0.3445,  0.0718,  0.6305],
          [-0.3432,  0.2476, -0.0468,  ..., -0.1301,  0.1246,  0.0411]]]),
 tensor([[[-0.1382, -0.2264, -0.4627,  ...,  0.3514,  0.0516, -0.0463],
          [-0.8300,  0.4672, -0.2483,  ...,  1.2602,  1.2012, -0.1328],
          [ 0.7289,  0.6790, -0.3091,  ..., -0.1309,  0.9835, -0.2290],
          ...,
          [ 0.8956,  0.3428,  0.0079,  ...,  0.2997, -0.3415,  0.7970],
          [-0.1553,  0.2835,  0.2071,  ...,  0.0758, -0.0326,  0.6186],
          [-0.3426,  0.0535,  0.0638,  ...,  0.0197,  0.1122, -0.1884]]]),
 tensor([[[-0.0770, -0.3675, -0.2666,  ...,  0.3117,  0.2467,  0.1323],
          [-0.3731, -0.0286, -0.1670,  ...,  0.6970,  1.5362, -0.3529],
          [ 0.7061,  0.4618, -0.2415,  ..., -0.0807,  0.8768, -0.2854],
          ...,
          [ 1.3325,  0.1663, -0.0099,  ...,  0.1685, -0.1381,  0.6110],
          [-0.3374,  0.1269,  0.1817,  ..., -0.0198, -0.0905,  0.3292],
          [-0.0850, -0.0934,  0.1007,  ...,  0.0459,  0.0579, -0.0371]]]),
 tensor([[[ 0.0599, -0.7039, -0.8094,  ...,  0.4053,  0.2542,  0.5017],
          [-0.7397, -0.5218, -0.1666,  ...,  0.6768,  1.5843, -0.2920],
          [ 0.8869,  0.5469, -0.3197,  ..., -0.0870,  0.5288,  0.1315],
          ...,
          [ 1.5591,  0.2863,  0.2924,  ...,  0.4971, -0.0800,  0.7023],
          [-0.3145,  0.1553, -0.0974,  ..., -0.1852, -0.3847,  0.5292],
          [-0.0261, -0.0488,  0.0042,  ...,  0.0081,  0.0475, -0.0346]]]),
 tensor([[[-0.0289, -0.7001, -0.6573,  ..., -0.0254,  0.2115,  0.5060],
          [-0.9080, -0.4675, -0.2327,  ...,  0.2051,  1.5554, -0.3402],
          [ 1.0436,  0.5098, -0.4004,  ..., -0.4537,  0.3073,  0.5464],
          ...,
          [ 1.8741,  0.1041, -0.1578,  ...,  0.5090,  0.0933,  0.9344],
          [ 0.2248,  0.2398, -0.3275,  ..., -0.2687, -0.5662,  0.7646],
          [-0.0183, -0.0432,  0.0123,  ...,  0.0138,  0.0110, -0.0385]]]),
 tensor([[[ 0.1700, -0.9118, -0.5099,  ..., -0.2153,  0.4185,  0.3388],
          [-0.5750, -0.5454, -0.3029,  ..., -0.1316,  1.3756, -0.3223],
          [ 0.8847,  0.6076, -0.5053,  ..., -0.5245,  0.0685,  0.3392],
          ...,
          [ 1.8617, -0.1778,  0.0593,  ..., -0.1164,  0.1354,  1.5028],
          [ 0.3238,  0.6568, -0.6567,  ..., -0.6430, -0.4393,  0.4841],
          [ 0.0172, -0.0527, -0.0179,  ..., -0.0102, -0.0174, -0.0409]]]),
 tensor([[[ 0.3411, -0.8139, -0.7188,  ..., -0.6404,  0.2390,  0.1338],
          [-0.6435, -0.1589, -0.1621,  ..., -0.0504,  0.9217, -0.4096],
          [ 0.7229,  0.5266, -0.7379,  ..., -0.5187,  0.0021,  0.3104],
          ...,
          [ 1.7987,  0.0404,  0.1860,  ..., -0.3626,  0.4451,  1.3464],
          [ 0.1577, -0.0492, -1.1795,  ..., -0.8191, -0.4314,  0.3754],
          [ 0.0079, -0.0187, -0.0308,  ..., -0.0261,  0.0054, -0.0522]]]),
 tensor([[[ 0.2597, -0.5194, -0.8438,  ..., -0.6873, -0.1183,  0.4508],
          [-0.5360,  0.0884, -0.3540,  ..., -0.2608,  0.5271, -0.4311],
          [ 0.3990,  0.4642, -0.6246,  ..., -0.5714,  0.1685,  0.5618],
          ...,
          [ 1.3260, -0.1660,  0.4866,  ...,  0.1439,  0.5888,  0.9798],
          [-0.2248, -0.3549, -1.2145,  ..., -0.7236, -0.3995,  0.3148],
          [ 0.0038, -0.0030,  0.0181,  ..., -0.0527, -0.0362, -0.0885]]]),
 tensor([[[ 0.2711, -0.3491, -0.6618,  ..., -0.1569,  0.0043,  0.3841],
          [-0.4096,  0.3449, -0.8822,  ...,  0.2367,  0.2244, -0.4131],
          [ 0.4250,  0.4963, -0.3541,  ..., -0.4456,  0.2106,  0.3286],
          ...,
          [ 1.1249, -0.2633,  0.2771,  ...,  0.2688,  0.2323,  0.7970],
          [ 0.1102,  0.2645, -0.9370,  ..., -0.3904, -0.3523,  0.1010],
          [-0.0321, -0.0416,  0.0300,  ..., -0.0738, -0.0530, -0.0741]]]),
 tensor([[[-0.0167, -0.2538, -0.4799,  ..., -0.0870, -0.4391,  0.3460],
          [-0.2158,  0.3668, -0.8787,  ...,  0.1046, -0.1264, -0.5901],
          [ 0.4833,  0.1214,  0.0037,  ..., -0.4762,  0.0543,  0.2185],
          ...,
          [ 0.8555, -0.2857,  0.6263,  ...,  0.5248,  0.1679,  0.6346],
          [ 0.0267,  0.0116, -0.0948,  ..., -0.0126, -0.0193,  0.0141],
          [-0.0377, -0.0243,  0.1689,  ...,  0.2037, -0.1910, -0.1169]]]),
 tensor([[[ 0.0439, -0.2886, -0.5210,  ..., -0.0585,  0.0057,  0.3484],
          [ 0.2003,  0.1950, -0.8941,  ...,  0.2855,  0.3792, -0.4433],
          [ 0.6422,  0.2077, -0.0531,  ..., -0.2940,  0.1614,  0.3406],
          ...,
          [ 0.8555, -0.3486,  0.6021,  ...,  0.2175,  0.1230,  0.5547],
          [ 0.0507,  0.0111, -0.0194,  ...,  0.0255, -0.0229,  0.0141],
          [ 0.0348, -0.0095, -0.0098,  ...,  0.0583, -0.0379, -0.0241]]]),
 tensor([[[-0.1813, -0.1627, -0.2402,  ..., -0.1174,  0.2389,  0.5933],
          [-0.1219,  0.2374, -0.8745,  ...,  0.3379,  0.4232, -0.2547],
          [ 0.3440,  0.2197, -0.0133,  ..., -0.1566,  0.2564,  0.2016],
          ...,
          [ 0.5548, -0.4396,  0.7075,  ...,  0.1718, -0.1337,  0.4442],
          [ 0.5042,  0.1461, -0.2642,  ...,  0.0728, -0.4193, -0.3139],
          [ 0.4306,  0.1996, -0.0055,  ...,  0.1924, -0.5685, -0.3189]]]))

Other optional outputs, which we don’t have here, include the following:

  • past_key_values: previously computed key and value matrices, which generative models can draw on to speed up computation

  • attentions: attention weights for every layer in the model

  • cross_attentions: layer-by-layer attention weights for models that work by attending to tokens across input pairs

16.3.2. Which layer? Which token?#

The next chapter will demonstrate a classification task with BERT. This involves modifying the network layers to output one of a set of labels for input. All this will happen inside the model itself, but it’s also perfectly fine to generate embeddings with a model and to use those embeddings for some other task that has nothing to do with a LLM.

People often use the last hidden state embeddings for other tasks, though there’s no hard and fast rule saying that this is necessary. The BERTology paper tells us that different layers in BERT do different things: earlier ones capture syntactic features, while later ones capture more semantic features. If you’re studying syntax, you might choose an earlier layer, or set of layers.

For general document embeddings, there are a number of options:

  • Instead of using the [CLS] token, mean pooling involves computing the mean of all tokens in the last hidden layer. This can potentially smooth out noise

  • Max pooling takes the max of all tokens’ last hidden layer embeddings. This boosts salient features in a sequence

  • Other people compute the mean of the last four layers and select [CLS] from that (though you could use all tokens, too); others take the sum of the last four layers. Both strategies combine information from a greater portion of the model

  • A concatenation of the last four layers (like appending layers in a list) is yet another option. This can potentially combine different levels of abstraction

Tip

The SentenceTransformers package will handle many of these pooling operations for you.

Finally, while using [CLS] is customary, it’s not necessary for all purposes and you can select another token if you feel it would be better. You can even train a classification model to learn from a different token, but be warned: one of the reasons [CLS] is customary is because this token is in every input sequence. The same cannot always be said of other tokens.