<!-- Autogenerated by `scripts/make_examples.py` -->
<table align="left">
    <td>
        <a target="_blank" href="https://colab.research.google.com/github/voxel51/fiftyone-examples/blob/master/examples/ai_telephone.ipynb">
            <img src="https://user-images.githubusercontent.com/25985824/104791629-6e618700-5769-11eb-857f-d176b37d2496.png" height="32" width="32">
            Try in Google Colab
        </a>
    </td>
    <td>
        <a target="_blank" href="https://nbviewer.jupyter.org/github/voxel51/fiftyone-examples/blob/master/examples/ai_telephone.ipynb">
            <img src="https://user-images.githubusercontent.com/25985824/104791634-6efa1d80-5769-11eb-8a4c-71d6cb53ccf0.png" height="32" width="32">
            Share via nbviewer
        </a>
    </td>
    <td>
        <a target="_blank" href="https://github.com/voxel51/fiftyone-examples/blob/master/examples/ai_telephone.ipynb">
            <img src="https://user-images.githubusercontent.com/25985824/104791633-6efa1d80-5769-11eb-8ee3-4b2123fe4b66.png" height="32" width="32">
            View on GitHub
        </a>
    </td>
    <td>
        <a href="https://github.com/voxel51/fiftyone-examples/raw/master/examples/ai_telephone.ipynb" download>
            <img src="https://user-images.githubusercontent.com/25985824/104792428-60f9cc00-576c-11eb-95a4-5709d803023a.png" height="32" width="32">
            Download notebook
        </a>
    </td>
</table>


# AI Telephone

This notebook walks you through how to play a game of multimodal AI Telephone!

Here’s how the game of AI Telephone works:

1. Each “game” will pair up an image-to-text (I2T) model with a text-to-image (T2I) model
2. Given an initial prompt, we use the T2I model to generate an image.
3. We then pass this image into the I2T model to generate a description.
4. We repeat steps 2 and 3 a fixed number of times `n` (in our case `n=10`).
5. Finally, we quantify the difference between the original prompt and the final description.

## Setup

