{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " Try in Google Colab\n", " \n", " \n", " \n", " \n", " Share via nbviewer\n", " \n", " \n", " \n", " \n", " View on GitHub\n", " \n", " \n", " \n", " \n", " Download notebook\n", " \n", "
\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# AI Telephone\n", "\n", "This notebook walks you through how to play a game of multimodal AI Telephone!\n", "\n", "Here’s how the game of AI Telephone works:\n", "\n", "1. Each “game” will pair up an image-to-text (I2T) model with a text-to-image (T2I) model\n", "2. Given an initial prompt, we use the T2I model to generate an image.\n", "3. We then pass this image into the I2T model to generate a description.\n", "4. We repeat steps 2 and 3 a fixed number of times `n` (in our case `n=10`).\n", "5. Finally, we quantify the difference between the original prompt and the final description." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "To run this code, you will need to install the [FiftyOne open source library](https://github.com/voxel51/fiftyone) for dataset curation, the [OpenAI Python Library](https://github.com/openai/openai-python), and the [Replicate Python client](https://replicate.com/docs/get-started/python)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install fiftyone openai replicate" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We will import all of the necessary modules:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import hashlib\n", "import os\n", "import requests\n", "\n", "import openai\n", "import replicate\n", "\n", "import fiftyone as fo\n", "from fiftyone import ViewField as F" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Text-to-Image Models" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "First we define the base class:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "class Text2Image(object):\n", " \"\"\"Wrapper for a Text2Image model.\"\"\"\n", " def __init__(self):\n", " self.name = None\n", " self.model_name = None\n", "\n", " def generate_image(self, text):\n", " response = replicate.run(self.model_name, input={\"prompt\": text})\n", " if type(response) == list:\n", " response = response[0]\n", " return response" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Then we create a class for each model:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class StableDiffusion(Text2Image):\n", " \"\"\"Wrapper for a StableDiffusion model.\"\"\"\n", " def __init__(self):\n", " self.name = \"stable-diffusion\"\n", " self.model_name = \"stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478\"\n", "\n", "\n", "class VQGANCLIP(Text2Image):\n", " \"\"\"Wrapper for a VQGAN-CLIP model.\"\"\"\n", " def __init__(self):\n", " self.name = \"vqgan-clip\"\n", " self.model_name = \"mehdidc/feed_forward_vqgan_clip:28b5242dadb5503688e17738aaee48f5f7f5c0b6e56493d7cf55f74d02f144d8\"\n", "\n", "\n", "class DALLE2(Text2Image):\n", " \"\"\"Wrapper for a DALL-E 2 model.\"\"\"\n", " def __init__(self):\n", " self.name = \"dalle-2\"\n", " \n", " def generate_image(self, text):\n", " response = openai.Image.create(\n", " prompt=text,\n", " n=1,\n", " size=\"512x512\"\n", " )\n", " return response['data'][0]['url']" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Image-to-Text Models" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Once again, we define the base class:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class Image2Text(object):\n", " \"\"\"Wrapper for an Image2Text model.\"\"\"\n", " def __init__(self):\n", " self.name = None\n", " self.model_name = None\n", " self.task_description = \"Write a detailed description of this image.\"\n", "\n", " def _clean_response(self, response):\n", " response = response.lower()\n", " phrases = [\"caption: \", \"the image shows \", \"the image features\"]\n", " for phrase in phrases:\n", " if phrase in response:\n", " response = response.split(phrase)[1].strip()\n", " return response\n", " \n", " def _generate_text(self, image_url):\n", " response = replicate.run(\n", " self.model_name, \n", " input={\n", " \"image\": image_url,\n", " \"prompt\": self.task_description,\n", " }\n", " )\n", " return response\n", "\n", " def generate_text(self, image_url):\n", " response = self._generate_text(image_url)\n", " response = self._clean_response(response)\n", " return response" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Then we create a class for each model:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class BLIP(Image2Text):\n", " \"\"\"Wrapper for a BLIP model.\"\"\"\n", " def __init__(self):\n", " super().__init__()\n", " self.name = \"blip\"\n", " self.model_name = \"salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746\"\n", "\n", "class CLIPPrefix(Image2Text):\n", " \"\"\"Wrapper for a CLIPPrefixCaptioning model.\"\"\"\n", " def __init__(self):\n", " super().__init__()\n", " self.name = \"clip-prefix\"\n", " self.model_name = \"rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8\"\n", "\n", "class MiniGPT4(Image2Text):\n", " \"\"\"Wrapper for a MiniGPT-4 model.\"\"\"\n", " def __init__(self):\n", " super().__init__()\n", " self.name = \"minigpt-4\"\n", " self.model_name = \"daanelson/minigpt-4:b96a2f33cc8e4b0aa23eacfce731b9c41a7d9466d9ed4e167375587b54db9423\"\n", "\n", "class MPLUGOwl(Image2Text):\n", " \"\"\"Wrapper for a mPLUG Owl model.\"\"\"\n", " def __init__(self):\n", " super().__init__()\n", " self.name = \"mplug-owl\"\n", " self.model_name = \"joehoover/mplug-owl:51a43c9d00dfd92276b2511b509fcb3ad82e221f6a9e5806c54e69803e291d6b\"\n", " \n", " def _generate_text(self, image_url):\n", " response = replicate.run(\n", " self.model_name, \n", " input={\n", " \"img\": image_url,\n", " \"prompt\": self.task_description,\n", " }\n", " )\n", " output_str = ''\n", " for item in response:\n", " output_str += item\n", " \n", " return output_str" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Prompts" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "These specific prompts were generated using GPT-4. Feel free to generate prompts however you'd like!" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "easy_texts = [\n", " \"A red apple sitting on a wooden table with sunlight streaming in from a window.\",\n", " \"A small white dog is playing in a lush green park, chasing a yellow frisbee.\",\n", " \"A bluebird is perched on a blooming cherry blossom branch on a clear spring day.\",\n", "]\n", "\n", "medium_texts = [\n", " \"A busy city street with neon signs in the evening, people walking with umbrellas, a vendor selling hot dogs, and a red double-decker bus passing by.\",\n", " \"A quaint cobblestone alleyway in a European town during a bright day. There are colorful flowers in the window boxes, a bicycle leaning against the wall, and a cat lounging near a doorway.\",\n", " \"An astronaut floating in the International Space Station, looking out at Earth through the window, with a space capsule docked in the background.\",\n", "]\n", "\n", "hard_texts = [\n", " \"A grand medieval banquet hall filled with elegantly dressed lords and ladies feasting on a spread of exotic dishes, a minstrel playing a lute, and a knight narrating his adventures.\",\n", " \"A bustling marketplace in an ancient Middle Eastern city. Traders haggling over spices and silks, camels carrying goods, the sun setting behind a mosque with a crescent moon visible.\",\n", " \"A complex network of futuristic machines in a high-tech lab. Scientists are observing data on holographic screens, while autonomous robots are assembling nanobots.\",\n", "]\n", "\n", "impossible_texts = [\n", " \"A panoramic scene of an advanced alien civilization on a distant exoplanet. Interstellar vehicles flying in an indigo sky above towering crystalline structures. Aliens with varying physical features are interacting, engaging in activities like exchanging energy orbs, communicating through light patterns, and tending to exotic, bio-luminescent flora. The planet’s twin moons are visible in the horizon over a glistening alien ocean.\"\n", "]\n", "\n", "levels = [\"easy\", \"medium\", \"hard\", \"impossible\"]\n", "level_prompts = [easy_texts, medium_texts, hard_texts, impossible_texts]\n", "\n", "def get_prompts():\n", " prompts = []\n", " for level, texts in zip(levels, level_prompts):\n", " for text in texts:\n", " prompts.append(Prompt(text, level))\n", " return prompts" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Telephone Lines" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def download_image(image_url, filename):\n", " img_data = requests.get(image_url).content\n", " with open(filename, 'wb') as handler:\n", " handler.write(img_data)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class TelephoneLine(object):\n", " \"\"\"Class for playing telephone with AI.\"\"\"\n", " def __init__(self, t2i, i2t):\n", " self.t2i = t2i\n", " self.i2t = i2t\n", " self.name = f\"{t2i.name}_{i2t.name}\"\n", " self.conversations = {}\n", "\n", " def get_conversation_name(self, text):\n", " full_name = f\"{self.name}{text}\"\n", " hashed_name = hashlib.md5(full_name.encode())\n", " return hashed_name.hexdigest()[:6]\n", " \n", " def play(self, prompt, nturns = 10):\n", " \"\"\"Play a game of telephone.\"\"\"\n", " print(f\"Connecting {self.t2i.name} <-> {self.i2t.name} with prompt: {prompt.text[:20]}...\")\n", " texts = [prompt.text]\n", " image_urls = []\n", "\n", " for _ in range(nturns):\n", " image_url = self.t2i.generate_image(texts[-1])\n", " text = self.i2t.generate_text(image_url)\n", " texts.append(text)\n", " image_urls.append(image_url)\n", " \n", " conversation_name = self.get_conversation_name(prompt.text)\n", " self.conversations[conversation_name] = {\n", " \"texts\": texts,\n", " \"image_urls\": image_urls,\n", " \"level\": prompt.level\n", " }\n", " \n", " def save_conversations_to_dataset(self, dataset):\n", " \"\"\"Save conversations to a dataset.\"\"\"\n", " for conversation_name in self.conversations.keys():\n", " conversation = self.conversations[conversation_name]\n", " prompt = conversation[\"texts\"][0]\n", " level = conversation[\"level\"]\n", " image_urls = conversation[\"image_urls\"]\n", " texts = conversation[\"texts\"]\n", "\n", " for i in range(len(image_urls)):\n", " filename = f\"{conversation_name}_{i}.jpg\"\n", " filepath = os.path.join(\"telephone_images\", filename)\n", " download_image(image_urls[i], filepath)\n", "\n", " sample = fo.Sample(\n", " filepath = filepath,\n", " conversation_name = conversation_name,\n", " prompt = prompt,\n", " level = level,\n", " t2i_model = self.t2i.name,\n", " i2t_model = self.i2t.name,\n", " step_number = i,\n", " text_before = texts[i],\n", " text_after = texts[i+1]\n", " )\n", " \n", " dataset.add_sample(sample)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Set the directory where you'd like images downloaded/stored:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IMAGES_DIR = \"telephone_images\"\n", "if not os.path.exists(IMAGES_DIR):\n", " os.makedirs(IMAGES_DIR)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Carry out the conversations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## Image2Text models\n", "mplug_owl = MPLUGOwl()\n", "blip = BLIP()\n", "clip_prefix = CLIPPrefix()\n", "mini_gpt4 = MiniGPT4()\n", "image2text_models = [mplug_owl, blip, clip_prefix, mini_gpt4]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## Text2Image models\n", "vqgan_clip = VQGANCLIP()\n", "sd = StableDiffusion()\n", "dalle2 = DALLE2()\n", "text2image_models = [sd, dalle2, vqgan_clip]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "combos = [(t2i, i2t) for t2i in text2image_models for i2t in image2text_models]\n", "lines = [TelephoneLine(*combo) for combo in combos]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "prompts = get_prompts()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Create the dataset where we will store the results:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = fo.Dataset(name = 'telephone', persistent=True)\n", "dataset.add_sample_field(\"conversation_name\", fo.StringField)\n", "dataset.add_sample_field(\"prompt\", fo.StringField)\n", "dataset.add_sample_field(\"level\", fo.StringField)\n", "dataset.add_sample_field(\"t2i_model\", fo.StringField)\n", "dataset.add_sample_field(\"i2t_model\", fo.StringField)\n", "dataset.add_sample_field(\"step_number\", fo.IntField)\n", "dataset.add_sample_field(\"text_before\", fo.StringField)\n", "dataset.add_sample_field(\"text_after\", fo.StringField)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Play all of the games:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for line in tqdm(lines):\n", " for prompt in prompts:\n", " line.play(prompt, nturns = 10)\n", " line.save_conversations_to_dataset(dataset)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Check out the results in the FiftyOne App:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## auto=False to prevent the app from opening. Open with new tab in browser: http://localhost:5151\n", "session = fo.launch_app(dataset, auto = False)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Use the dynamic groups functionality in the FiftyOne App: click on the splitting icon in the menu bar to group images by conversation, select `conversation_name` from the dropdown, then toggle the selector to `ordered` and select `step_number`." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Analysis" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from scipy.spatial.distance import cosine as cosine_distance" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Create a unique hash key for each prompt and store the embeddings in a dictionary:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def hash_prompt(prompt):\n", " return hashlib.md5(prompt.encode()).hexdigest()[:6]\n", "\n", "## Use ImageBind to embed text. You can use any text embedding model here. \n", "## You can also embed the generated images if you appropriately modify the code below.\n", "MODEL_NAME = \"daanelson/imagebind:0383f62e173dc821ec52663ed22a076d9c970549c209666ac3db181618b7a304\"\n", "def embed_text(text):\n", " response = replicate.run(\n", " MODEL_NAME,\n", " input={\n", " \"text_input\": text,\n", " \"modality\": \"text\"\n", " }\n", " )\n", " return np.array(response)\n", "\n", "prompts = dataset.distinct(\"prompt\")\n", "\n", "### Embed initial prompts\n", "prompt_embeddings = {}\n", "dataset.add_sample_field(\"prompt_hash\", fo.StringField)\n", "prompt_groups = dataset.group_by(\"prompt\")\n", "for pg in prompt_groups.iter_dynamic_groups():\n", " prompt = pg.first().prompt\n", " hash = hash_prompt(prompt)\n", " prompt_embeddings[hash] = embed_text(prompt)\n", " view = pg.set_field(\"prompt_hash\", hash)\n", " view.save(\"prompt_hash\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Compute a distance between the text description and the prompt:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset.add_sample_field(\"text_after_dist\", fo.FloatField)\n", "\n", "prompt_groups = dataset.group_by(\"conversation_name\")\n", "for cg in conversation_groups.iter_dynamic_groups(progress=True):\n", " hash = cg.first().prompt_hash\n", " prompt_embedding = prompt_embeddings[hash]\n", "\n", " ordered_samples = cg.sort_by(\"step_number\")\n", " for sample in ordered_samples.iter_samples(autosave=True):\n", " text_embedding = embed_text(sample.text_after)\n", " sample[\"text_embedding\"] = text_embedding \n", " sample.text_after_dist = cosine_distance(\n", " prompt_embedding,\n", " text_embedding\n", " )" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Aggregate the results by level of prompt difficulty, T2I model, and I2T model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "### Aggregate performance by level\n", "levels = dataset.distinct(\"level\")\n", "t2i_models = dataset.distinct(\"t2i_model\")\n", "i2t_models = dataset.distinct(\"i2t_model\")\n", "pairs = [(t2i, i2t) for t2i in t2i_models for i2t in i2t_models]\n", "steps = sorted(dataset.distinct(\"step_number\"))\n", "\n", "pair_dists = {}\n", "\n", "for level in levels:\n", " pair_level_dists = {}\n", " level_view = dataset.match(F(\"level\") == level)\n", " for pair in pairs:\n", " t2i_model, i2t_model = pair\n", " model_view = level_view.match(F(\"t2i_model\") == t2i_model).match(F(\"i2t_model\") == i2t_model)\n", " step_dists = [0.]\n", " for step in steps:\n", " step_view = model_view.match(F(\"step_number\") == step)\n", " step_dists.append(step_view.mean(\"image_dist\"))\n", " step_dists.append(step_view.mean(\"text_after_dist\"))\n", " pair_level_dists[pair] = step_dists\n", " pair_dists[level] = pair_level_dists" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Plotting the results" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Set the style for each curve in the plot based on the T2I model and I2T model, so that we can easily distinguish between them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "## color by t2i model\n", "t2i_colors = {\n", " \"dalle-2\": \"r\",\n", " \"stable-diffusion\": 'b',\n", " \"vqgan-clip\": 'y'\n", "}\n", "\n", "## marker by i2t model\n", "i2t_markers = {\n", " \"clip-prefix\": '+',\n", " 'minigpt-4': 'o',\n", " 'blip': 'v',\n", " 'mplug-owl': 's'\n", "}\n", "\n", "\n", "def get_style(pair):\n", " t2i, i2t = pair\n", " return f\"-{t2i_colors[t2i]}{i2t_markers[i2t]}\"\n", "\n", "def format_pair(pair):\n", " t2i, i2t = pair\n", " arrow = r'$\\leftrightarrow$'\n", " return f\"{t2i}{arrow}{i2t}\"" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Function that plots results for each difficulty level:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_level_results(level):\n", " plt.figure(figsize=(20,10))\n", " for pair in pairs:\n", " dists = pair_dists[level][pair]\n", " steps = np.arange(len(dists)) + 1\n", " plt.plot(steps, dists, get_style(pair), label=format_pair(pair))\n", " plt.xlabel(\"Step Number\")\n", " plt.ylabel(\"Cosine Distance\")\n", " plt.title(f\"AI Telephone Results: {level.capitalize()} Prompts\", fontsize=20)\n", " plt.legend(frameon=False)\n", " plt.savefig(f\"{level}.png\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Iterate over each difficulty level and plot the results:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for level in levels:\n", " plot_level_results(level)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 4 }