Files
sora_python/app/service/outfit_matcher/foco.py
2024-03-28 10:30:18 +08:00

252 lines
6.1 KiB
Python

import operator
from numba import jit
import numpy as np
from skimage import transform, color
@jit(nopython=True)
def h_unit(h):
"""
The unit value (1-15) of hue
"""
for i, u in [(0, 1), (14, 2), (33, 3), (44, 4), (62, 5), (88, 6), (165, 7), (194, 8), (219, 9), (241, 10),
(266, 11), (292, 12), (307, 13), (327, 14), (344, 15), (360, 1)]:
if h <= i:
return u
return 0
@jit(nopython=True)
def s_unit(s):
"""
The unit value (1-8) of saturation
"""
for i, u in [(0, 1), (5, 1), (10, 2), (14, 3), (34, 4), (53, 5), (63, 6), (80, 7), (100, 8)]:
if s <= i:
return u
return 0
@jit(nopython=True)
def b_unit(b):
"""
The unit value (1-6) of brightness
"""
for i, u in [(0, 1), (20, 1), (39, 2), (55, 3), (72, 4), (90, 5), (100, 6)]:
if b <= i:
return u
return 0
@jit(nopython=True)
def encode_foco(hsb):
h, s, b = hsb
theta = (h - 1) / 15 * np.pi * 2
return np.array([np.cos(theta), np.sin(theta), (s - 1) / 7, (b - 1) / 5])
@jit(nopython=True)
def is_white(hsb_unit):
"""
Whether an hsb color unit is white
"""
_, s, b = hsb_unit
return s == 1 and b == 6
@jit(nopython=True)
def is_rice(hsb_unit):
"""
Whether an hsb color unit is rice
"""
h, s, b = hsb_unit
return h in [3, 4, 5] and s in [2, 3, 4] and b in [5, 6]
@jit(nopython=True)
def is_blue_netrual(hsb_unit):
"""
Whether an hsb color is blue netrual
"""
h, s, b = hsb_unit
return h in [9, 10] and s in [2, 3, 4, 5, 6] and b in [2, 3, 4]
@jit(nopython=True)
def is_pastel(hsb_unit):
"""
Whether an hsb color unit is pastel
"""
_, s, b = hsb_unit
return s == 1 and b in [4, 5]
@jit(nopython=True)
def is_black(hsb_unit):
"""
Whether an hsb color unit is black
"""
_, _, b = hsb_unit
return b == 1
@jit(nopython=True)
def is_gray(hsb_unit):
"""
Whether an hsb color unit is gray
"""
_, s, b = hsb_unit
return s == 1 and b in [2, 3]
@jit(nopython=True)
def foco_merge(hsb_foco):
"""
Merge multiple black/white/gray foco unit to one
"""
if is_white(hsb_foco): return (1, 1, 6)
if is_black(hsb_foco): return (1, 1, 1)
if is_gray(hsb_foco): return (1, hsb_foco[1], hsb_foco[2])
return hsb_foco[0], hsb_foco[1], hsb_foco[2]
def color_foco(img):
"""
transform an hsb image to color units
input: (height, width, 3) array
output: an uint8 array with same shape
"""
img = img.copy()
h, w, c = img.shape
# print(h, w, c)
assert c == 3
for i in range(h):
for j in range(w):
img[i, j, 0] = h_unit(img[i, j, 0])
img[i, j, 1] = s_unit(img[i, j, 1])
img[i, j, 2] = b_unit(img[i, j, 2])
return img.astype(np.uint8)
@jit(nopython=True)
def is_ignore(hsb, eps=1):
"""
Ignore white background
"""
_, s, b = hsb
return s < eps and b > 100 - eps
@jit(nopython=True)
def ignore_list(hsb_img, ignore=is_ignore):
"""
Given an hsb image (ranges 360, 100, 100), output a dict of ignored pixel locations
Input:
is_ignore: a lambda that determines whether a hsb pixel should be ignored
Output:
set of ignore pixel coordinate
"""
h, w, _ = hsb_img.shape
ign = set()
for i in range(h):
for j in range(w):
if ignore(hsb_img[i, j, :]):
ign.add((i, j))
return ign
def color_histogram(img, merge=None, threshold=None, ignore=None):
"""
color histogram of an image
"""
h, w, c = img.shape
# print(h, w, c)
if merge is None:
merge = foco_merge
if threshold is None:
threshold = h + w
if ignore is None:
ignore = {}
hist = {}
assert c == 3
for i in range(h):
for j in range(w):
if (i, j) in ignore:
continue
k = merge(tuple(img[i, j]))
if k in hist:
hist[k] += 1
else:
hist[k] = 1
return {k: v for k, v in hist.items() if v > threshold}
def main_colors(img, n=1, frequency=False, hist=False, merge=True):
"""
return the list of main colors of a hsb image.
img: a hsb image
n: number of main colors to return
frequency: whether return freqency
"""
if hist:
hist = img
else:
hist = color_histogram(color_foco(img), ignore=ignore_list(img))
if merge:
newhist = {}
oldk = {}
ks = sorted(hist.items(), key=operator.itemgetter(1), reverse=True)
# print(ks)
for k, v in ks:
if k in oldk:
continue
near = {kk: hist[kk] for kk in hist if np.abs(np.sum(np.abs(np.array(k) - np.array(kk)))) <= 2}
# print(near)
newhist[k] = sum(v for k, v in near.items())
for kk in near:
oldk[kk] = None
hist.pop(kk)
hist = newhist
items = sorted(hist.items(), key=operator.itemgetter(1), reverse=True)[:n]
return items if frequency else [k for k, v in items]
def rgb2hsb(rgb_img):
"""
Transform an rgb image (unit 8 np array, ranges 255, 255, 255) to hsb (float np array, ranges 360, 100, 100)
"""
rgb_img = np.array(rgb_img).astype('uint8')
return color.rgb2hsv(rgb_img) * np.array([360, 100, 100]).reshape(1, 1, 3)
def extract_main_colors(img, n=5):
"""
Args:
img: Numpy array (height, width, channel)
n: number of main colors wants to extract
return:
Features of main colors: Numpy array (n, 5)
"""
# Convert to hsb
img = img.astype("uint8")
height = img.shape[0]
width = img.shape[1]
ratio = (512 * 512) / (height * width)
if ratio < 1.0:
img = transform.resize(img, (int(height * ratio), int(width * ratio)))
img = color.rgb2hsv(img) * np.array([360, 100, 100]).reshape(1, 1, 3)
# Extract main colors
cf = main_colors(img, n, frequency=True)
s = sum(f for c, f in cf)
features = np.zeros((n, 5))
for i, (c, f) in enumerate(cf):
features[i, :4] = encode_foco(c)
features[i, 4] = f / s
return features