File size: 5,023 Bytes
3f7cfab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gc
import datetime
import inspect

import torch
import numpy as np

dtype_memory_size_dict = {
    torch.float64: 64 / 8,
    torch.double: 64 / 8,
    torch.float32: 32 / 8,
    torch.float: 32 / 8,
    torch.float16: 16 / 8,
    torch.half: 16 / 8,
    torch.int64: 64 / 8,
    torch.long: 64 / 8,
    torch.int32: 32 / 8,
    torch.int: 32 / 8,
    torch.int16: 16 / 8,
    torch.short: 16 / 6,
    torch.uint8: 8 / 8,
    torch.int8: 8 / 8,
}
# compatibility of torch1.0
if getattr(torch, "bfloat16", None) is not None:
    dtype_memory_size_dict[torch.bfloat16] = 16 / 8
if getattr(torch, "bool", None) is not None:
    dtype_memory_size_dict[
        torch.bool] = 8 / 8  # pytorch use 1 byte for a bool, see https://github.com/pytorch/pytorch/issues/41571


def get_mem_space(x):
    try:
        ret = dtype_memory_size_dict[x]
    except KeyError:
        print(f"dtype {x} is not supported!")
    return ret


import contextlib, sys

@contextlib.contextmanager
def file_writer(file_name = None):
    # Create writer object based on file_name
    writer = open(file_name, "aw") if file_name is not None else sys.stdout
    # yield the writer object for the actual use
    yield writer
    # If it is file, then close the writer object
    if file_name != None: writer.close()


class MemTracker(object):
    """
    Class used to track pytorch memory usage
    Arguments:
        detail(bool, default True): whether the function shows the detail gpu memory usage
        path(str): where to save log file
        verbose(bool, default False): whether show the trivial exception
        device(int): GPU number, default is 0
    """

    def __init__(self, detail=True, path='', verbose=False, device=0, log_to_disk=False):
        self.print_detail = detail
        self.last_tensor_sizes = set()
        self.gpu_profile_fn = path + f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_track.txt'
        self.verbose = verbose
        self.begin = True
        self.device = device
        self.log_to_disk = log_to_disk

    def get_tensors(self):
        for obj in gc.get_objects():
            try:
                if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                    tensor = obj
                else:
                    continue
                if tensor.is_cuda:
                    yield tensor
            except Exception as e:
                if self.verbose:
                    print('A trivial exception occured: {}'.format(e))

    def get_tensor_usage(self):
        sizes = [np.prod(np.array(tensor.size())) * get_mem_space(tensor.dtype) for tensor in self.get_tensors()]
        return np.sum(sizes) / 1024 ** 2

    def get_allocate_usage(self):
        return torch.cuda.memory_allocated() / 1024 ** 2

    def clear_cache(self):
        gc.collect()
        torch.cuda.empty_cache()

    def print_all_gpu_tensor(self, file=None):
        for x in self.get_tensors():
            print(x.size(), x.dtype, np.prod(np.array(x.size())) * get_mem_space(x.dtype) / 1024 ** 2, file=file)

    def track(self):
        """
        Track the GPU memory usage
        """
        frameinfo = inspect.stack()[1]
        where_str = frameinfo.filename + ' line ' + str(frameinfo.lineno) + ': ' + frameinfo.function

        if self.log_to_disk:
            file_name = self.gpu_profile_fn
        else:
            file_name = None

        with file_writer(file_name) as f:

            if self.begin:
                f.write(f"GPU Memory Track | {datetime.datetime.now():%d-%b-%y-%H:%M:%S} |"
                        f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
                        f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")
                self.begin = False

            if self.print_detail is True:
                ts_list = [(tensor.size(), tensor.dtype) for tensor in self.get_tensors()]
                new_tensor_sizes = {(type(x),
                                     tuple(x.size()),
                                     ts_list.count((x.size(), x.dtype)),
                                     np.prod(np.array(x.size())) * get_mem_space(x.dtype) / 1024 ** 2,
                                     x.dtype) for x in self.get_tensors()}
                for t, s, n, m, data_type in new_tensor_sizes - self.last_tensor_sizes:
                    f.write(
                        f'+ | {str(n)} * Size:{str(s):<20} | Memory: {str(m * n)[:6]} M | {str(t):<20} | {data_type}\n')
                for t, s, n, m, data_type in self.last_tensor_sizes - new_tensor_sizes:
                    f.write(
                        f'- | {str(n)} * Size:{str(s):<20} | Memory: {str(m * n)[:6]} M | {str(t):<20} | {data_type}\n')

                self.last_tensor_sizes = new_tensor_sizes

            f.write(f"\nAt {where_str:<50}"
                    f" Total Tensor Used Memory:{self.get_tensor_usage():<7.1f}Mb"
                    f" Total Allocated Memory:{self.get_allocate_usage():<7.1f}Mb\n\n")