feat sketch 提取接口
fix
This commit is contained in:
110
app/service/image2sketch/util/get_data.py
Normal file
110
app/service/image2sketch/util/get_data.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import tarfile
|
||||
import requests
|
||||
from warnings import warn
|
||||
from zipfile import ZipFile
|
||||
from bs4 import BeautifulSoup
|
||||
from os.path import abspath, isdir, join, basename
|
||||
|
||||
|
||||
class GetData(object):
|
||||
"""A Python script for downloading CycleGAN or pix2pix datasets.
|
||||
|
||||
Parameters:
|
||||
technique (str) -- One of: 'cyclegan' or 'pix2pix'.
|
||||
verbose (bool) -- If True, print additional information.
|
||||
|
||||
Examples:
|
||||
>>> from util.get_data import GetData
|
||||
>>> gd = GetData(technique='cyclegan')
|
||||
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
|
||||
|
||||
Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
|
||||
and 'scripts/download_cyclegan_model.sh'.
|
||||
"""
|
||||
|
||||
def __init__(self, technique='cyclegan', verbose=True):
|
||||
url_dict = {
|
||||
'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
|
||||
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
|
||||
}
|
||||
self.url = url_dict.get(technique.lower())
|
||||
self._verbose = verbose
|
||||
|
||||
def _print(self, text):
|
||||
if self._verbose:
|
||||
print(text)
|
||||
|
||||
@staticmethod
|
||||
def _get_options(r):
|
||||
soup = BeautifulSoup(r.text, 'lxml')
|
||||
options = [h.text for h in soup.find_all('a', href=True)
|
||||
if h.text.endswith(('.zip', 'tar.gz'))]
|
||||
return options
|
||||
|
||||
def _present_options(self):
|
||||
r = requests.get(self.url)
|
||||
options = self._get_options(r)
|
||||
print('Options:\n')
|
||||
for i, o in enumerate(options):
|
||||
print("{0}: {1}".format(i, o))
|
||||
choice = input("\nPlease enter the number of the "
|
||||
"dataset above you wish to download:")
|
||||
return options[int(choice)]
|
||||
|
||||
def _download_data(self, dataset_url, save_path):
|
||||
if not isdir(save_path):
|
||||
os.makedirs(save_path)
|
||||
|
||||
base = basename(dataset_url)
|
||||
temp_save_path = join(save_path, base)
|
||||
|
||||
with open(temp_save_path, "wb") as f:
|
||||
r = requests.get(dataset_url)
|
||||
f.write(r.content)
|
||||
|
||||
if base.endswith('.tar.gz'):
|
||||
obj = tarfile.open(temp_save_path)
|
||||
elif base.endswith('.zip'):
|
||||
obj = ZipFile(temp_save_path, 'r')
|
||||
else:
|
||||
raise ValueError("Unknown File Type: {0}.".format(base))
|
||||
|
||||
self._print("Unpacking Data...")
|
||||
obj.extractall(save_path)
|
||||
obj.close()
|
||||
os.remove(temp_save_path)
|
||||
|
||||
def get(self, save_path, dataset=None):
|
||||
"""
|
||||
|
||||
Download a dataset.
|
||||
|
||||
Parameters:
|
||||
save_path (str) -- A directory to save the data to.
|
||||
dataset (str) -- (optional). A specific dataset to download.
|
||||
Note: this must include the file extension.
|
||||
If None, options will be presented for you
|
||||
to choose from.
|
||||
|
||||
Returns:
|
||||
save_path_full (str) -- the absolute path to the downloaded data.
|
||||
|
||||
"""
|
||||
if dataset is None:
|
||||
selected_dataset = self._present_options()
|
||||
else:
|
||||
selected_dataset = dataset
|
||||
|
||||
save_path_full = join(save_path, selected_dataset.split('.')[0])
|
||||
|
||||
if isdir(save_path_full):
|
||||
warn("\n'{0}' already exists. Voiding Download.".format(
|
||||
save_path_full))
|
||||
else:
|
||||
self._print('Downloading Data...')
|
||||
url = "{0}/{1}".format(self.url, selected_dataset)
|
||||
self._download_data(url, save_path=save_path)
|
||||
|
||||
return abspath(save_path_full)
|
||||
Reference in New Issue
Block a user