本文是LaneNet车道线检测效果复现,不涉及原理讲解部分。

开门见山,上链接:
链接:https://pan.baidu.com/s/1yxJNDdR1y4ixW62gDuDawQ
提取码:hcwi

关于LaneNet算法,网上有很多资料,Github上面也有很多,可能是自身检索能力有限,捣鼓了几天,迟迟不能复现代码的效果。主要原因就是某些文件找不到,下载不下来。现在相关文件均放在百度网盘里面了。

windows系统
python3.5.2
相关库具体版本见requirements_new.txt
更新时间2020.06.09

1.下载压缩包,解压,注意这里面的model文件夹下的New_Tusimple_Lanenet_Model_Weights权重文件是自己添加的,某些Github或者博客中并未提供。(为了这个权重,我真是费尽心思,现在分享在百度云盘里New_Tusimple_Lanenet_Model_Weights提取码:s40b)
在这里插入图片描述
在这里插入图片描述
2. 修改tools文件夹下的test_lanenet.py文件,添加相关路径,不然会报错。
在这里插入图片描述
修改成自己的路径

import sys 
sys.path.append('C:/Users/Lenovo/Desktop/lanenet-lane-detection-master')
sys.path.append('C:/Users/Lenovo/Desktop/lanenet-lane-detection-master/config')
sys.path.append('C:/Users/Lenovo/Desktop/lanenet-lane-detection-master/data_provider')
sys.path.append('C:/Users/Lenovo/Desktop/lanenet-lane-detection-master/lanenet_model')
sys.path.append('C:/Users/Lenovo/Desktop/lanenet-lane-detection-master/tools')
  1. lanenet-lane-detection-master文件夹下,运行命令
python tools/test_lanenet.py --weights_path model/New_Tusimple_Lanenet_Model_Weights/tusimple_lanenet_vgg.ckpt  --image_path data/tusimple_test_image/0.jpg
  1. 注意pictures文件夹是我自己新建的,用于保存检测结果的图片,源码中没有这个文件夹。注意requirements_new.txt列出了我电脑装的一些库版本,和原作者版本有些出入,但是并不影响。注意data/tusimple_test_image文件夹保存有测试图片,测试效果很好。你也可以放自己的图片进行检测,但是我测试自己的车道线效果并不好,甚至说很差,原因暂时未知。
  2. 测试效果
    在这里插入图片描述
    上面测试效果还不赖,可是我换成自己的数据集,车道线就飞到天上了。。。
    在这里插入图片描述

——2020年6月11日更新
出现车道线飞到天上的原因找到了,是图片的分辨率不对。图片分辨率要求1280X720,而我的是1280X1024。分辨率调整过后,车道线检测就正常了。
在这里插入图片描述


2021年10月31日更新
有位博友需要做个简单的上位机,显示检测的结果,帮忙做了一个,现开源出来,供大家参考。

上位机样子大概就是下面这样,基于python做的,缺什么库pip什么库即可。
在这里插入图片描述
这里面涉及到两个文件,一个是my_form.py,为上位机界面程序;另一个是test_images.py,为车道线检测程序。
将这两个程序放在根目录下lanenet-lane-detection-master即可,运行test_images.py。可以修改图片路径,检测不同的图片。
test_images.py

import sys
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtGui import *
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
from my_form import Ui_Form
from PIL import Image
from PIL.ImageQt import ImageQt
import qdarkstyle

import argparse
import os.path as ops
import time

import cv2
import glog as log
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import sys 
sys.path.append('C:/Users/Lenovo/Desktop\shang/lanenet-lane-detection-master')
sys.path.append('C:/Users/Lenovo/Desktop\shang/lanenet-lane-detection-master/config')
sys.path.append('C:/Users/Lenovo/Desktop\shang/lanenet-lane-detection-master/data_provider')
sys.path.append('C:/Users/Lenovo/Desktop\shang/lanenet-lane-detection-master/lanenet_model')
sys.path.append('C:/Users/Lenovo/Desktop\shang/lanenet-lane-detection-master/tools')

from config import global_config
from lanenet_model import lanenet
from lanenet_model import lanenet_postprocess


