1
!pip install -q datasets
[notice] A new release of pip is available: 23.1.1 -> 23.2.1
[notice] To update, run: pip install --upgrade pip

教程网址:
Fine-tune BLIP using Hugging Face transformers and datasets 🤗
https://colab.research.google.com/drive/1lbqiSiA0sDF7JDWPeS0tccrM85LloVha?usp=sharing

讨论区:
https://github.com/salesforce/BLIP/issues/37

数据集制作方法:
https://colab.research.google.com/drive/1HLxgrG7xZJ9FvXckNG61J72FkyrbqKAA?usp=sharing

1
2
3
4
5
import json

root = "image_in_rainly_river/"

from datasets import load_dataset
/home/huyifan/anaconda3/envs/yolov5/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
1
dataset = load_dataset("imagefolder",data_dir=root,split="train")
Resolving data files: 100%|████████████████████████████████████████████████████████████████| 171/171 [00:00<00:00, 201524.58it/s]
1
dataset
Dataset({
    features: ['image', 'caption'],
    num_rows: 170
})
1
2
3
4
example = dataset[0]
image = example["image"]
width, height = image.size
display(image.resize((int(0.3*width), int(0.3*height))))

png

1
example["caption"]
'a narrow river in the city with buildings on both sides'
1
from torch.utils.data import Dataset
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class ImageCaptioningDataset(Dataset):
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
item = self.dataset[idx]

encoding = self.processor(images=item["image"], text=item["caption"], padding="max_length",return_tensors="pt")

# remove batch dimension
encoding = {k:v.squeeze() for k,v in encoding.items()}

return encoding
1
from transformers import Blip2Processor, Blip2ForConditionalGeneration
1
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
1
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
Loading checkpoint shards: 100%|██████████████████████████████████████████████████| 2/2 [00:14<00:00,  7.09s/it]
1
train_dataset = ImageCaptioningDataset(dataset, processor)
1
2
3
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1)
1
2
3
4
5
import torch

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
model.train()
for epoch in range(50):
print("Epoch:", epoch)
for idx, batch in enumerate(train_dataloader):
input_ids = batch.pop("input_ids").to(device)
pixel_values = batch.pop("pixel_values").to(device)

outputs = model(input_ids=input_ids,
pixel_values=pixel_values,
labels=input_ids)

loss = outputs.loss

print("Loss:", loss.item())

loss.backward()

optimizer.step()
optimizer.zero_grad()