To run this code, you will need to install the [FiftyOne open source library](https://github.com/voxel51/fiftyone) for dataset curation, the [OpenAI Python Library](https://github.com/openai/openai-python), and the [Replicate Python client](https://replicate.com/docs/get-started/python).

In [None]:
!pip install fiftyone openai replicate

We will import all of the necessary modules:

In [None]:
import hashlib
import os
import requests

import openai
import replicate

import fiftyone as fo
from fiftyone import ViewField as F

## Text-to-Image Models

First we define the base class:

In [1]:
class Text2Image(object):
    """Wrapper for a Text2Image model."""
    def __init__(self):
        self.name = None
        self.model_name = None

    def generate_image(self, text):
        response = replicate.run(self.model_name, input={"prompt": text})
        if type(response) == list:
            response = response[0]
        return response

Then we create a class for each model:

In [2]:
class StableDiffusion(Text2Image):
    """Wrapper for a StableDiffusion model."""
    def __init__(self):
        self.name = "stable-diffusion"
        self.model_name = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478"


class VQGANCLIP(Text2Image):
    """Wrapper for a VQGAN-CLIP model."""
    def __init__(self):
        self.name = "vqgan-clip"
        self.model_name = "mehdidc/feed_forward_vqgan_clip:28b5242dadb5503688e17738aaee48f5f7f5c0b6e56493d7cf55f74d02f144d8"


class DALLE2(Text2Image):
    """Wrapper for a DALL-E 2 model."""
    def __init__(self):
        self.name = "dalle-2"
    
    def generate_image(self, text):
        response = openai.Image.create(
            prompt=text,
            n=1,
            size="512x512"
        )
        return response['data'][0]['url']

## Image-to-Text Models

Once again, we define the base class:

In [3]:
class Image2Text(object):
    """Wrapper for an Image2Text model."""
    def __init__(self):
        self.name = None
        self.model_name = None
        self.task_description = "Write a detailed description of this image."

    def _clean_response(self, response):
        response = response.lower()
        phrases = ["caption: ", "the image shows ", "the image features"]
        for phrase in phrases:
            if phrase in response:
                response = response.split(phrase)[1].strip()
        return response
    
    def _generate_text(self, image_url):
        response = replicate.run(
            self.model_name, 
            input={
                "image": image_url,
                "prompt": self.task_description,
                }
            )
        return response

    def generate_text(self, image_url):
        response = self._generate_text(image_url)
        response = self._clean_response(response)
        return response

Then we create a class for each model:

In [4]:
class BLIP(Image2Text):
    """Wrapper for a BLIP model."""
    def __init__(self):
        super().__init__()
        self.name = "blip"
        self.model_name = "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746"

class CLIPPrefix(Image2Text):
    """Wrapper for a CLIPPrefixCaptioning model."""
    def __init__(self):
        super().__init__()
        self.name = "clip-prefix"
        self.model_name = "rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8"

class MiniGPT4(Image2Text):
    """Wrapper for a MiniGPT-4 model."""
    def __init__(self):
        super().__init__()
        self.name = "minigpt-4"
        self.model_name = "daanelson/minigpt-4:b96a2f33cc8e4b0aa23eacfce731b9c41a7d9466d9ed4e167375587b54db9423"

class MPLUGOwl(Image2Text):
    """Wrapper for a mPLUG Owl model."""
    def __init__(self):
        super().__init__()
        self.name = "mplug-owl"
        self.model_name = "joehoover/mplug-owl:51a43c9d00dfd92276b2511b509fcb3ad82e221f6a9e5806c54e69803e291d6b"
    
    def _generate_text(self, image_url):
        response = replicate.run(
            self.model_name, 
            input={
                "img": image_url,
                "prompt": self.task_description,
                }
            )
        output_str = ''
        for item in response:
            output_str += item
        
        return output_str

## Prompts

These specific prompts were generated using GPT-4. Feel free to generate prompts however you'd like!

In [5]:
easy_texts = [
    "A red apple sitting on a wooden table with sunlight streaming in from a window.",
    "A small white dog is playing in a lush green park, chasing a yellow frisbee.",
    "A bluebird is perched on a blooming cherry blossom branch on a clear spring day.",
]

medium_texts = [
    "A busy city street with neon signs in the evening, people walking with umbrellas, a vendor selling hot dogs, and a red double-decker bus passing by.",
    "A quaint cobblestone alleyway in a European town during a bright day. There are colorful flowers in the window boxes, a bicycle leaning against the wall, and a cat lounging near a doorway.",
    "An astronaut floating in the International Space Station, looking out at Earth through the window, with a space capsule docked in the background.",
]

hard_texts = [
    "A grand medieval banquet hall filled with elegantly dressed lords and ladies feasting on a spread of exotic dishes, a minstrel playing a lute, and a knight narrating his adventures.",
    "A bustling marketplace in an ancient Middle Eastern city. Traders haggling over spices and silks, camels carrying goods, the sun setting behind a mosque with a crescent moon visible.",
    "A complex network of futuristic machines in a high-tech lab. Scientists are observing data on holographic screens, while autonomous robots are assembling nanobots.",
]

impossible_texts = [
    "A panoramic scene of an advanced alien civilization on a distant exoplanet. Interstellar vehicles flying in an indigo sky above towering crystalline structures. Aliens with varying physical features are interacting, engaging in activities like exchanging energy orbs, communicating through light patterns, and tending to exotic, bio-luminescent flora. The planet’s twin moons are visible in the horizon over a glistening alien ocean."
]

levels = ["easy", "medium", "hard", "impossible"]
level_prompts = [easy_texts, medium_texts, hard_texts, impossible_texts]

def get_prompts():
    prompts = []
    for level, texts in zip(levels, level_prompts):
        for text in texts:
            prompts.append(Prompt(text, level))
    return prompts

## Telephone Lines

In [6]:
def download_image(image_url, filename):
    img_data = requests.get(image_url).content
    with open(filename, 'wb') as handler:
        handler.write(img_data)

In [7]:
class TelephoneLine(object):
    """Class for playing telephone with AI."""
    def __init__(self, t2i, i2t):
        self.t2i = t2i
        self.i2t = i2t
        self.name = f"{t2i.name}_{i2t.name}"
        self.conversations = {}

    def get_conversation_name(self, text):
        full_name = f"{self.name}{text}"
        hashed_name = hashlib.md5(full_name.encode())
        return hashed_name.hexdigest()[:6]
    
    def play(self, prompt, nturns = 10):
        """Play a game of telephone."""
        print(f"Connecting {self.t2i.name} <-> {self.i2t.name} with prompt: {prompt.text[:20]}...")
        texts = [prompt.text]
        image_urls = []

        for _ in range(nturns):
            image_url = self.t2i.generate_image(texts[-1])
            text = self.i2t.generate_text(image_url)
            texts.append(text)
            image_urls.append(image_url)
        
        conversation_name = self.get_conversation_name(prompt.text)
        self.conversations[conversation_name] = {
            "texts": texts,
            "image_urls": image_urls,
            "level": prompt.level
        }
             
    def save_conversations_to_dataset(self, dataset):
        """Save conversations to a dataset."""
        for conversation_name in self.conversations.keys():
            conversation = self.conversations[conversation_name]
            prompt = conversation["texts"][0]
            level = conversation["level"]
            image_urls = conversation["image_urls"]
            texts = conversation["texts"]

            for i in range(len(image_urls)):
                filename = f"{conversation_name}_{i}.jpg"
                filepath = os.path.join("telephone_images", filename)
                download_image(image_urls[i], filepath)

                sample = fo.Sample(
                    filepath = filepath,
                    conversation_name = conversation_name,
                    prompt = prompt,
                    level = level,
                    t2i_model = self.t2i.name,
                    i2t_model = self.i2t.name,
                    step_number = i,
                    text_before = texts[i],
                    text_after = texts[i+1]
                )
                
                dataset.add_sample(sample)

Set the directory where you'd like images downloaded/stored:

In [None]:
IMAGES_DIR = "telephone_images"
if not os.path.exists(IMAGES_DIR):
    os.makedirs(IMAGES_DIR)

## Carry out the conversations

In [None]:
## Image2Text models
mplug_owl = MPLUGOwl()
blip = BLIP()
clip_prefix = CLIPPrefix()
mini_gpt4 = MiniGPT4()
image2text_models = [mplug_owl, blip, clip_prefix, mini_gpt4]

In [None]:
## Text2Image models
vqgan_clip = VQGANCLIP()
sd = StableDiffusion()
dalle2 = DALLE2()
text2image_models = [sd, dalle2, vqgan_clip]

In [None]:
combos = [(t2i, i2t) for t2i in text2image_models for i2t in image2text_models]
lines = [TelephoneLine(*combo) for combo in combos]

In [None]:
prompts = get_prompts()

Create the dataset where we will store the results:

In [None]:
dataset = fo.Dataset(name = 'telephone', persistent=True)
dataset.add_sample_field("conversation_name", fo.StringField)
dataset.add_sample_field("prompt", fo.StringField)
dataset.add_sample_field("level", fo.StringField)
dataset.add_sample_field("t2i_model", fo.StringField)
dataset.add_sample_field("i2t_model", fo.StringField)
dataset.add_sample_field("step_number", fo.IntField)
dataset.add_sample_field("text_before", fo.StringField)
dataset.add_sample_field("text_after", fo.StringField)

Play all of the games:

In [None]:
for line in tqdm(lines):
    for prompt in prompts:
        line.play(prompt, nturns = 10)
    line.save_conversations_to_dataset(dataset)

Check out the results in the FiftyOne App:

In [None]:
## auto=False to prevent the app from opening. Open with new tab in browser: http://localhost:5151
session = fo.launch_app(dataset, auto = False)

Use the dynamic groups functionality in the FiftyOne App: click on the splitting icon in the menu bar to group images by conversation, select `conversation_name` from the dropdown, then toggle the selector to `ordered` and select `step_number`.

## Analysis

In [8]:
import numpy as np
from scipy.spatial.distance import cosine as cosine_distance

Create a unique hash key for each prompt and store the embeddings in a dictionary:

In [None]:
def hash_prompt(prompt):
    return hashlib.md5(prompt.encode()).hexdigest()[:6]

## Use ImageBind to embed text. You can use any text embedding model here. 
## You can also embed the generated images if you appropriately modify the code below.
MODEL_NAME = "daanelson/imagebind:0383f62e173dc821ec52663ed22a076d9c970549c209666ac3db181618b7a304"
def embed_text(text):
    response = replicate.run(
        MODEL_NAME,
        input={
            "text_input": text,
             "modality": "text"
            }
    )
    return np.array(response)

prompts = dataset.distinct("prompt")

### Embed initial prompts
prompt_embeddings = {}
dataset.add_sample_field("prompt_hash", fo.StringField)
prompt_groups = dataset.group_by("prompt")
for pg in prompt_groups.iter_dynamic_groups():
    prompt = pg.first().prompt
    hash = hash_prompt(prompt)
    prompt_embeddings[hash] = embed_text(prompt)
    view = pg.set_field("prompt_hash", hash)
    view.save("prompt_hash")

Compute a distance between the text description and the prompt:

In [None]:
dataset.add_sample_field("text_after_dist", fo.FloatField)

prompt_groups = dataset.group_by("conversation_name")
for cg in conversation_groups.iter_dynamic_groups(progress=True):
    hash = cg.first().prompt_hash
    prompt_embedding = prompt_embeddings[hash]

    ordered_samples = cg.sort_by("step_number")
    for sample in ordered_samples.iter_samples(autosave=True):
        text_embedding = embed_text(sample.text_after)
        sample["text_embedding"] = text_embedding        
        sample.text_after_dist = cosine_distance(
            prompt_embedding,
            text_embedding
        )

Aggregate the results by level of prompt difficulty, T2I model, and I2T model:

In [None]:
### Aggregate performance by level
levels = dataset.distinct("level")
t2i_models = dataset.distinct("t2i_model")
i2t_models = dataset.distinct("i2t_model")
pairs = [(t2i, i2t) for t2i in t2i_models for i2t in i2t_models]
steps = sorted(dataset.distinct("step_number"))

pair_dists = {}

for level in levels:
    pair_level_dists = {}
    level_view = dataset.match(F("level") == level)
    for pair in pairs:
        t2i_model, i2t_model = pair
        model_view = level_view.match(F("t2i_model") == t2i_model).match(F("i2t_model") == i2t_model)
        step_dists = [0.]
        for step in steps:
            step_view = model_view.match(F("step_number") == step)
            step_dists.append(step_view.mean("image_dist"))
            step_dists.append(step_view.mean("text_after_dist"))
        pair_level_dists[pair] = step_dists
    pair_dists[level] = pair_level_dists

### Plotting the results

In [9]:
import matplotlib.pyplot as plt
import numpy as np

Set the style for each curve in the plot based on the T2I model and I2T model, so that we can easily distinguish between them:

In [None]:
## color by t2i model
t2i_colors = {
    "dalle-2": "r",
    "stable-diffusion": 'b',
    "vqgan-clip": 'y'
}

## marker by i2t model
i2t_markers = {
    "clip-prefix": '+',
    'minigpt-4': 'o',
    'blip': 'v',
    'mplug-owl': 's'
}


def get_style(pair):
    t2i, i2t = pair
    return f"-{t2i_colors[t2i]}{i2t_markers[i2t]}"

def format_pair(pair):
    t2i, i2t = pair
    arrow = r'$\leftrightarrow$'
    return f"{t2i}{arrow}{i2t}"

Function that plots results for each difficulty level:

In [None]:
def plot_level_results(level):
    plt.figure(figsize=(20,10))
    for pair in pairs:
        dists = pair_dists[level][pair]
        steps = np.arange(len(dists)) + 1
        plt.plot(steps, dists, get_style(pair), label=format_pair(pair))
        plt.xlabel("Step Number")
        plt.ylabel("Cosine Distance")
    plt.title(f"AI Telephone Results: {level.capitalize()} Prompts", fontsize=20)
    plt.legend(frameon=False)
    plt.savefig(f"{level}.png")

Iterate over each difficulty level and plot the results:

In [None]:
for level in levels:
    plot_level_results(level)