import torch
import torch.nn.functional as F
# Install needed packages iff running in Google Colab
import sys
if "google.colab" in sys.modules:
!pip install torchinfo
from torchinfo import summary
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"18 Working with Language Models
Using code to study, interact with, and customize LLMs
Open the live notebook in Google Colab.
This set of lecture notes involves random samples from a public language model, which means that there is some risk of the output containing offensive or inappropriate content.
In this set of notes, we’ll demonstrate several different ways to interact programmatically with large language models. We often hear about two extremes for models:
- Black-box (user experience): We can only interact with the model through a fixed interface, such as a web app or API. We have no access to the internal workings of the model, and we cannot modify it in any way.
- White-box (developer experience): We have full access to the model’s architecture, parameters, and training data. We can modify the model’s parameters, architecture, and training data as we see fit.
In this set of notes we’ll do something a bit in between; our goal is to flex our muscles as sophisticated, computationally literatre, and creatively mischievous users of LLMs. We’ll see how to programmatically achieve three primary tasks:
- Next-token prediction.
- Text generation.
- Fine-tuning.
Like last time, we’ll use the Hugging Face Transformers library to load an open-source version of GPT2.
from transformers import AutoTokenizer, AutoModelForCausalLM
checkpoint = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# Use slower eager attention to enable attention outputs
model = AutoModelForCausalLM.from_pretrained(checkpoint)
model.eval(); # Put model in evaluation mode rather than training mode
model.to(device)
summary(model)/Users/philchodrow/opt/anaconda3/envs/cs451/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Loading weights: 0%| | 0/148 [00:00<?, ?it/s]Loading weights: 100%|██████████| 148/148 [00:00<00:00, 6110.72it/s]
===========================================================================
Layer (type:depth-idx) Param #
===========================================================================
GPT2LMHeadModel --
├─GPT2Model: 1-1 --
│ └─Embedding: 2-1 38,597,376
│ └─Embedding: 2-2 786,432
│ └─Dropout: 2-3 --
│ └─ModuleList: 2-4 --
│ │ └─GPT2Block: 3-1 7,087,872
│ │ └─GPT2Block: 3-2 7,087,872
│ │ └─GPT2Block: 3-3 7,087,872
│ │ └─GPT2Block: 3-4 7,087,872
│ │ └─GPT2Block: 3-5 7,087,872
│ │ └─GPT2Block: 3-6 7,087,872
│ │ └─GPT2Block: 3-7 7,087,872
│ │ └─GPT2Block: 3-8 7,087,872
│ │ └─GPT2Block: 3-9 7,087,872
│ │ └─GPT2Block: 3-10 7,087,872
│ │ └─GPT2Block: 3-11 7,087,872
│ │ └─GPT2Block: 3-12 7,087,872
│ └─LayerNorm: 2-5 1,536
├─Linear: 1-2 38,597,376
===========================================================================
Total params: 163,037,184
Trainable params: 163,037,184
Non-trainable params: 0
===========================================================================
Next-Token Prediction
As a warmup, let’s see how to use this model to predict the next token in a sequence.
prompt = "This machine learning class is so"We follow our usual workflow:
- 1
- Tokenize the prompt and convert it to a PyTorch tensor.
- 2
- Pass the tokenized input through the model to get the output logits.
- 3
- Extract the logits for the next token (the last position in the sequence).
Let’s take a look at some of the top-scoring next tokens and their corresponding scores:
top_k_scores, top_k_indices = torch.topk(next_token_logits, k=10, dim=-1)
top_k_tokens = [tokenizer.decode([idx]) for idx in top_k_indices[0]]
print("Top-k next tokens and their scores:")
for token, score in zip(top_k_tokens, top_k_scores[0]):
print(f"{token:<10}: {score.item():.4f}")Top-k next tokens and their scores:
powerful : -112.4371
simple : -112.8910
good : -113.1194
much : -113.2813
easy : -113.3669
fast : -113.4712
popular : -113.7754
well : -113.9250
great : -113.9620
complex : -114.0690
To sample from the set of possible next tokens, we can again use the Boltzmann distribution.
def boltzmann_sample(preds, temperature = 1.0):
probabilities = torch.nn.Softmax(dim = 0)(preds / temperature)
return torch.multinomial(probabilities, num_samples=1).item()The following function wraps the complete logic from prompt to next-token in a single function:
def next_token(prompt, temperature=1.0):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
with torch.no_grad():
output = model(input_ids)
next_token_logits = output.logits[:, -1, :]
next_token_id = boltzmann_sample(next_token_logits.squeeze(), temperature)
next_token = tokenizer.decode([next_token_id])
return next_tokenLet’s take a look at some next tokens sampled from the model:
prompt = "This machine learning class is so"
temp = 1.0
sample = lambda: next_token(prompt, temperature=1.0)
print(f"Prompt: {prompt}")
print(f"Next tokens:")
for _ in range(10):
print(f" {sample():<20} {sample():<20} {sample():<20}")Prompt: This machine learning class is so
Next tokens:
terrible cool interesting
naturally much advanced
good powerful original
popular clear advanced
powerful cool Effective
effective fun heavy
simple amazingly big
exciting massive complex
complex crazy much
cool massive central
We observe that the model produces a range of next tokens, many (but not all) of which are coherent continuations of the prompt. Turning down the temperature leads to more reliable but also more boring results:
print(f"Prompt: {prompt}")
print(f"Next tokens:")
sample = lambda: next_token(prompt, temperature=0.5)
for _ in range(10):
print(f" {sample():<20} {sample():<20} {sample():<20}")Prompt: This machine learning class is so
Next tokens:
simple well powerful
powerful powerful fast
simple much easy
easy good great
powerful powerful complex
powerful advanced powerful
powerful fast important
popular simple much
fast well good
powerful powerful simple
Sensitivity to Prompt
Model tokenizers can be sensitive to the exact formatting of the prompt. For example, suppose we added a single space to the prompt above:
prompt += " "
sample = lambda: next_token(prompt, temperature=1.0)
print(f"Prompt: {prompt}")
print(f"Next tokens:")
for _ in range(10):
print(f" {sample():<20} {sample():<20} {sample():<20}")Prompt: This machine learning class is so
Next tokens:
ia ia
icky
izzy
ive
icky ive
iced
icky
icky
ical ________ ippy
The set of likely next tokens has changed dramatically: the output tends to be more negative in sentiment and less likely to be a coherent prompt continuation.
Text Generation
Like last time, we can use next-token prediction in the context of a recurrent pipeline to generate sequences of synthetic text:
import textwrap
def generate_text(prompt, max_length=50, temperature=1.0, wrap = True):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generated_tokens = input_ids.squeeze().tolist()
for _ in range(max_length):
with torch.no_grad():
output = model(input_ids)
next_token_logits = output.logits[:, -1, :]
next_token_id = boltzmann_sample(next_token_logits.squeeze(), temperature)
generated_tokens.append(next_token_id)
input_ids = torch.tensor([generated_tokens]).to(device)
generated_text = tokenizer.decode(generated_tokens)
if wrap:
generated_text = textwrap.fill(generated_text, width=80)
return generated_text- Tokenize the initial prompt and prepare it as input to the model.
- Pass the current sequence of tokens through the model to get the output logits.
- Extract the logits for the next token (the last position in the sequence).
- Sample the next token ID from the Boltzmann distribution.
- Append the new token ID to the sequence of generated tokens.
- Decode the complete sequence of tokens back into text.
Let’s try it out:
prompt = "Fox in socks"
generated_text = generate_text(prompt, max_length=50, temperature=1.0)
print(generated_text)Fox in socks Jake Tapper on how frustrating and embarrassing it was to watch his
boss bring up RatPac's claim that he had to kick 9/11 and finish surfing. While
it's the common perception of mainstream media outlets in generalwn studio
habits to encourage
As usual, modulating the temperature can lead to more or less coherent results:
temperatures = [0.01, 0.1, 0.3, 0.8, 1.5]
for temp in temperatures:
generated_text = generate_text(prompt, max_length=50, temperature=temp)
print()
print(f"\nTemperature: {temp}")
print("-" * 40)
print(generated_text)
Temperature: 0.01
Fox in socks. The first thing I noticed was that the socks were not very
comfortable. I was wearing a pair of socks that were very comfortable. I was
wearing a pair of socks that were very comfortable. I was wearing a pair of
socks that were
Temperature: 0.1
Fox in socks. The first thing I noticed was that the socks were very soft. I
was surprised that they were so soft. I was surprised that they were so soft. I
was surprised that they were so soft. I was surprised that they
Temperature: 0.3
Fox in socks and a pair of socks. "I'm not going to say it's a bad thing, but I
think it's a good thing," he said. "It's a good thing that the kids are going to
be able to play with
Temperature: 0.8
Fox in socks, challenging the US Supreme Court decision to overturn the
Affordable Care Act's mandate that most people have insurance. After winning
the highest court battle since the 1980s, the conservative justices have hung on
to a large majority of their rulings, but the
Temperature: 1.5
Fox in socks Pull neck brands with designed lime doors NeonGear Madden BMDE
Sweats recalled Bellevueense SoulsCense Dissmusimages Spice test Quarter Magnum
G401 Velocity 455Ws Renegade emphasized Expressan billionaire Leonard Rus
Sunrise 1000 downloads Lightning Trucks 5202 Debor
For very low temperatures, the model rapidly gets “frozen” in a loop of text, while for higher temperatures the model produces apparently random text; for intermediate values the generated text appears coherent and in some sense interesting (or at least amusing).
Fine-Tuning
If, however, you are a die-hard Dr. Seuss fan, you’ll be disappointed by these results: considering that the prompt is “Fox in Socks,” the generated text is not at all Seussian and usually appears to forget both about the Fox and his socks. To address this fatal shortcoming, we’ll fine-tune the model on our Dr. Seuss text. This involves essentially the same process of training a language model as we saw when we built our own model from scratch, except now that we begin with a pre-trained model. It’s often sufficient in this kind of experiment to train for only a few batches.
First we’ll retrieve our training data, which is the text of Dr. Seuss’s “Fox in Socks.”
import urllib.request
url = "https://raw.githubusercontent.com/PhilChodrow/ml-notes-update/refs/heads/main/data/fox-in-socks.txt"
text = "\n".join([line.decode('utf-8').strip() for line in urllib.request.urlopen(url)])Next, as usual, we need a data set and data loader.
from torch.utils.data import Dataset, DataLoader
class NextTokenDataset(Dataset):
def __init__(self, tokens, context_length = 10):
self.context_length = context_length
self.tokens = tokens
self.vocab_length = len(set(tokens))
def __len__(self):
return len(self.tokens) - self.context_length
def __getitem__(self, key):
target_token = self.tokens[self.context_length + key]
target = torch.tensor(target_token)
feature_tokens = self.tokens[key:(self.context_length + key)]
feature_tensor = torch.tensor(feature_tokens, dtype=torch.long)
return feature_tensor.to(device), target.to(device)
data_set = NextTokenDataset(tokenizer.encode(text), context_length=10)
data_loader = DataLoader(data_set, batch_size=32, shuffle=True)Token indices sequence length is longer than the specified maximum sequence length for this model (1480 > 1024). Running this sequence through the model will result in indexing errors
x, y = next(iter(data_loader))
print("Feature tensor shape:", x.shape)
print("Target tensor shape:", y.shape)Feature tensor shape: torch.Size([32, 10])
Target tensor shape: torch.Size([32])
Now we’re ready! The following is a standard training loop for attention-based next-token prediction models. ; the only difference is that we also generate text at regular intervals to see how the model is doing as it trains.
prompt = "Fox in socks"
gen_dict = {}
gen_dict[0] = generate_text(prompt, max_length=50, temperature=1.0, wrap=True)
from torch import optim
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
batch = 0
for epoch in range(10):
for features, target in data_loader:
optimizer.zero_grad()
output = model(features).logits[:, -1, :]
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch % 10 == 0:
gen_text = generate_text(prompt, max_length=50, temperature=1.0, wrap=False)
gen_dict[batch] = gen_text
print(f"Batch {batch}, Loss: {loss.item():.4f}")
batch += 1
if batch >= 100: # Limit to 100 batches for demonstration purposes
breakBatch 0, Loss: 5.3515
Batch 10, Loss: 3.5355
Batch 20, Loss: 3.2273
Batch 30, Loss: 4.4418
Batch 40, Loss: 3.6557
Batch 50, Loss: 1.6089
Batch 60, Loss: 1.4152
Batch 70, Loss: 1.3767
Batch 80, Loss: 1.5732
Batch 90, Loss: 1.4203
Batch 100, Loss: 0.7501
Let’s take a look:
for batch, gen_text in gen_dict.items():
print("\n\n")
print("-" * 40)
print(f"Batch {batch}")
print("-" * 40)
print(gen_text)
Batch 0
Fox in socks on Tumblr.
All through the 2011 season, Wilson found himself on many Top 200 lists, and fifth this season. He was on a national cast battle with Russell Wilson, Harbaugh and possibly Robert Griffin III...a year ago. Only the 2014
Batch 10
Fox in socks and socks.
I come from my socks!
Okay xo.<|endoftext|>She says nasty. Something is coming. "Do not move" She sa
!
.
To reeeeey !
Batch 20
Fox in socks.
On a horset wheel.
Major.
Volley.
Merlin in socks.
William reports to Mary that. Ooh! I'll go with some guitar.
Mary keeps the hamlet in a boat by
Batch 30
Fox in socks and slime on trees.
Licks pushed an in-between.<|endoftext|>Shame on you, sir.
You get this hurt tweetle on a whale's face fīm wan it's done, sir.
Batch 40
Fox in socks,
Martino's wicks in sack of bricks.
Then there's Jessie and Riley, too.
My mind's fully
sounds.
NO, I'll say it in Kyzy Joe
Batch 50
Fox in socks.
Hose goes.
Socks go.
Sloe...
Bum... wore socks on cats on battle paddocks.
Big blasts!!!!
Bim parks his nose on Knox's socks and when it's done,
Batch 60
Fox in socks.
Socks on fox in box.
Big Blue.
Socks that rolled in!
Big Blue in Knox and Knox in broom.
Knox on box, sir.
This box is protected.
Batch 70
Fox in socks.
Slow Joe Crow. poodle find it a new sweet noodle in the box.
Socks on Knox and Knox and Knox
Cross in to the sneeze, and find what thing makes an oo
Batch 80
Fox in socks, get's back on BROWN's broom up.
Closer now, Mr. Knox, sir.
Closer now, sir.
Closer now, sir.
Closer now. Slow, Broom, Mr.
Batch 90
Fox in socks!
Luke Luck sneezy...
Luke Luck is a fixit tick can chew... !
Luke Luck catches ticks tick and chew. Luke likes chewy cheese.
Luke's duck likes that.
Luke's duck likes
Batch 100
Fox in socks box.
My socks are sewspeeps and ticks.
My socks made these three socks.
Knox in socks. Bim in socks.
Sue sews poodle.
Fox in socks.
At first, the model essentially ignores the prompt and produces text in the style of the pre-trained GPT2. However, as we continue to train the model, the generated text more and more resembles Fox in Socks, at least as measured by the frequency of words which appear in the text. In some cases, we also see flashes of Seussian style, including rhymes and repetition.
© Phil Chodrow, 2025