# 3D_resnet based on Tencent MedicalNet
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from functools import partial
import os
import numpy as np
from torch.utils.data import Dataset
import nibabel
from scipy import ndimage
from torch import optim
from torch.utils.data import DataLoader
import time
import logging
from torch.optim import lr_scheduler
import sys
import math
import random

# Settings for training
root_dir = './data' #type=str, help='Root directory path of data'
img_list = './data/train.txt' # type=str, help='Path for image list file'
num_seg_classes = 1 #type=int, help="Number of segmentation classes"
learning_rate = 0.001  # set to 0.001 when finetune, type=float, help= 'Initial learning rate (divided by 10 while training by lr scheduler)'
num_workers = 0 # type=int, help='Number of jobs'
batch_size = 1 # type=int, help='Batch Size'
phase = 'train' # type=str, help='Phase of train or test'
save_intervals = 10 # type=int, help='Interation for saving model'
total_epochs = 20 # type=int, help='Number of total epochs to run'
input_D = 56 # type=int, help='Input size of depth'
input_H = 448 # type=int, help='Input size of height'
input_W = 448 # type=int, help='Input size of width'
#resume_path = '' # type=str, help='Path for resume model.'
pretrain_path = 'pretrain/resnet_50.pth' # type=str, help='Path for pretrained model.'
new_layer_names = ['conv_cls']
#default=['upsample1', 'cmp_layer3', 'upsample2', 'cmp_layer2', 'upsample3', 'cmp_layer1', 'upsample4', 'cmp_conv1', 'conv_seg'],
# type=list, help='New layer except for backbone'
no_cuda = False # help='If true, cuda is not used.'
gpu_id = 0 # type=int, help='Gpu id lists'
basemodel = 'resnet' # type=str,help='(resnet | preresnet | wideresnet | resnext | densenet)'
model_depth = 50 # type=int, help='Depth of resnet (10 | 18 | 34 | 50 | 101)'
resnet_shortcut = 'B' # type=str, help='Shortcut type of resnet (A | B)'
manual_seed = 1 # type=int, help='Manually set random seed'
ci_test = False # help='If true, ci testing is used.'
save_folder = "./trails/models/{}_{}".format(basemodel, model_depth)

# 3Dresnet_model backbone
#__all__ = ['ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101','resnet152', 'resnet200']
def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
    # 3x3x3 convolution with padding
    return nn.Conv3d(
        in_planes,
        out_planes,
        kernel_size=3,
        dilation=dilation,
        stride=stride,
        padding=dilation,
        bias=False)
