-
Notifications
You must be signed in to change notification settings - Fork 4
/
run.py
133 lines (111 loc) · 4.17 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Tritonbench benchmark runner.
Note: make sure to `python install.py` first or otherwise make sure the benchmark you are going to run
has been installed. This script intentionally does not automate or enforce setup steps.
"""
import argparse
import copy
import subprocess
import sys
from typing import List
from tritonbench.operator_loader import load_opbench_by_name_from_loader
from tritonbench.operators import load_opbench_by_name
from tritonbench.operators_collection import list_operators_by_collection
from tritonbench.utils.gpu_utils import gpu_lockdown
from tritonbench.utils.parser import get_parser
from tritonbench.utils.path_utils import add_cmd_parameter, remove_cmd_parameter
from tritonbench.utils.triton_op import BenchmarkOperatorResult, IS_FBCODE
try:
if IS_FBCODE:
from .fb.utils import usage_report_logger # @manual
else:
usage_report_logger = lambda *args, **kwargs: None
except ImportError:
usage_report_logger = lambda *args, **kwargs: None
def _run_in_task(op: str) -> None:
op_task_cmd = [] if IS_FBCODE else [sys.executable]
copy_sys_argv = copy.deepcopy(sys.argv)
copy_sys_argv = remove_cmd_parameter(copy_sys_argv, "--op")
copy_sys_argv = remove_cmd_parameter(copy_sys_argv, "--isolate")
copy_sys_argv = remove_cmd_parameter(copy_sys_argv, "--op-collection")
add_cmd_parameter(copy_sys_argv, "--op", op)
op_task_cmd.extend(copy_sys_argv)
try:
print("[tritonbench] running command: " + " ".join(op_task_cmd))
subprocess.check_call(op_task_cmd, stdout=sys.stdout, stderr=sys.stderr)
except subprocess.CalledProcessError:
# By default, we will continue on the failed operators
pass
except KeyboardInterrupt:
print("KeyboardInterrupt received, exiting...")
sys.exit(1)
def _run(args: argparse.Namespace, extra_args: List[str]) -> BenchmarkOperatorResult:
if args.operator_loader:
Opbench = load_opbench_by_name_from_loader(args)
else:
Opbench = load_opbench_by_name(args.op)
opbench = Opbench(
tb_args=args,
extra_args=extra_args,
)
try:
opbench.run(args.warmup, args.iter)
finally:
metrics = opbench.output
if not args.skip_print:
if args.csv:
metrics.write_csv_to_file(sys.stdout)
else:
print(metrics)
if IS_FBCODE and args.log_scuba:
from .fb.utils import log_benchmark # @manual
kwargs = {
"metrics": metrics,
"benchmark_name": args.op,
"device": args.device,
"logging_group": args.logging_group,
"precision": args.precision,
}
if args.production_shapes:
from tritonbench.utils.fb.durin_data import productionDataLoader
kwargs["weights_loader"] = productionDataLoader
if "hardware" in args:
kwargs["hardware"] = args.hardware
log_benchmark(**kwargs)
if args.plot:
try:
opbench.plot()
except NotImplementedError:
print(f"Plotting is not implemented for {args.op}")
if args.output:
with open(args.output, "w") as f:
metrics.write_csv_to_file(f)
print(f"[TritonBench] Output result csv to {args.output}")
return metrics
def run(args: List[str] = []):
if args == []:
args = sys.argv[1:]
# Log the tool usage
usage_report_logger(benchmark_name="tritonbench")
parser = get_parser()
args, extra_args = parser.parse_known_args(args)
if args.ci:
from .ci import run_ci # @manual
run_ci()
return
if args.op:
ops = args.op.split(",")
else:
ops = list_operators_by_collection(args.op_collection)
# Force isolation in subprocess if testing more than one op.
if len(ops) >= 2:
args.isolate = True
with gpu_lockdown(args.gpu_lockdown):
for op in ops:
args.op = op
if args.isolate:
_run_in_task(op)
else:
_run(args, extra_args)
if __name__ == "__main__":
run()