-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy pathprofiling.py
More file actions
73 lines (56 loc) · 2.53 KB
/
profiling.py
File metadata and controls
73 lines (56 loc) · 2.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Comparing torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
Adapted from https://github.com/pytorch/pytorch/tree/v1.11.0/benchmarks/overrides_benchmark
"""
from __future__ import annotations
import argparse
import torch
from min_classes import SubTensor, SubWithTorchFunc
from monai.data import MetaTensor
from monai.utils.profiling import PerfContext
NUM_REPEATS = 1000
NUM_REPEAT_OF_REPEATS = 1000
def bench(t1, t2):
bench_times = []
for _ in range(NUM_REPEAT_OF_REPEATS):
with PerfContext() as pc:
for _ in range(NUM_REPEATS):
torch.add(t1, t2)
bench_times.append(pc.total_time)
bench_time_min = float(torch.min(torch.Tensor(bench_times))) / NUM_REPEATS
bench_time_avg = float(torch.sum(torch.Tensor(bench_times))) / (NUM_REPEATS * NUM_REPEAT_OF_REPEATS)
bench_time_med = float(torch.median(torch.Tensor(bench_times))) / NUM_REPEATS
bench_std = float(torch.std(torch.Tensor(bench_times))) / NUM_REPEATS
return bench_time_min, bench_time_avg, bench_time_med, bench_std
def main():
global NUM_REPEATS
global NUM_REPEAT_OF_REPEATS
parser = argparse.ArgumentParser(description="Run the __torch_function__ benchmarks.")
parser.add_argument(
"--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats for one measurement."
)
parser.add_argument("--nrepreps", "-m", type=int, default=NUM_REPEAT_OF_REPEATS, help="The number of measurements.")
args = parser.parse_args()
NUM_REPEATS = args.nreps
NUM_REPEAT_OF_REPEATS = args.nrepreps
types = torch.Tensor, SubTensor, SubWithTorchFunc, MetaTensor
for t in types:
tensor_1 = t(1)
tensor_2 = t(2)
b_min, b_avg, b_med, b_std = bench(tensor_1, tensor_2)
print(
f"Type {t.__name__} time (microseconds):"
f" min: {10**6 * b_min}, avg: {(10**6) * b_avg}, median: {(10**6) * b_med}, and std {(10**6) * b_std}."
)
if __name__ == "__main__":
main()