class MyMainForm(QMainWindow, Ui_Form):
    def __init__(self, parent=None):
        super(MyMainForm, self).__init__(parent)
        self.setupUi(self)

        # self.resize(600, 400)
        self.setWindowTitle("LaneNet车道线检测界面")
                            
        self.pushButton.clicked.connect(self.openimage)
        self.pushButton_2.clicked.connect(self.detection_results)

    def openimage(self):
        try:
            global image_path
            image_path = self.lineEdit.text()
            print(type(image_path))
        except:
            pass    

    def detection_results(self):

        CFG = global_config.cfg
        
        def args_str2bool(arg_value):
            if arg_value.lower() in ('yes', 'true', 't', 'y', '1'):
                return True

            elif arg_value.lower() in ('no', 'false', 'f', 'n', '0'):
                return False
            else:
                raise argparse.ArgumentTypeError('Unsupported value encountered.')


        def minmax_scale(input_arr):
            min_val = np.min(input_arr)
            max_val = np.max(input_arr)

            output_arr = (input_arr - min_val) * 255.0 / (max_val - min_val)

            return output_arr

        def messageDialog():
            QtWidgets.QMessageBox.warning(self, "警告", "图片路径加载错误!", QtWidgets.QMessageBox.Cancel)
            

        def test_lanenet():
            try:
                # image_path = "data/tusimple_test_image/0.jpg"
                global image_path
                image_path = image_path
                weights_path = "model/New_Tusimple_Lanenet_Model_Weights/tusimple_lanenet_vgg.ckpt"
                assert ops.exists(image_path), '{:s} not exist'.format(image_path)
            except:
                messageDialog()

            log.info('Start reading image and preprocessing')
            t_start = time.time()
            image = cv2.imread(image_path, cv2.IMREAD_COLOR)
            image_vis = image
            image = cv2.resize(image, (512, 256), interpolation=cv2.INTER_LINEAR)
            image = image / 127.5 - 1.0
            log.info('Image load complete, cost time: {:.5f}s'.format(time.time() - t_start))

            input_tensor = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_tensor')

            net = lanenet.LaneNet(phase='test', net_flag='vgg')
            binary_seg_ret, instance_seg_ret = net.inference(input_tensor=input_tensor, name='lanenet_model')

            postprocessor = lanenet_postprocess.LaneNetPostProcessor()

            saver = tf.train.Saver()

            # Set sess configuration
            sess_config = tf.ConfigProto()
            sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
            sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
            sess_config.gpu_options.allocator_type = 'BFC'

            sess = tf.Session(config=sess_config)

            with sess.as_default():

                saver.restore(sess=sess, save_path=weights_path)

                t_start = time.time()
                binary_seg_image, instance_seg_image = sess.run(
                    [binary_seg_ret, instance_seg_ret],
                    feed_dict={input_tensor: [image]}
                )
                t_cost = time.time() - t_start
                log.info('Single imgae inference cost time: {:.5f}s'.format(t_cost))

                postprocess_result = postprocessor.postprocess(
                    binary_seg_result=binary_seg_image[0],
                    instance_seg_result=instance_seg_image[0],
                    source_image=image_vis
                )
                mask_image = postprocess_result['mask_image']
                
                
                for i in range(CFG.TRAIN.EMBEDDING_FEATS_DIMS):
                    instance_seg_image[0][:, :, i] = minmax_scale(instance_seg_image[0][:, :, i])
                embedding_image = np.array(instance_seg_image[0], np.uint8)
                        
                # 界面显示
                # 显示原始图片
                img_src = image_vis[:, :, (2, 1, 0)]
                img_src = cv2.cvtColor(img_src,cv2.COLOR_BGR2RGB)
                # 读取label宽高
                label_width = self.label_2.width()
                label_height = self.label_2.height()
                # 将图片转换为QImage
                temp_imgSrc = QImage(img_src, img_src.shape[1], img_src.shape[0],img_src.shape[1]*3, QImage.Format_RGB888)
                # 将图片转换为QPixmap方便显示
                pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height)
                # 使用label进行显示
                self.label.setPixmap(pixmap_imgSrc)  

                # 显示原始图片
                img_src = image_vis[:, :, (2, 1, 0)]
                img_src = cv2.cvtColor(img_src,cv2.COLOR_BGR2RGB)
                # 读取label宽高
                label_width = self.label_2.width()
                label_height = self.label_2.height()
                # 将图片转换为QImage
                temp_imgSrc = QImage(img_src, img_src.shape[1], img_src.shape[0],img_src.shape[1]*3, QImage.Format_RGB888)
                # 将图片转换为QPixmap方便显示
                pixmap_imgSrc = QPixmap.fromImage(temp_imgSrc).scaled(label_width, label_height)
                # 使用label进行显示
                self.label_2.setPixmap(pixmap_imgSrc)  
                
                # 显示mask_image
                mask_image = mask_image[:, :, (2, 1, 0)]
                mask_image = cv2.cvtColor(mask_image,cv2.COLOR_BGR2RGB)
                # 读取label宽高
                label_width = self.label_3.width()
                label_height = self.label_3.height()
                # 将图片转换为QImage
                temp_mask = QImage(mask_image, mask_image.shape[1], mask_image.shape[0],mask_image.shape[1]*3, QImage.Format_RGB888)
                # 将图片转换为QPixmap方便显示
                pixmap_mask = QPixmap.fromImage(temp_mask).scaled(label_width, label_height)
                # 使用label进行显示
                self.label_3.setPixmap(pixmap_mask)

                # 显示embedding_image
                embedding_image = embedding_image[:, :, (2, 1, 0)]
                embedding_image = cv2.cvtColor(embedding_image,cv2.COLOR_BGR2RGB)
                # 读取label宽高
                label_width = self.label_4.width()
                label_height = self.label_4.height()
                # 将图片转换为QImage
                temp_embed = QImage(embedding_image, embedding_image.shape[1], embedding_image.shape[0],embedding_image.shape[1]*3, QImage.Format_RGB888)
                # 将图片转换为QPixmap方便显示
                pixmap_embed = QPixmap.fromImage(temp_embed).scaled(label_width, label_height)
                # 使用label进行显示
                self.label_4.setPixmap(pixmap_embed)

                # 二值化
                gray_image = cv2.cvtColor(mask_image, cv2.COLOR_RGB2GRAY)
                ret, binary_seg_image = cv2.threshold(gray_image, 10, 255, cv2.THRESH_BINARY)
                label_width = self.label_5.width()
                label_height = self.label_5.height()
                temp_QtImg = QImage(binary_seg_image.data,binary_seg_image.shape[1],binary_seg_image.shape[0],binary_seg_image.shape[1],QImage.Format_Indexed8)
                pixmap_QtImg = QPixmap.fromImage(temp_QtImg).scaled(label_width, label_height)
                self.label_5.setPixmap(pixmap_QtImg)

            sess.close()

            return

        test_lanenet()
                
