How to Build Your Own Image Dataset

We live in a far from perfect world, and many times, the data we need for training our machine learning models isn’t already present in some Internet dataset. As a result, it’s quite important for a machine learning developer to understand how exactly they can construct their own datasets in such situations. In today’s post, I will show you how to use Python and Selenium WebDriver to scrape images from Google. In addition, I’ll demonstrate how you can upload these images to an Amazon Web Services (AWS) S3 storage bucket instead of having them eat up your local machine’s storage space.

Module Imports

from typing import List
from selenium import webdriver
from time import sleep
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.common.by import By
import requests
import io
from PIL import Image
import PIL.Image
import os
import boto3
import shutil
import threading
  • selenium – this library allows us to launch a Google Chrome window, open a specific URL page, and simulate user behaviors such as clicks, scrolls, and many other actions
  • requests – with this library, we can make HTTP GET requests to receive access to the HTML contents of web pages
  • io – this library allows us to operate with different file types
  • PIL – this library allows us to work with images and tweak them as deemed necessary
  • os – through this library, we’re able to execute shell commands, specifically at the file level
  • boto3 – this library gives us access to AWS services
  • shutil – similar to the os library, this library can be used to execute shell commands. However, it is mainly used to target higher level directories and not so much so individual files
  • threading – with this library, we can run multiple threads in parallel to speed up a program’s execution time.

Class Initialization

def __init__(self, s3: boto3.client, queries: List[str], max_images: int, target_path: str, num_threads: int, batch_size: int):
        self.url = 'https://www.images.google.com/'
        self.s3 = s3
        self.queries = queries
        self.max_images = max_images
        self.target_path = target_path
        self.num_threads = num_threads
        self.batch_size = batch_size

        if self.num_threads * batch_size != len(self.queries):
            raise ValueError(
                "The number of threads and batch size you specified do not multiply to the length of the query list")

For our image scraper class to execute the scraping task efficiently, the user needs to provide it with sufficient initial information through the class constructor. Let’s take a look at these initial values and discuss what each of them helps with.

  • url – by default, this is the Google Images URL and will be used by the Google Chrome driver when searching for images. It’s important to note that other URLs will not work here since this entire scraper is built to fetch from Google Images specifically and not from any other source.
  • s3 – this is the AWS S3 bucket to which the ImageScraper class will upload the images.
  • queries – this is a list of labels that the ImageScraper class will find images for.
  • max_images – this is the maximum images we want per label. Once the program is executed, each label will have at most max_images number of images in its directory.
  • target_path – this is the path where all images will be stored.
  • num_threads – this is the number of threads the ImageScraper class will run when scraping.
  • batch_size – this is the number of queries each thread will be scraping.

The exception towards the end of the function is only thrown if the number of threads and the batch size per thread do not multiply out to the number of queries there are. For example, if there are 40 queries, 4 threads and a batch size of 10 are acceptable parameters whereas 5 threads and a batch size of 9 are not.

DOM Body Scrolling

def scroll_down_body(self, browser: webdriver.Chrome):
    browser.execute_script(
        "window.scrollTo(0, document.body.scrollHeight);")
    sleep(1)

The ImageScraper class uses this utility function to scroll down the page once all the images in the browser window have been inspected so that it is able to click on the Load More button. The way Google works is by loading a set of images to fill the full browser window and only showing more images once the user scrolls down. As a result, the scrolling functionality is necessary because without it, the image scraper wouldn’t be able to scrape images past the window size of the browser.

Fetching Image URLs

