Home

Awesome

RuCLIP

Zero-shot image classification model for Russian language


RuCLIP (Russian Contrastive Language–Image Pretraining) is a multimodal model for obtaining images and text similarities and rearranging captions and pictures. RuCLIP builds on a large body of work on zero-shot transfer, computer vision, natural language processing and multimodal learning. This repo has the prototypes model of OpenAI CLIP's Russian version following this paper.

Models

Installing

pip install ruclip==0.0.2

Usage

Open In Colab Standart RuCLIP API

Open In Colab RuCLIP + SberVqgan

Open In Colab ONNX example

Init models

import ruclip

device = 'cuda'
clip, processor = ruclip.load('ruclip-vit-base-patch32-384', device=device)

Zero-Shot Classification [Minimal Example]

import torch
import base64
import requests
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO

# prepare images
bs4_urls = requests.get('https://raw.githubusercontent.com/ai-forever/ru-dolph/master/pics/pipelines/cats_vs_dogs_bs4.json').json()
images = [Image.open(BytesIO(base64.b64decode(bs4_url))) for bs4_url in bs4_urls]

# prepare classes
classes = ['кошка', 'собака']
templates = ['{}', 'это {}', 'на картинке {}', 'это {}, домашнее животное']

# predict
predictor = ruclip.Predictor(clip, processor, device, bs=8, templates=templates)
with torch.no_grad():
    text_latents = predictor.get_text_latents(classes)
    pred_labels = predictor.run(images, text_latents)

# show results
f, ax = plt.subplots(2,4, figsize=(12,6))
for i, (pil_img, pred_label) in enumerate(zip(images, pred_labels)):
    ax[i//4, i%4].imshow(pil_img)
    ax[i//4, i%4].set_title(classes[pred_label])

Cosine similarity Visualization Example

Softmax Scores Visualization Example

Linear Probe and ZeroShot Correlation Results

Linear Probe Example

train = CIFAR100(root, download=True, train=True)
test = CIFAR100(root, download=True, train=False)

with torch.no_grad():
    X_train = predictor.get_image_latents((pil_img for pil_img, _ in train)).cpu().numpy()
    X_test = predictor.get_image_latents((pil_img for pil_img, _ in test)).cpu().numpy()
    y_train, y_test = np.array(train.targets), np.array(test.targets)

clf = LogisticRegression(solver='lbfgs', penalty='l2', max_iter=1000, verbose=1)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
accuracy = np.mean((y_test == y_pred).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

>>> Accuracy = 75.680

Performance

We have evaluated the performance zero-shot image classification on the following datasets:

DatasetruCLIP Base [vit-base-patch32-224]ruCLIP Base [vit-base-patch16-224]ruCLIP Large [vit-large-patch14-224]ruCLIP Base [vit-base-patch32-384]ruCLIP Large [vit-large-patch14-336]ruCLIP Base [vit-base-patch16-384]CLIP [vit-base-patch16-224] original + OPUS-MTCLIP [vit-base-patch16-224] original
Food101, acc0.5050.5520.5970.6420.712💥0.6890.6640.883
CIFAR10, acc0.8180.8100.8780.8620.906💥0.8450.8590.893
CIFAR100, acc0.5040.4960.5110.5290.5910.5690.603💥0.647
Birdsnap, acc0.1150.1170.1720.1610.213💥0.1950.1260.396
SUN397, acc0.4520.4620.4840.5100.523💥0.5210.4470.631
Stanford Cars, acc0.4330.4870.5590.5720.659💥0.6260.5670.638
DTD, acc0.3800.4010.3700.3900.4080.421💥0.2430.432
MNIST, acc0.4470.4640.3370.4040.2420.4780.559💥0.559
STL10, acc0.9320.9320.9340.9460.9560.9640.967💥0.970
PCam, acc0.5010.5050.5200.5060.5540.5010.603💥0.573
CLEVR, acc0.1480.1280.1520.1880.1420.1320.240💥0.240
Rendered SST2, acc0.4890.5270.5290.5080.539💥0.5250.4840.484
ImageNet, acc0.3750.4010.4260.4510.488💥0.4820.3920.638
FGVC Aircraft, mean-per-class0.0330.0430.0460.0530.0750.0460.220💥0.244
Oxford Pets, mean-per-class0.5600.5950.6040.5870.5460.635💥0.5070.874
Caltech101, mean-per-class0.7860.7750.7770.8340.835💥0.835💥0.7920.883
Flowers102, mean-per-class0.4010.3880.4550.4490.517💥0.4520.3570.697
Hateful Memes, roc-auc0.5640.5160.5300.5370.5190.5430.579💥0.589

And for linear-prob evaluation:

DatasetruCLIP Base [vit-base-patch32-224]ruCLIP Base [vit-base-patch16-224]ruCLIP Large [vit-large-patch14-224]ruCLIP Base [vit-base-patch32-384]ruCLIP Large [vit-large-patch14-336]ruCLIP Base [vit-base-patch16-384]CLIP [vit-base-patch16-224] original
Food1010.7650.8270.8400.8510.896💥0.8900.901
CIFAR100.9170.9220.9270.9340.943💥0.9420.953
CIFAR1000.7160.7390.7340.7450.7700.773💥0.808
Birdsnap0.3470.5030.5670.4340.6090.612💥0.664
SUN3970.6830.7210.7310.7210.759💥0.7580.777
Stanford Cars0.6970.7760.7970.7660.8310.840💥0.866
DTD0.6900.7340.7110.7030.7310.749💥0.770
MNIST0.9630.974💥0.9490.9650.9490.9710.989
STL100.9570.9620.9730.9680.981💥0.9740.982
PCam0.8270.8230.7910.8350.8070.846💥0.830
CLEVR0.3560.3600.3580.3080.3180.378💥0.604
Rendered SST20.6030.6550.6510.6510.6370.661💥0.606
FGVC Aircraft0.2540.3120.2900.2830.3410.362💥0.604
Oxford Pets0.7740.8200.8190.7300.7530.856💥0.931
Caltech1010.9040.9170.9140.9220.937💥0.9320.956
HatefulMemes0.5450.5680.5630.5810.585💥0.5780.645

Also, we have created speed comparison based on CIFAR100 dataset using Nvidia-V100 for evaluation:

ruclip-vit-base-patch32-224ruclip-vit-base-patch16-224ruclip-vit-large-patch14-224ruclip-vit-base-patch32-384ruclip-vit-large-patch14-336ruclip-vit-base-patch16-384
iter/sec308.84 💥155.3549.95147.2622.1161.79

Authors

Supported by

<img src="https://raw.githubusercontent.com/ai-forever/ru-dolph/master/pics/logo/airi-logo.png" height="50"/>

Social Media