def downsample_basic_block(x, planes, stride, no_cuda=no_cuda):
    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
    zero_pads = torch.Tensor(
        out.size(0), planes - out.size(1), out.size(2), out.size(3),
        out.size(4)).zero_()
    if not no_cuda:
        if isinstance(out.data, torch.cuda.FloatTensor):
            zero_pads = zero_pads.cuda()

    out = Variable(torch.cat([out.data, zero_pads], dim=1))

    return out
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride=stride, dilation=dilation)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes, dilation=dilation)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(
            planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
class ResNet(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 sample_input_D,
                 sample_input_H,
                 sample_input_W,
                 num_seg_classes,
                 shortcut_type='B',
                 no_cuda = False):
        self.inplanes = 64
        self.no_cuda = no_cuda
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv3d(
            1,
            64,
            kernel_size=7,
            stride=(2, 2, 2),
            padding=(3, 3, 3),
            bias=False)

        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
        self.layer2 = self._make_layer(
            block, 128, layers[1], shortcut_type, stride=2)
        self.layer3 = self._make_layer(
            block, 256, layers[2], shortcut_type, stride=1, dilation=2)
        self.layer4 = self._make_layer(
            block, 512, layers[3], shortcut_type, stride=1, dilation=4)

        # self.conv_seg = nn.Sequential(
        #                                 nn.ConvTranspose3d(512 * block.expansion, 32, 2, stride=2),
        #                                 nn.BatchNorm3d(32),
        #                                 nn.ReLU(inplace=True),
        #                                 nn.Conv3d(32, 32, kernel_size=3, stride=(1, 1, 1), padding=(1, 1, 1), bias=False),
        #                                 nn.BatchNorm3d(32),
        #                                 nn.ReLU(inplace=True),
        #                                 nn.Conv3d(32, num_seg_classes, kernel_size=1, stride=(1, 1, 1), bias=False)
        #                                 )

        self.conv_cls = nn.Sequential(
                                        nn.AdaptiveMaxPool3d(output_size=(1, 1, 1)),
                                        nn.Flatten(start_dim=1),
                                        nn.Dropout(0.1),
                                        nn.Linear(512 * block.expansion, num_seg_classes)
                                        )


        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(
                    downsample_basic_block,
                    planes=planes * block.expansion,
                    stride=stride,
                    no_cuda=self.no_cuda)
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(
                        self.inplanes,
                        planes * block.expansion,
                        kernel_size=1,
                        stride=stride,
                        bias=False), nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        # x = self.conv_seg(x)
        x = self.conv_cls(x)
        x = torch.sigmoid_(x)

        return x
def resnet10(**kwargs):
    """Constructs a ResNet-18 model.
    """
    model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
    return model
def resnet18(**kwargs):
    """Constructs a ResNet-18 model.
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    return model
def resnet34(**kwargs):
    """Constructs a ResNet-34 model.
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    return model
def resnet50(**kwargs):
    """Constructs a ResNet-50 model.
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    return model
def resnet101(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    return model
def resnet152(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    return model
def resnet200(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs)
    return model
# get 3Dresnet_model
def generate_model(basemodel, model_depth, input_D, input_H, input_W, num_seg_classes, no_cuda, phase, pretrain_path):
    assert basemodel in [
        'resnet'
    ]

    if basemodel == 'resnet':
        assert model_depth in [10, 18, 34, 50, 101, 152, 200]

        if model_depth == 10:
            model = resnet10(
                sample_input_W=input_W,
                sample_input_H=input_H,
                sample_input_D=input_D,
                shortcut_type=resnet_shortcut,
                no_cuda=no_cuda,
                num_seg_classes=num_seg_classes)
        elif model_depth == 18:
            model = resnet18(
                sample_input_W=input_W,
                sample_input_H=input_H,
                sample_input_D=input_D,
                shortcut_type=resnet_shortcut,
                no_cuda=no_cuda,
                num_seg_classes=num_seg_classes)
        elif model_depth == 34:
            model = resnet34(
                sample_input_W=input_W,
                sample_input_H=input_H,
                sample_input_D=input_D,
                shortcut_type=resnet_shortcut,
                no_cuda=no_cuda,
                num_seg_classes=num_seg_classes)
        elif model_depth == 50:
            model = resnet50(
                sample_input_W=input_W,
                sample_input_H=input_H,
                sample_input_D=input_D,
                shortcut_type=resnet_shortcut,
                no_cuda=no_cuda,
                num_seg_classes=num_seg_classes)
        elif model_depth == 101:
            model = resnet101(
                sample_input_W=input_W,
                sample_input_H=input_H,
                sample_input_D=input_D,
                shortcut_type=resnet_shortcut,
                no_cuda=no_cuda,
                num_seg_classes=num_seg_classes)
        elif model_depth == 152:
            model = resnet152(
                sample_input_W=input_W,
                sample_input_H=input_H,
                sample_input_D=input_D,
                shortcut_type=resnet_shortcut,
                no_cuda=no_cuda,
                num_seg_classes=num_seg_classes)
        elif model_depth == 200:
            model = resnet200(
                sample_input_W=input_W,
                sample_input_H=input_H,
                sample_input_D=input_D,
                shortcut_type=resnet_shortcut,
                no_cuda=no_cuda,
                num_seg_classes=num_seg_classes)

    if not no_cuda:
        if gpu_id > 1:
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=gpu_id)
            net_dict = model.state_dict()
        else:
            import os
            os.environ["CUDA_VISIBLE_DEVICES"]=str(0)
            model = model.cuda()
            model = nn.DataParallel(model, device_ids=None)
            net_dict = model.state_dict()
    else:
        net_dict = model.state_dict()

    # load pretrain
    if phase != 'test' and pretrain_path:
        print ('loading pretrained model {}'.format(pretrain_path))
        pretrain = torch.load(pretrain_path)
        pretrain_dict = {k.replace("module.", ""): v for k, v in pretrain['state_dict'].items() if k.replace("module.", "") in net_dict.keys()}

        net_dict.update(pretrain_dict)
        model.load_state_dict(net_dict)

        new_parameters = []
        for pname, p in model.named_parameters():
            for layer_name in new_layer_names:
                if pname.find(layer_name) >= 0:
                    new_parameters.append(p)
                    break

        new_parameters_id = list(map(id, new_parameters))
        base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
        parameters = {'base_parameters': base_parameters,
                      'new_parameters': new_parameters}

        return model, parameters

    return model, model.parameters()

# define Dataset for training
class Dataset(Dataset):

    def __init__(self, root_dir, img_list, input_D, input_H, input_W, phase):
        with open(img_list, 'r') as f:
            self.img_list = [line.strip() for line in f]
        print("Processing {} datas".format(len(self.img_list)))
        self.root_dir = root_dir
        self.input_D = input_D
        self.input_H = input_H
        self.input_W = input_W
        self.phase = phase

    def __nii2tensorarray__(self, data):
        [z, y, x] = data.shape
        new_data = np.reshape(data, [1, z, y, x])
        new_data = new_data.astype("float32")

        return new_data

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):

        if self.phase == "train":
            # read image and labels
            ith_info = self.img_list[idx].split(" ")
            img_name = os.path.join(self.root_dir, ith_info[0])
            label_name = os.path.join(self.root_dir, ith_info[1])
            class_array = np.arange(2) ######
            class_array[1] = ith_info[2]  #####
            class_array = torch.tensor(class_array, dtype=torch.float32) ######
            assert os.path.isfile(img_name)
            assert os.path.isfile(label_name)
            img = nibabel.load(img_name)  # We have transposed the data from WHD format to DHW
            assert img is not None
            mask = nibabel.load(label_name)
            assert mask is not None

            # data processing
            img_array, mask_array = self.__training_data_process__(img, mask)

            # 2 tensor array
            img_array = self.__nii2tensorarray__(img_array)
            mask_array = self.__nii2tensorarray__(mask_array)

            assert img_array.shape ==  mask_array.shape, "img shape:{} is not equal to mask shape:{}".format(img_array.shape, mask_array.shape)
            return img_array, mask_array, class_array #####

        elif self.phase == "test":
            # read image
            ith_info = self.img_list[idx].split(" ")
            img_name = os.path.join(self.root_dir, ith_info[0])
            print(img_name)
            assert os.path.isfile(img_name)
            img = nibabel.load(img_name)
            assert img is not None

            # data processing
            img_array = self.__testing_data_process__(img)

            # 2 tensor array
            img_array = self.__nii2tensorarray__(img_array)

            return img_array


    def __drop_invalid_range__(self, volume, label=None):
        """
        Cut off the invalid area
        """
        zero_value = volume[0, 0, 0]
        non_zeros_idx = np.where(volume != zero_value)

        [max_z, max_h, max_w] = np.max(np.array(non_zeros_idx), axis=1)
        [min_z, min_h, min_w] = np.min(np.array(non_zeros_idx), axis=1)

        if label is not None:
            return volume[min_z:max_z, min_h:max_h, min_w:max_w], label[min_z:max_z, min_h:max_h, min_w:max_w]
        else:
            return volume[min_z:max_z, min_h:max_h, min_w:max_w]


    def __random_center_crop__(self, data, label):
        from random import random
        """
        Random crop
        """
        target_indexs = np.where(label>0)
        [img_d, img_h, img_w] = data.shape
        [max_D, max_H, max_W] = np.max(np.array(target_indexs), axis=1)
        [min_D, min_H, min_W] = np.min(np.array(target_indexs), axis=1)
        [target_depth, target_height, target_width] = np.array([max_D, max_H, max_W]) - np.array([min_D, min_H, min_W])
        Z_min = int((min_D - target_depth*1.0/2) * random())
        Y_min = int((min_H - target_height*1.0/2) * random())
        X_min = int((min_W - target_width*1.0/2) * random())

        Z_max = int(img_d - ((img_d - (max_D + target_depth*1.0/2)) * random()))
        Y_max = int(img_h - ((img_h - (max_H + target_height*1.0/2)) * random()))
        X_max = int(img_w - ((img_w - (max_W + target_width*1.0/2)) * random()))

        Z_min = np.max([0, Z_min])
        Y_min = np.max([0, Y_min])
        X_min = np.max([0, X_min])

        Z_max = np.min([img_d, Z_max])
        Y_max = np.min([img_h, Y_max])
        X_max = np.min([img_w, X_max])

        Z_min = int(Z_min)
        Y_min = int(Y_min)
        X_min = int(X_min)

        Z_max = int(Z_max)
        Y_max = int(Y_max)
        X_max = int(X_max)

        return data[Z_min: Z_max, Y_min: Y_max, X_min: X_max], label[Z_min: Z_max, Y_min: Y_max, X_min: X_max]



    def __itensity_normalize_one_volume__(self, volume):
        """
        normalize the itensity of an nd volume based on the mean and std of nonzeor region
        inputs:
            volume: the input nd volume
        outputs:
            out: the normalized nd volume
        """

        pixels = volume[volume > 0]
        mean = pixels.mean()
        std  = pixels.std()
        out = (volume - mean)/std
        out_random = np.random.normal(0, 1, size = volume.shape)
        out[volume == 0] = out_random[volume == 0]
        return out

    def __resize_data__(self, data):
        """
        Resize the data to the input size
        """
        [depth, height, width] = data.shape
        scale = [self.input_D*1.0/depth, self.input_H*1.0/height, self.input_W*1.0/width]
        data = ndimage.zoom(data, scale, order=0)

        return data


    def __crop_data__(self, data, label):
        """
        Random crop with different methods:
        """
        # random center crop
        data, label = self.__random_center_crop__ (data, label)

        return data, label

    def __training_data_process__(self, data, label):
        # crop data according net input size
        data = data.get_fdata()
        label = label.get_fdata()

        # drop out the invalid range
        data, label = self.__drop_invalid_range__(data, label)

        # crop data
        data, label = self.__crop_data__(data, label)

        # resize data
        data = self.__resize_data__(data)
        label = self.__resize_data__(label)

        # normalization datas
        data = self.__itensity_normalize_one_volume__(data)

        return data, label


    def __testing_data_process__(self, data):
        # crop data according net input size
        data = data.get_fdata()

        # resize data
        data = self.__resize_data__(data)

        # normalization datas
        data = self.__itensity_normalize_one_volume__(data)

        return data

# define logger
logging.basicConfig(format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',datefmt='%Y-%m-%d %H:%M:%S',level=logging.DEBUG)
log = logging.getLogger()

# get model
torch.manual_seed(manual_seed)
model, parameters = generate_model(basemodel, model_depth, input_D, input_H, input_W, num_seg_classes, no_cuda, phase, pretrain_path)
for param_name, param in model.named_parameters():
    if param_name.startswith("conv_cls"):
        param.requires_grad = True
    else:
        param.requires_grad = False
# get training dataset
training_dataset = Dataset(root_dir=root_dir, img_list=img_list, input_D=input_D, input_H=input_H, input_W=input_W, phase=phase)
# get data loader
data_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
# optimizer
params = [
        { 'params': parameters['base_parameters'], 'lr': learning_rate },
        { 'params': parameters['new_parameters'], 'lr': learning_rate*100 }
        ]
optimizer = torch.optim.SGD(params, momentum=0.9, weight_decay=1e-3)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
# train from resume
# if resume_path:
#     if os.path.isfile(resume_path):
#         print("=> loading checkpoint '{}'".format(resume_path))
#         checkpoint = torch.load(resume_path)
#         model.load_state_dict(checkpoint['state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer'])
#         print("=> loaded checkpoint '{}' (epoch {})"
#           .format(resume_path, checkpoint['epoch']))

# define train
def train(data_loader, model, optimizer, scheduler, total_epochs, save_interval, save_folder, no_cuda):
    # settings
    batches_per_epoch = len(data_loader)
    log.info('{} epochs in total, {} batches per epoch'.format(total_epochs, batches_per_epoch))
    loss_seg = loss_seg = nn.BCELoss() #nn.CrossEntropyLoss(ignore_index=-1)

    if not no_cuda:
        loss_seg = loss_seg.cuda()

    model.train()
    train_time_sp = time.time()
    for epoch in range(total_epochs):
        log.info('Start epoch {}'.format(epoch))

        scheduler.step()
        log.info('lr = {}'.format(scheduler.get_lr()))

        for batch_id, batch_data in enumerate(data_loader):
            # getting data batch
            batch_id_sp = epoch * batches_per_epoch
            volumes, label_masks, class_array = batch_data  #####

            if not no_cuda:
                volumes = volumes.cuda()
                class_array = class_array.cuda() #####
            optimizer.zero_grad()
            out_masks = model(volumes)
            # resize label
            # [n, _, d, h, w] = out_masks.shape
            # new_label_masks = np.zeros([n, d, h, w])
            # for label_id in range(n):
            #     label_mask = label_masks[label_id]
            #     [ori_c, ori_d, ori_h, ori_w] = label_mask.shape
            #     label_mask = np.reshape(label_mask, [ori_d, ori_h, ori_w])
            #     scale = [d*1.0/ori_d, h*1.0/ori_h, w*1.0/ori_w]
            #     label_mask = ndimage.zoom(label_mask, scale, order=0)
            #     new_label_masks[label_id] = label_mask

            new_label_masks = torch.tensor(out_masks).to(torch.int64) #####
            if not no_cuda:
                new_label_masks = new_label_masks.cuda()

            # calculating loss
            loss_value_seg = loss_seg(out_masks, class_array)#####new_label_masks
            loss = loss_value_seg
            loss.requires_grad_(True) #####
            loss.backward()
            optimizer.step()

            avg_batch_time = (time.time() - train_time_sp) / (1 + batch_id_sp)
            log.info(
                    'Batch: {}-{} ({}), loss = {:.3f}, loss_seg = {:.3f}, avg_batch_time = {:.3f}'\
                    .format(epoch, batch_id, batch_id_sp, loss.item(), loss_value_seg.item(), avg_batch_time))

            # save model
            if batch_id == 0 and batch_id_sp != 0 and batch_id_sp % save_interval == 0:
            #if batch_id_sp != 0 and batch_id_sp % save_interval == 0:
                model_save_path = '{}_epoch_{}_batch_{}.pth.tar'.format(save_folder, epoch, batch_id)
                model_save_dir = os.path.dirname(model_save_path)
                if not os.path.exists(model_save_dir):
                    os.makedirs(model_save_dir)

                log.info('Save checkpoints: epoch = {}, batch_id = {}'.format(epoch, batch_id))
                torch.save({
                            'ecpoch': epoch,
                            'batch_id': batch_id,
                            'state_dict': model.state_dict(),
                            'optimizer': optimizer.state_dict()},
                            model_save_path)
    print('Finished training')

# training
train(data_loader=data_loader, model=model, optimizer=optimizer, scheduler=scheduler, total_epochs=total_epochs, save_interval=save_intervals, save_folder=save_folder, no_cuda=no_cuda)


# settting for test
phase = 'test'
resume_path = 'trails/models/resnet_50_epoch_1_batch_0.pth.tar'
img_list = './data/val.txt'

# read val files
def load_lines(file_path):
    """Read file into a list of lines.

    Input
      file_path: file path

    Output
      lines: an array of lines
    """
    with open(file_path, 'r') as fio:
        lines = fio.read().splitlines()
    return lines

# calculate the dice between prediction and ground truth
# def seg_eval(pred, label, clss):
#     """
#     input:
#         pred: predicted mask
#         label: groud truth
#         clss: eg. [0, 1] for binary class
#     """
#     Ncls = len(clss)
#     dices = np.zeros(Ncls)
#     [depth, height, width] = pred.shape
#     for idx, cls in enumerate(clss):
#         # binary map
#         pred_cls = np.zeros([depth, height, width])
#         pred_cls[np.where(pred == cls)] = 1
#         label_cls = np.zeros([depth, height, width])
#         label_cls[np.where(label == cls)] = 1
#
#         # cal the inter & conv
#         s = pred_cls + label_cls
#         inter = len(np.where(s >= 2)[0])
#         conv = len(np.where(s >= 1)[0]) + inter
#         try:
#             dice = 2.0 * inter / conv
#         except:
#             print("conv is zeros when dice = 2.0 * inter / conv")
#             dice = -1
#
#         dices[idx] = dice
#
#     return dices

# define test
def test(data_loader, model, img_names, no_cuda):
    masks = []
    model.eval() # for testing
    for batch_id, batch_data in enumerate(data_loader):
        # forward
        volume = batch_data
        if not no_cuda:
            volume = volume.cuda()
        with torch.no_grad():
            probs = model(volume)
            # probs = F.softmax(probs, dim=1)

        # resize mask to original size
        # [batchsize, _, mask_d, mask_h, mask_w] = probs.shape
        # data = nibabel.load(os.path.join(root_dir, img_names[batch_id]))
        # data = data.get_fdata()
        # [depth, height, width] = data.shape
        # mask = probs[0]
        # scale = [1, depth*1.0/mask_d, height*1.0/mask_h, width*1.0/mask_w]
        # mask = ndimage.zoom(mask.cpu(), scale, order=1)
        # mask = np.argmax(mask, axis=0)

        masks.append(probs.cpu().item()) # mask

    return masks

# getting model
checkpoint = torch.load(resume_path)
net, _ = generate_model(basemodel, model_depth, input_D, input_H, input_W, num_seg_classes, no_cuda, phase, pretrain_path)
net.load_state_dict(checkpoint['state_dict'])

# data tensor
testing_data = Dataset(root_dir=root_dir, img_list=img_list, input_D=input_D, input_H=input_H, input_W=input_W, phase=phase)
data_loader = DataLoader(testing_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False)

# testing
img_names = [info.split(" ")[0] for info in load_lines(img_list)]
masks = test(data_loader, net, img_names, no_cuda)
class_names = [int(info.split(" ")[2]) for info in load_lines(img_list)]

from sklearn.metrics import roc_curve, auc
fpr,tpr, thresholds = roc_curve(class_names, masks)
roc_auc = auc(fpr, tpr)

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