29 lines
610 B
Python
29 lines
610 B
Python
|
import random
|
|||
|
|
|||
|
import numpy as np
|
|||
|
|
|||
|
from .data import *
|
|||
|
from .metrics import *
|
|||
|
|
|||
|
|
|||
|
def set_random_seed(seed):
|
|||
|
"""设置Python, numpy, PyTorch的随机数种子
|
|||
|
|
|||
|
:param seed: int 随机数种子
|
|||
|
"""
|
|||
|
random.seed(seed)
|
|||
|
np.random.seed(seed)
|
|||
|
torch.manual_seed(seed)
|
|||
|
if torch.cuda.is_available():
|
|||
|
torch.cuda.manual_seed(seed)
|
|||
|
dgl.seed(seed)
|
|||
|
|
|||
|
|
|||
|
def get_device(device):
|
|||
|
"""返回指定的GPU设备
|
|||
|
|
|||
|
:param device: int GPU编号,-1表示CPU
|
|||
|
:return: torch.device
|
|||
|
"""
|
|||
|
return torch.device(f'cuda:{device}' if device >= 0 and torch.cuda.is_available() else 'cpu')
|