def get_image_urls(self, query: str, browser: webdriver.Chrome):
    # search for query search_box =
    delay = 2
    search_box = WebDriverWait(browser, delay).until(
        EC.presence_of_element_located((By.CSS_SELECTOR, 'input.gLFyf')))
    
    # enter query in search box and trigger form submission
    search_box.send_keys(query)
    search_box.send_keys(Keys.RETURN)

    image_container = WebDriverWait(browser, delay).until(
        EC.presence_of_element_located((By.CSS_SELECTOR, 'div.mJxzWe')))

    # get the images

    images_processed = 0
    image_urls = set()
    img_start_idx = 0
    old_num_thumbnails = 0

    while images_processed < self.max_images:
        self.scroll_down_body(browser)

        thumbnail_list = WebDriverWait(browser, delay).until(
            EC.presence_of_all_elements_located((By.CSS_SELECTOR, 'img.Q4LuWd')))
        num_thumbnails = len(thumbnail_list)

        if num_thumbnails == old_num_thumbnails:
            break

        old_num_thumbnails = num_thumbnails

        for thumbnail in thumbnail_list[img_start_idx: min(num_thumbnails, self.max_images)]:
            try:
                thumbnail.click()
                sleep(2)
                if thumbnail.get_attribute('src') and 'http' in thumbnail.get_attribute('src'):
                    image_url = thumbnail.get_attribute('src')
                    image_urls.add(image_url)
            except Exception:
                continue

        images_processed = len(image_urls)

        sleep(1)

        img_start_idx = min(num_thumbnails, self.max_images)
        
        load_more_btn = WebDriverWait(browser, delay).until(
            EC.presence_of_all_element_located((By.CSS_SELECTOR, '.mye4qd')))

        if load_more_btn:
            browser.execute_script(
                "document.querySelector('.mye4qd').click();")
    
    print(f'The total number of image links we found is {images_processed}')
    
    return image_urls

The first thing we’d like to do once we open the Google Images site is find the search engine’s input field. If you right click the input field and inspect the HTML content in the console, you’ll notice that the input field has a CSS class of “gLFyf”. That CSS class is what we use to tell Selenium how to find this input field. This request is wrapped by the function WebDriverWait() as a means of instructing Selenium to wait until the HTML element is present in the DOM body. This is especially useful when working with dynamically loaded JavaScript pages where elements are inserted onto the page upon user triggers and are not present by default.

Now that we’ve found the search input field, we can feed it the query using the send_keys() function. We can also use this function again to trigger the form submission by passing in the Keys.RETURN argument. Once this is done, the browser navigates to the Google Image results page for that specific query and the actual image scraping task can finally begin.

As mentioned earlier, a limited number of images will load in the browser window and we must make use of scrolling to access the next set of images. Hence, we’ll use the images_processed variable to keep track of how many images we’ve seen so far and we’ll stop scraping once it reaches the specified max image limit. The image URLs found get stored in the set image_urls for later usage.

The next step is to run a loop that fetches the image URLs within each window frame. To do that, we’ll first get access to the HTML elements responsible for holding the image thumbnails. This can be done using the CSS class selector “Q4LuWd” that’s common across all the thumbnail elements. To make sure that we’ve actually fetched new thumbnails in a specific loop iteration, we check that the new thumbnail count is not equal to the old thumbnail count of the previous iteration.

Once we’ve added all the thumbnails in the window to the thumbnails_list set, we loop through them and grab the image URL of each one and add it to the image_urls set. We can then update the images_processed value by assigning to it the length of the image_urls set.

The last step in this cycle is making sure that when we reach the end of a results page, we check if there’s a “Load More” button that extends the results page by adding more images. If we don’t check for this, then the scraper will stop scraping once it reaches the end of one result page, which may be undesirable depending on how many images we’re interested in. Hence, we search for this button using its CSS class selector “mye4qd” and click it if it is present.

This full loop continues iterating until either of these things happens:

  • We’ve exhausted the max_images limit.
  • We can’t find any more images for the specified query.

After loop termination, all there is left to do is return the image urls, which can then be used by other functions to actually download the images either locally or upload to a cloud storage location like an AWS S3 bucket.

Saving the Images

