|
| 1 | +import argparse |
1 | 2 | import os |
2 | 3 | import urllib |
3 | 4 | import numpy as np |
@@ -33,16 +34,22 @@ def download_image(download_str, save_dir): |
33 | 34 |
|
34 | 35 |
|
35 | 36 | def main(): |
| 37 | + parser = argparse.ArgumentParser() |
| 38 | + parser.add_argument("--img_url_file", type=str, required=True, |
| 39 | + help="File that contains list of image IDs and urls.") |
| 40 | + parser.add_argument("--output_dir", type=str, required=True, |
| 41 | + help="Directory where to save outputs.") |
| 42 | + parser.add_argument("--n_download_urls", type=int, default=20000, |
| 43 | + help="Directory where to save outputs.") |
| 44 | + args = parser.parse_args() |
| 45 | + |
36 | 46 | np.random.seed(123456) |
37 | | - url_file = "/data/imagenet/fall11_urls.txt" |
38 | | - save_dir = "/data/imagenet/img" |
39 | | - n_download_imgs = 20000 |
40 | 47 |
|
41 | | - with open(url_file) as f: |
| 48 | + with open(args.img_url_file) as f: |
42 | 49 | lines = f.readlines() |
43 | | - lines = np.random.choice(lines, size=n_download_imgs, replace=False) |
| 50 | + lines = np.random.choice(lines, size=args.n_download_urls, replace=False) |
44 | 51 |
|
45 | | - Parallel(n_jobs=12)(delayed(download_image)(line, save_dir) for line in lines) |
| 52 | + Parallel(n_jobs=12)(delayed(download_image)(line, args.output_dir) for line in lines) |
46 | 53 |
|
47 | 54 |
|
48 | 55 | if __name__ == "__main__": |
|
0 commit comments