import operator from numba import jit import numpy as np from skimage import transform, color @jit(nopython=True) def h_unit(h): """ The unit value (1-15) of hue """ for i, u in [(0, 1), (14, 2), (33, 3), (44, 4), (62, 5), (88, 6), (165, 7), (194, 8), (219, 9), (241, 10), (266, 11), (292, 12), (307, 13), (327, 14), (344, 15), (360, 1)]: if h <= i: return u return 0 @jit(nopython=True) def s_unit(s): """ The unit value (1-8) of saturation """ for i, u in [(0, 1), (5, 1), (10, 2), (14, 3), (34, 4), (53, 5), (63, 6), (80, 7), (100, 8)]: if s <= i: return u return 0 @jit(nopython=True) def b_unit(b): """ The unit value (1-6) of brightness """ for i, u in [(0, 1), (20, 1), (39, 2), (55, 3), (72, 4), (90, 5), (100, 6)]: if b <= i: return u return 0 @jit(nopython=True) def encode_foco(hsb): h, s, b = hsb theta = (h - 1) / 15 * np.pi * 2 return np.array([np.cos(theta), np.sin(theta), (s - 1) / 7, (b - 1) / 5]) @jit(nopython=True) def is_white(hsb_unit): """ Whether an hsb color unit is white """ _, s, b = hsb_unit return s == 1 and b == 6 @jit(nopython=True) def is_rice(hsb_unit): """ Whether an hsb color unit is rice """ h, s, b = hsb_unit return h in [3, 4, 5] and s in [2, 3, 4] and b in [5, 6] @jit(nopython=True) def is_blue_netrual(hsb_unit): """ Whether an hsb color is blue netrual """ h, s, b = hsb_unit return h in [9, 10] and s in [2, 3, 4, 5, 6] and b in [2, 3, 4] @jit(nopython=True) def is_pastel(hsb_unit): """ Whether an hsb color unit is pastel """ _, s, b = hsb_unit return s == 1 and b in [4, 5] @jit(nopython=True) def is_black(hsb_unit): """ Whether an hsb color unit is black """ _, _, b = hsb_unit return b == 1 @jit(nopython=True) def is_gray(hsb_unit): """ Whether an hsb color unit is gray """ _, s, b = hsb_unit return s == 1 and b in [2, 3] @jit(nopython=True) def foco_merge(hsb_foco): """ Merge multiple black/white/gray foco unit to one """ if is_white(hsb_foco): return (1, 1, 6) if is_black(hsb_foco): return (1, 1, 1) if is_gray(hsb_foco): return (1, hsb_foco[1], hsb_foco[2]) return hsb_foco[0], hsb_foco[1], hsb_foco[2] def color_foco(img): """ transform an hsb image to color units input: (height, width, 3) array output: an uint8 array with same shape """ img = img.copy() h, w, c = img.shape # print(h, w, c) assert c == 3 for i in range(h): for j in range(w): img[i, j, 0] = h_unit(img[i, j, 0]) img[i, j, 1] = s_unit(img[i, j, 1]) img[i, j, 2] = b_unit(img[i, j, 2]) return img.astype(np.uint8) @jit(nopython=True) def is_ignore(hsb, eps=1): """ Ignore white background """ _, s, b = hsb return s < eps and b > 100 - eps @jit(nopython=True) def ignore_list(hsb_img, ignore=is_ignore): """ Given an hsb image (ranges 360, 100, 100), output a dict of ignored pixel locations Input: is_ignore: a lambda that determines whether a hsb pixel should be ignored Output: set of ignore pixel coordinate """ h, w, _ = hsb_img.shape ign = set() for i in range(h): for j in range(w): if ignore(hsb_img[i, j, :]): ign.add((i, j)) return ign def color_histogram(img, merge=None, threshold=None, ignore=None): """ color histogram of an image """ h, w, c = img.shape # print(h, w, c) if merge is None: merge = foco_merge if threshold is None: threshold = h + w if ignore is None: ignore = {} hist = {} assert c == 3 for i in range(h): for j in range(w): if (i, j) in ignore: continue k = merge(tuple(img[i, j])) if k in hist: hist[k] += 1 else: hist[k] = 1 return {k: v for k, v in hist.items() if v > threshold} def main_colors(img, n=1, frequency=False, hist=False, merge=True): """ return the list of main colors of a hsb image. img: a hsb image n: number of main colors to return frequency: whether return freqency """ if hist: hist = img else: hist = color_histogram(color_foco(img), ignore=ignore_list(img)) if merge: newhist = {} oldk = {} ks = sorted(hist.items(), key=operator.itemgetter(1), reverse=True) # print(ks) for k, v in ks: if k in oldk: continue near = {kk: hist[kk] for kk in hist if np.abs(np.sum(np.abs(np.array(k) - np.array(kk)))) <= 2} # print(near) newhist[k] = sum(v for k, v in near.items()) for kk in near: oldk[kk] = None hist.pop(kk) hist = newhist items = sorted(hist.items(), key=operator.itemgetter(1), reverse=True)[:n] return items if frequency else [k for k, v in items] def rgb2hsb(rgb_img): """ Transform an rgb image (unit 8 np array, ranges 255, 255, 255) to hsb (float np array, ranges 360, 100, 100) """ rgb_img = np.array(rgb_img).astype('uint8') return color.rgb2hsv(rgb_img) * np.array([360, 100, 100]).reshape(1, 1, 3) def extract_main_colors(img, n=5): """ Args: img: Numpy array (height, width, channel) n: number of main colors wants to extract return: Features of main colors: Numpy array (n, 5) """ # Convert to hsb img = img.astype("uint8") height = img.shape[0] width = img.shape[1] ratio = (512 * 512) / (height * width) if ratio < 1.0: img = transform.resize(img, (int(height * ratio), int(width * ratio))) img = color.rgb2hsv(img) * np.array([360, 100, 100]).reshape(1, 1, 3) # Extract main colors cf = main_colors(img, n, frequency=True) s = sum(f for c, f in cf) features = np.zeros((n, 5)) for i, (c, f) in enumerate(cf): features[i, :4] = encode_foco(c) features[i, 4] = f / s return features