def save_image(self, dir: str, url: str, idx: int, query: str):
    try:
        img_content = requests.get(url).content
    except Exception as e:
        print(f'ERROR - Could not download {url} - {e}')

    try:
        img_file = io.BytesIO(img_content)
        img = Image.open(img_file).convert('RGB')
        file_path = os.path.join(dir, str(idx) + '.jpg')

        with open(file_path, 'wb') as f:
            img.save(f, "JPEG", quality=85)
        print(f"SUCCESS - saved {url} - as {file_path}")

        #self.s3.upload_file(file_path, S3_BUCKET, '_'.join(query.lower().split(' ')) + '/' + str(idx) + '.jpg')
        #os.remove(file_path)

    except Exception as e:
        print(f"ERROR - Could not save {url} - {e}")

The first step in this process is actually downloading the image’s content. This can be done using the requests module, which allows us to execute an HTTP GET request. Once we have the content, we first convert it into bytes using the BytesIO() function. Next, we open the image and convert it to an RGB (Red, Green, Blue) color scale because we’d like all the images to be consistent with each other. Now, we’d like to save this image locally, so we create a file with a ‘.jpg’ extension and a filename represented by the index number of the image, attach it to the file path for the query’s image directory, and save the image file to it.

Moving forward, if you would like to have the images stored on an AWS S3 bucket instead as I did, then you can uncomment the next two lines of code. The first line does the actual file uploading to AWS S3 and the second line removes the file from your local machine. Otherwise, the images will only be saved on your local machine in the specified directory.

Fetching Images for a Single Query

def fetch_single_query(self, query: str, browser: webdriver.Chrome):
    img_dir = os.path.join(
        self.target_path, '_'.join(query.lower().split(' ')))

    if not os.path.exists(img_dir):
        os.makedirs(img_dir)

    # reopen google every time
    browser.get(self.url)

    image_urls = self.get_image_urls(query, browser)

    for idx, img_url in enumerate(image_urls):
        self.save_image(img_dir, img_url, idx, query)

To scrape images for a single query, we need to first create a directory for that query so that we can store the image files there. If we end up using AWS instead of storing locally, the uncommented lines in the saving function handle deleting this created directory.

Next, we need to submit an HTTP GET request to actually load the Google Images site. Once we’ve done that, we can get the image URLs using the function we created earlier.

Lastly, the remaining step is to iterate through these URLs and use the saving function to save them to the assigned directory either on the user’s local machine or designated AWS S3 bucket depending on the user’s specifications.

Fetching Images for a Batch of Queries

def fetch_query_batch(self, batch: List[str], browser: webdriver.Chrome):
    for query in batch:
        self.fetch_single_query(query, browser)

    browser.close()

This function is as simple as iterating through a batch of queries and invoking the previous function we discussed and then finally closing the browser session once we’re done.

Fetching Images for All Queries

def fetch_all_queries(self):
    threads = []
    options = Options()

    # options.headless = True

    for i in range(self.num_threads):
        start_idx = batch_size * i
        end_idx = start_idx + batch_size
        batch = self.queries[start_idx: end_idx]
        t = threading.Thread(target=self.fetch_query_batch, args=[
                             batch, webdriver.Chrome(options=options)])
        threads.append(t)

    [thread.start() for thread in threads]
    [thread.join() for thread in threads]

    shutil.rmtree('./images')

We start this function off by creating a list of threads. We then take batch slices of the list of queries and pass each batch slice to a specific thread. For example, if our query list has a length of 24 and our batch size is set to 8, then we’d have three threads, which would take the following batch slices

  • Thread 1 – The first 8 queries (index 0 to index 7)
  • Thread 2 – The second 8 queries (index 8 to index 15)
  • Thread 3 – The third 8 queries (index 16 to index 23)

Once we’ve created the threads, all we need to do is run them. Once everything is completely done, we remove the images directory we created towards the start. And voila, we’re done! The images are uploaded onto the AWS S3 bucket, and the local directory and files are no longer present on your machine.

If this post helped you, please show your support by dropping a like and commenting 🙂 Thank you!

Leave a Reply