if __name__ == "__main__":
    app = QtWidgets.QApplication(sys.argv)
    app.setStyleSheet(qdarkstyle.load_stylesheet())
    myWin = MyMainForm()
    myWin.show()
    sys.exit(app.exec_())

my_form.py

# -*- coding: utf-8 -*-

# Form implementation generated from reading ui file 'untitled.ui'
#
# Created by: PyQt5 UI code generator 5.15.2
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again.  Do not edit this file unless you know what you are doing.


from PyQt5 import QtCore, QtGui, QtWidgets


class Ui_Form(object):
    def setupUi(self, Form):
        Form.setObjectName("Form")
        Form.resize(888, 561)
        self.label_6 = QtWidgets.QLabel(Form)
        self.label_6.setGeometry(QtCore.QRect(310, 10, 281, 29))
        self.label_6.setStyleSheet("font: 16pt \"Ubuntu\";")
        self.label_6.setAlignment(QtCore.Qt.AlignCenter)
        self.label_6.setObjectName("label_6")
        self.layoutWidget = QtWidgets.QWidget(Form)
        self.layoutWidget.setGeometry(QtCore.QRect(20, 50, 851, 501))
        self.layoutWidget.setObjectName("layoutWidget")
        self.gridLayout_2 = QtWidgets.QGridLayout(self.layoutWidget)
        self.gridLayout_2.setContentsMargins(0, 0, 0, 0)
        self.gridLayout_2.setObjectName("gridLayout_2")
        self.gridLayout = QtWidgets.QGridLayout()
        self.gridLayout.setContentsMargins(-1, -1, -1, 10)
        self.gridLayout.setObjectName("gridLayout")
        self.label_3 = QtWidgets.QLabel(self.layoutWidget)
        self.label_3.setLayoutDirection(QtCore.Qt.LeftToRight)
        self.label_3.setStyleSheet("background-color: rgb(211, 215, 207);")
        self.label_3.setAlignment(QtCore.Qt.AlignCenter)
        self.label_3.setObjectName("label_3")
        self.gridLayout.addWidget(self.label_3, 0, 2, 1, 1)
        self.label_2 = QtWidgets.QLabel(self.layoutWidget)
        self.label_2.setLayoutDirection(QtCore.Qt.LeftToRight)
        self.label_2.setStyleSheet("background-color: rgb(211, 215, 207);")
        self.label_2.setAlignment(QtCore.Qt.AlignCenter)
        self.label_2.setObjectName("label_2")
        self.gridLayout.addWidget(self.label_2, 0, 1, 1, 1)
        self.label_4 = QtWidgets.QLabel(self.layoutWidget)
        self.label_4.setLayoutDirection(QtCore.Qt.LeftToRight)
        self.label_4.setStyleSheet("background-color: rgb(211, 215, 207);")
        self.label_4.setAlignment(QtCore.Qt.AlignCenter)
        self.label_4.setObjectName("label_4")
        self.gridLayout.addWidget(self.label_4, 1, 1, 1, 1)
        self.label = QtWidgets.QLabel(self.layoutWidget)
        self.label.setLayoutDirection(QtCore.Qt.LeftToRight)
        self.label.setAutoFillBackground(False)
        self.label.setStyleSheet("background-color: rgb(211, 215, 207);")
        self.label.setAlignment(QtCore.Qt.AlignCenter)
        self.label.setObjectName("label")
        self.gridLayout.addWidget(self.label, 0, 0, 2, 1)
        self.label_5 = QtWidgets.QLabel(self.layoutWidget)
        self.label_5.setLayoutDirection(QtCore.Qt.LeftToRight)
        self.label_5.setStyleSheet("background-color: rgb(211, 215, 207);")
        self.label_5.setAlignment(QtCore.Qt.AlignCenter)
        self.label_5.setObjectName("label_5")
        self.gridLayout.addWidget(self.label_5, 1, 2, 1, 1)
        self.gridLayout_2.addLayout(self.gridLayout, 0, 0, 1, 1)
        self.horizontalLayout_2 = QtWidgets.QHBoxLayout()
        self.horizontalLayout_2.setContentsMargins(-1, -1, -1, 10)
        self.horizontalLayout_2.setSpacing(16)
        self.horizontalLayout_2.setObjectName("horizontalLayout_2")
        self.label_7 = QtWidgets.QLabel(self.layoutWidget)
        self.label_7.setObjectName("label_7")
        self.horizontalLayout_2.addWidget(self.label_7)
        self.lineEdit = QtWidgets.QLineEdit(self.layoutWidget)
        self.lineEdit.setObjectName("lineEdit")
        self.horizontalLayout_2.addWidget(self.lineEdit)
        self.gridLayout_2.addLayout(self.horizontalLayout_2, 1, 0, 1, 1)
        self.horizontalLayout = QtWidgets.QHBoxLayout()
        self.horizontalLayout.setContentsMargins(-1, -1, -1, 0)
        self.horizontalLayout.setSpacing(16)
        self.horizontalLayout.setObjectName("horizontalLayout")
        spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
        self.horizontalLayout.addItem(spacerItem)
        self.pushButton = QtWidgets.QPushButton(self.layoutWidget)
        self.pushButton.setStyleSheet("background-color: rgb(114, 159, 207);")
        self.pushButton.setObjectName("pushButton")
        self.horizontalLayout.addWidget(self.pushButton)
        spacerItem1 = QtWidgets.QSpacerItem(308, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
        self.horizontalLayout.addItem(spacerItem1)
        self.pushButton_2 = QtWidgets.QPushButton(self.layoutWidget)
        self.pushButton_2.setStyleSheet("background-color: rgb(114, 159, 207);")
        self.pushButton_2.setObjectName("pushButton_2")
        self.horizontalLayout.addWidget(self.pushButton_2)
        spacerItem2 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
        self.horizontalLayout.addItem(spacerItem2)
        self.gridLayout_2.addLayout(self.horizontalLayout, 2, 0, 1, 1)
        self.gridLayout_2.setRowStretch(0, 7)
        self.gridLayout_2.setRowStretch(1, 1)
        self.gridLayout_2.setRowStretch(2, 1)

        self.retranslateUi(Form)
        QtCore.QMetaObject.connectSlotsByName(Form)

    def retranslateUi(self, Form):
        _translate = QtCore.QCoreApplication.translate
        Form.setWindowTitle(_translate("Form", "Form"))
        self.label_6.setText(_translate("Form", "LaneNet车道线检测界面"))
        self.label_3.setText(_translate("Form", "IMAGES2"))
        self.label_2.setText(_translate("Form", "IMAGES1"))
        self.label_4.setText(_translate("Form", "IMAGES3"))
        self.label.setText(_translate("Form", "IMAGES"))
        self.label_5.setText(_translate("Form", "IMAGES4"))
        self.label_7.setText(_translate("Form", "图片路径"))
        self.lineEdit.setText(_translate("Form", "data/tusimple_test_image/0.jpg"))
        self.pushButton.setText(_translate("Form", "加载图片"))
        self.pushButton_2.setText(_translate("Form", "处理结果"))
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