写在前面

  • 本次参加datawhale组织的语音识别比赛,主要是想体验一下流程,以及熟悉一下天池打比赛的环境。
  • 今天花费了大量时间在天池建mxnet环境,企图白嫖GPU,报错了AttributeError,代码与本地相同,唯一区别是天池python是3.6,有空再试试。

BaseLine源码

解压训练集、测试集

!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/train_sample.zip

!unzip -qq train_sample.zip
!\rm train_sample.zip
!wget http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/test_a.zip
--2021-04-13 16:24:50--  http://tianchi-competition.oss-cn-hangzhou.aliyuncs.com/531887/test_a.zip
Resolving tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)... 118.31.232.194
Connecting to tianchi-competition.oss-cn-hangzhou.aliyuncs.com (tianchi-competition.oss-cn-hangzhou.aliyuncs.com)|118.31.232.194|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1092637852 (1.0G) [application/zip]
Saving to: ‘test_a.zip’

100%[====================================>] 1,092,637,852 11.6MB/s   in 88s    

2021-04-13 16:26:19 (11.8 MB/s) - ‘test_a.zip’ saved [1092637852/1092637852]
!unzip -qq test_a.zip
!\rm test_a.zip

环境要求

  • TensorFlow的版本:2.0 +
  • keras
  • sklearn
  • librosa
# 基本库
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler

加载深度学习框架

# 搭建分类模型所需要的库

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPool2D, Dropout
from tensorflow.keras.utils import to_categorical 

from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
/opt/conda/lib/python3.6/site-packages/sklearn/ensemble/weight_boosting.py:29: DeprecationWarning: numpy.core.umath_tests is an internal NumPy module and should not be imported. It will be removed in a future NumPy release.
  from numpy.core.umath_tests import inner1d

加载音频处理库

  • 这里当时居然运行报错了,右上角切换了一下环境,不再报错
  • conda list依然没有librosa,百思不得其解。
!pip install librosa --user
!conda list
Looking in indexes: https://mirrors.aliyun.com/pypi/simple
Requirement already satisfied: librosa in /data/nas/workspace/envs/python3.6/site-packages (0.8.0)
Requirement already satisfied: numba>=0.43.0 in /data/nas/workspace/envs/python3.6/site-packages (from librosa) (0.53.1)
Requirement already satisfied: decorator>=3.0.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (4.4.2)
Requirement already satisfied: joblib>=0.14 in /opt/conda/lib/python3.6/site-packages (from librosa) (1.0.0)
Requirement already satisfied: soundfile>=0.9.0 in /data/nas/workspace/envs/python3.6/site-packages (from librosa) (0.10.3.post1)
Requirement already satisfied: audioread>=2.0.0 in /data/nas/workspace/envs/python3.6/site-packages (from librosa) (2.1.9)
Requirement already satisfied: numpy>=1.15.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (1.19.4)
Requirement already satisfied: resampy>=0.2.2 in /data/nas/workspace/envs/python3.6/site-packages (from librosa) (0.2.2)
Requirement already satisfied: pooch>=1.0 in /data/nas/workspace/envs/python3.6/site-packages (from librosa) (1.3.0)
Requirement already satisfied: scipy>=1.0.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (1.5.4)
Requirement already satisfied: scikit-learn!=0.19.0,>=0.14.0 in /opt/conda/lib/python3.6/site-packages (from librosa) (0.24.0)
Requirement already satisfied: llvmlite<0.37,>=0.36.0rc1 in /data/nas/workspace/envs/python3.6/site-packages (from numba>=0.43.0->librosa) (0.36.0)
Requirement already satisfied: setuptools in /opt/conda/lib/python3.6/site-packages (from numba>=0.43.0->librosa) (51.1.1)
Requirement already satisfied: appdirs in /data/nas/workspace/envs/python3.6/site-packages (from pooch>=1.0->librosa) (1.4.4)
Requirement already satisfied: requests in /opt/conda/lib/python3.6/site-packages (from pooch>=1.0->librosa) (2.25.1)
Requirement already satisfied: packaging in /opt/conda/lib/python3.6/site-packages (from pooch>=1.0->librosa) (20.8)
Requirement already satisfied: six>=1.3 in /opt/conda/lib/python3.6/site-packages (from resampy>=0.2.2->librosa) (1.15.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.6/site-packages (from scikit-learn!=0.19.0,>=0.14.0->librosa) (2.1.0)
Requirement already satisfied: cffi>=1.0 in /opt/conda/lib/python3.6/site-packages (from soundfile>=0.9.0->librosa) (1.14.4)
Requirement already satisfied: pycparser in /opt/conda/lib/python3.6/site-packages (from cffi>=1.0->soundfile>=0.9.0->librosa) (2.20)
Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.6/site-packages (from packaging->pooch>=1.0->librosa) (2.4.7)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (1.26.2)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (2020.12.5)
Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (2.10)
Requirement already satisfied: chardet<5,>=3.0.2 in /opt/conda/lib/python3.6/site-packages (from requests->pooch>=1.0->librosa) (4.0.0)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/pillow-5.2.0-py36heded4f4_0.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/six-1.11.0-py36h372c433_1.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/idna-2.6-py36h82fb2a8_1.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/requests-2.18.4-py36he2e5f8d_1.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/pycparser-2.18-py36hf9f622e_1.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/chardet-3.0.4-py36h0f667ec_1.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/python-graphviz-0.15-pyhd3eb1b0_0.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/wheel-0.30.0-py36hfd4bba0_1.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/urllib3-1.22-py36hbe7ace6_0.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/pytorch-cpu-1.1.0-py3.6_cpu_0.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/setuptools-36.5.0-py36he42e2e1_0.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/pip-9.0.1-py36h6c6f9ce_4.json.  Please remove this file manually (you may need to reboot to free file handles)
WARNING conda.gateways.disk.delete:unlink_or_rename_to_trash(140): Could not remove or rename /opt/conda/conda-meta/numpy-base-1.15.4-py36h81de0dd_0.json.  Please remove this file manually (you may need to reboot to free file handles)
# packages in environment at /opt/conda:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main    defaults
_py-xgboost-mutex         2.0                       cpu_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
absl-py                   0.11.0                   pypi_0    pypi
aliyun-python-sdk-core    2.13.5                   pypi_0    pypi
aliyun-python-sdk-core-v3 2.13.3                   pypi_0    pypi
aliyun-python-sdk-kms     2.7.1                    pypi_0    pypi
argon2-cffi               20.1.0                   pypi_0    pypi
asn1crypto                0.23.0           py36h4639342_0    defaults
astor                     0.8.0                    pypi_0    pypi
astunparse                1.6.3                    pypi_0    pypi
async-generator           1.10                     pypi_0    pypi
attrs                     20.3.0                   pypi_0    pypi
backcall                  0.2.0                    pypi_0    pypi
blas                      1.0                         mkl    defaults
bleach                    3.2.1                    pypi_0    pypi
bzip2                     1.0.8                h7b6447c_0    defaults
ca-certificates           2020.12.8            h06a4308_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
cachetools                4.2.1                    pypi_0    pypi
cairo                     1.14.12              h8948797_3    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
certifi                   2020.12.5                pypi_0    pypi
cffi                      1.14.4                   pypi_0    pypi
chardet                   4.0.0                    pypi_0    pypi
cloudpickle               1.6.0                    pypi_0    pypi
conda                     4.9.2            py36h06a4308_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
conda-env                 2.6.0                         1    defaults
conda-package-handling    1.3.11                   py36_0    defaults
crcmod                    1.7                      pypi_0    pypi
cryptography              2.3.1            py36hc365091_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
cycler                    0.10.0                   pypi_0    pypi
cython                    0.29.21                  pypi_0    pypi
dataclasses               0.8                      pypi_0    pypi
decorator                 4.4.2                    pypi_0    pypi
defusedxml                0.6.0                    pypi_0    pypi
dlib                      19.21.1                  pypi_0    pypi
dsw-demos-extension       0.1.0                    pypi_0    pypi
dsw-ipykernel             0.2.0                    pypi_0    pypi
dsw-magic                 0.0.1                    pypi_0    pypi
dsw-sql-extension         0.1.0                    pypi_0    pypi
dswdlv                    0.0.1                    pypi_0    pypi
dswmagic                  0.0.1                    pypi_0    pypi
entrypoints               0.3                      pypi_0    pypi
expat                     2.2.5                he0dffb1_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
faiss-cpu                 1.4.0            py36_cuda0.0_1    pytorch
fasttext                  0.9.2                    pypi_0    pypi
ffmpeg-python             0.2.0                    pypi_0    pypi
flatbuffers               1.12                     pypi_0    pypi
fontconfig                2.13.0               h9420a91_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
freetype                  2.9.1                h8a8886c_1    defaults
fribidi                   1.0.10               h7b6447c_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
future                    0.17.1                   pypi_0    pypi
gast                      0.3.3                    pypi_0    pypi
glib                      2.56.1               h000015b_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
google-auth               1.27.1                   pypi_0    pypi
google-auth-oauthlib      0.4.3                    pypi_0    pypi
google-pasta              0.2.0                    pypi_0    pypi
graphite2                 1.3.11               h16798f4_2    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
graphviz                  2.40.1               h21bd128_2    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
grpcio                    1.32.0                   pypi_0    pypi
h5py                      2.10.0                   pypi_0    pypi
harfbuzz                  1.8.4                hec2c2bc_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
icu                       58.2                 h9c2bf20_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
idna                      2.10                     pypi_0    pypi
imbalanced-learn          0.3.1              pyh2cb239c_0    glemaitre
importlib-metadata        3.3.0                    pypi_0    pypi
intel-openmp              2019.4                      243    defaults
ipykernel                 5.4.2                    pypi_0    pypi
ipython                   7.9.0                    pypi_0    pypi
ipython-genutils          0.2.0                    pypi_0    pypi
jedi                      0.18.0                   pypi_0    pypi
jinja2                    2.11.2                   pypi_0    pypi
jmespath                  0.9.4                    pypi_0    pypi
joblib                    1.0.0                    pypi_0    pypi
jpeg                      9b                   h024ee3a_2    defaults
json5                     0.9.5                    pypi_0    pypi
jsonschema                3.2.0                    pypi_0    pypi
jupyter-client            6.1.7                    pypi_0    pypi
jupyter-core              4.7.0                    pypi_0    pypi
jupyterlab                2.2.8                    pypi_0    pypi
jupyterlab-launcher       0.13.1                   pypi_0    pypi
jupyterlab-prometheus     0.1                      pypi_0    pypi
jupyterlab-pygments       0.1.2                    pypi_0    pypi
jupyterlab-server         1.2.0                    pypi_0    pypi
keras                     2.2.4                    pypi_0    pypi
keras-applications        1.0.8                    pypi_0    pypi
keras-preprocessing       1.1.2                    pypi_0    pypi
kiwisolver                1.2.0                    pypi_0    pypi
libarchive                3.3.3                h7d0bbab_1    defaults
libedit                   3.1                  heed3624_0    defaults
libffi                    3.2.1                hd88cf55_4    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libgcc                    7.2.0                h69d50b8_2    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libgcc-ng                 9.1.0                hdf63c60_0    defaults
libgfortran-ng            7.3.0                hdf63c60_0    defaults
libpng                    1.6.37               hbc83047_0    defaults
libstdcxx-ng              7.2.0                h7a57d05_2    defaults
libtiff                   4.0.9                he85c1e1_1    defaults
libuuid                   1.0.3                h1bed415_2    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libxcb                    1.14                 h7b6447c_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libxgboost                0.90                 hf484d3e_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
libxml2                   2.9.9                hea5a465_1    defaults
lightgbm                  2.3.1                    pypi_0    pypi
lz4-c                     1.8.1.2              h14c3975_0    defaults
lzo                       2.10                 h49e0be7_2    defaults
markdown                  3.1.1                    pypi_0    pypi
markupsafe                1.1.1                    pypi_0    pypi
matplotlib                3.3.3                    pypi_0    pypi
mistune                   0.8.4                    pypi_0    pypi
mkl                       2018.0.3                      1    defaults
mkl_fft                   1.0.4            py36h4414c95_1    defaults
mkl_random                1.0.1            py36h4414c95_1    defaults
nbclient                  0.5.1                    pypi_0    pypi
nbconvert                 6.0.7                    pypi_0    pypi
nbformat                  5.0.8                    pypi_0    pypi
ncurses                   6.0                  h9df7e31_2    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
nest-asyncio              1.4.3                    pypi_0    pypi
ninja                     1.8.2            py36h6bb024c_1    defaults
notebook                  6.1.6                    pypi_0    pypi
np-utils                  0.5.10.0                 pypi_0    pypi
numpy                     1.19.4                   pypi_0    pypi
oauthlib                  3.1.0                    pypi_0    pypi
odps                      3.5.1                    pypi_0    pypi
olefile                   0.46                     py36_0    defaults
open-from-url             0.1.0                    pypi_0    pypi
openssl                   1.0.2u               h7b6447c_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
opt-einsum                3.3.0                    pypi_0    pypi
oss2                      2.8.0                    pypi_0    pypi
packaging                 20.8                     pypi_0    pypi
palettable                3.3.0                    pypi_0    pypi
pandas                    1.1.5                    pypi_0    pypi
pandocfilters             1.4.3                    pypi_0    pypi
pango                     1.42.3               h8589676_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
parso                     0.8.1                    pypi_0    pypi
pcre                      8.42                 h439df22_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
pexpect                   4.8.0                    pypi_0    pypi
pickleshare               0.7.5                    pypi_0    pypi
pillow                    8.0.1                    pypi_0    pypi
pip                       21.0.1                   pypi_0    pypi
pixman                    0.40.0               h7b6447c_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
prometheus-client         0.9.0                    pypi_0    pypi
prompt-toolkit            2.0.10                   pypi_0    pypi
protobuf                  3.15.5                   pypi_0    pypi
ptyprocess                0.7.0                    pypi_0    pypi
py-xgboost                0.90             py36hf484d3e_1    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
pyasn1                    0.4.8                    pypi_0    pypi
pyasn1-modules            0.2.8                    pypi_0    pypi
pybind11                  2.6.1                    pypi_0    pypi
pycosat                   0.6.3            py36h27cfd23_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
pycparser                 2.20                     pypi_0    pypi
pycryptodome              3.8.2                    pypi_0    pypi
pygments                  2.7.3                    pypi_0    pypi
pymars                    0.6.1                    pypi_0    pypi
pyodps                    0.10.3                   pypi_0    pypi
pyopenssl                 17.5.0           py36h20ba746_0    defaults
pyparsing                 2.4.7                    pypi_0    pypi
pyrsistent                0.17.3                   pypi_0    pypi
pysocks                   1.6.7            py36hd97a5b1_1    defaults
python                    3.6.5                hc3d631a_2    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
python-dateutil           2.8.1                    pypi_0    pypi
python-dotenv             0.15.0                   pypi_0    pypi
python-graphviz           0.16                     pypi_0    pypi
python-libarchive-c       2.8                     py36_13    defaults
pytz                      2020.5                   pypi_0    pypi
pyyaml                    5.1.2                    pypi_0    pypi
pyzmq                     20.0.0                   pypi_0    pypi
readline                  7.0                  ha6073c6_4    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
requests                  2.25.1                   pypi_0    pypi
requests-oauthlib         1.3.0                    pypi_0    pypi
rsa                       4.7.2                    pypi_0    pypi
ruamel_yaml               0.11.14          py36ha2fb22d_2    defaults
scikit-learn              0.24.0                   pypi_0    pypi
scipy                     1.5.4                    pypi_0    pypi
seaborn                   0.10.1                   pypi_0    pypi
send2trash                1.5.0                    pypi_0    pypi
setuptools                51.1.1                   pypi_0    pypi
six                       1.15.0                   pypi_0    pypi
sklearn                   0.0                      pypi_0    pypi
sqlflow                   0.15.0.dev0              pypi_0    pypi
sqlite                    3.23.1               he433501_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
tensorboard               2.4.1                    pypi_0    pypi
tensorboard-plugin-wit    1.8.0                    pypi_0    pypi
tensorflow                1.14.0                   pypi_0    pypi
tensorflow-cpu            2.4.0                    pypi_0    pypi
tensorflow-estimator      2.4.0                    pypi_0    pypi
tensorflow-io             0.7.0                    pypi_0    pypi
termcolor                 1.1.0                    pypi_0    pypi
terminado                 0.9.1                    pypi_0    pypi
testpath                  0.4.4                    pypi_0    pypi
threadpoolctl             2.1.0                    pypi_0    pypi
tianchi-extension         0.1.0                    pypi_0    pypi
tk                        8.6.10               hbc83047_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
torch                     1.7.1                    pypi_0    pypi
torchvision-cpu           0.3.0             py36_cuNone_1    pytorch
tornado                   6.1                      pypi_0    pypi
tqdm                      4.32.1                     py_0    defaults
traitlets                 4.3.3                    pypi_0    pypi
typing-extensions         3.7.4.3                  pypi_0    pypi
urllib3                   1.26.2                   pypi_0    pypi
wcwidth                   0.2.5                    pypi_0    pypi
webencodings              0.5.1                    pypi_0    pypi
werkzeug                  0.15.5                   pypi_0    pypi
wheel                     0.36.2                   pypi_0    pypi
wrapt                     1.12.1                   pypi_0    pypi
xz                        5.2.4                h14c3975_4    defaults
yaml                      0.2.5                h7b6447c_0    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
zipp                      3.4.0                    pypi_0    pypi
zlib                      1.2.11               h7b6447c_3    https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
zstd                      1.3.3                h84994c4_0    defaults
# 其他库

import os
import librosa
import librosa.display
import glob 

特征提取以及数据集的建立

feature = []
label = []
# 建立类别标签,不同类别对应不同的数字。
label_dict = {'aloe': 0, 'burger': 1, 'cabbage': 2,'candied_fruits':3, 'carrots': 4, 'chips':5,
                  'chocolate': 6, 'drinks': 7, 'fries': 8, 'grapes': 9, 'gummies': 10, 'ice-cream':11,
                  'jelly': 12, 'noodles': 13, 'pickles': 14, 'pizza': 15, 'ribs': 16, 'salmon':17,
                  'soup': 18, 'wings': 19}
label_dict_inv = {v:k for k,v in label_dict.items()}
from tqdm import tqdm
def extract_features(parent_dir, sub_dirs, max_file=10, file_ext="*.wav"):
    c = 0
    label, feature = [], []
    for sub_dir in sub_dirs:
        for fn in tqdm(glob.glob(os.path.join(parent_dir, sub_dir, file_ext))[:max_file]): # 遍历数据集的所有文件
            
           # segment_log_specgrams, segment_labels = [], []
            #sound_clip,sr = librosa.load(fn)
            #print(fn)
            label_name = fn.split('/')[-2]
            label.extend([label_dict[label_name]])
            X, sample_rate = librosa.load(fn,res_type='kaiser_fast')
            mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征
            feature.extend([mels])
            
    return [feature, label]
# 自己更改目录
parent_dir = './train_sample/'
save_dir = "./"
folds = sub_dirs = np.array(['aloe','burger','cabbage','candied_fruits',
                             'carrots','chips','chocolate','drinks','fries',
                            'grapes','gummies','ice-cream','jelly','noodles','pickles',
                            'pizza','ribs','salmon','soup','wings'])

# 获取特征feature以及类别的label
temp = extract_features(parent_dir,sub_dirs,max_file=100)
100%|██████████| 45/45 [00:12<00:00,  5.03it/s]
100%|██████████| 64/64 [00:14<00:00,  5.09it/s]
100%|██████████| 48/48 [00:17<00:00,  2.88it/s]
100%|██████████| 74/74 [00:26<00:00,  1.31it/s]
100%|██████████| 49/49 [00:14<00:00,  3.50it/s]
100%|██████████| 57/57 [00:17<00:00,  3.65it/s]
100%|██████████| 27/27 [00:07<00:00,  3.48it/s]
100%|██████████| 27/27 [00:07<00:00,  3.54it/s]
100%|██████████| 57/57 [00:15<00:00,  3.67it/s]
100%|██████████| 61/61 [00:17<00:00,  4.01it/s]
100%|██████████| 65/65 [00:19<00:00,  3.11it/s]
100%|██████████| 69/69 [00:22<00:00,  3.08it/s]
100%|██████████| 43/43 [00:12<00:00,  3.41it/s]
100%|██████████| 33/33 [00:09<00:00,  3.37it/s]
100%|██████████| 75/75 [00:23<00:00,  3.15it/s]
100%|██████████| 55/55 [00:18<00:00,  2.96it/s]
100%|██████████| 47/47 [00:14<00:00,  3.50it/s]
100%|██████████| 37/37 [00:13<00:00,  2.04it/s]
100%|██████████| 32/32 [00:07<00:00,  3.87it/s]
100%|██████████| 35/35 [00:11<00:00,  2.76it/s]
temp = np.array(temp)
data = temp.transpose()
/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py:1: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray
  """Entry point for launching an IPython kernel.
# 获取特征
X = np.vstack(data[:, 0])

# 获取标签
Y = np.array(data[:, 1])
print('X的特征尺寸是:',X.shape)
print('Y的特征尺寸是:',Y.shape)
X的特征尺寸是: (1000, 128)
Y的特征尺寸是: (1000,)
# 在Keras库中:to_categorical就是将类别向量转换为二进制(只有0和1)的矩阵类型表示
Y = to_categorical(Y)
'''最终数据'''
print(X.shape)
print(Y.shape)
(1000, 128)
(1000, 20)
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state = 1, stratify=Y)
print('训练集的大小',len(X_train))
print('测试集的大小',len(X_test))
训练集的大小 750
测试集的大小 250
X_train = X_train.reshape(-1, 16, 8, 1)
X_test = X_test.reshape(-1, 16, 8, 1)

建立模型

搭建CNN网络

model = Sequential()

# 输入的大小
input_dim = (16, 8, 1)

model.add(Conv2D(64, (3, 3), padding = "same", activation = "tanh", input_shape = input_dim))# 卷积层
model.add(MaxPool2D(pool_size=(2, 2)))# 最大池化
model.add(Conv2D(128, (3, 3), padding = "same", activation = "tanh")) #卷积层
model.add(MaxPool2D(pool_size=(2, 2))) # 最大池化层
model.add(Dropout(0.1))
model.add(Flatten()) # 展开
model.add(Dense(1024, activation = "tanh"))
model.add(Dense(20, activation = "softmax")) # 输出层:20个units输出20个类的概率

# 编译模型,设置损失函数,优化方法以及评价标准
model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 16, 8, 64)         640       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 8, 4, 64)          0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 8, 4, 128)         73856     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 4, 2, 128)         0         
_________________________________________________________________
dropout (Dropout)            (None, 4, 2, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 1024)              1049600   
_________________________________________________________________
dense_1 (Dense)              (None, 20)                20500     
=================================================================
Total params: 1,144,596
Trainable params: 1,144,596
Non-trainable params: 0
_________________________________________________________________
# 训练模型
model.fit(X_train, Y_train, epochs = 20, batch_size = 15, validation_data = (X_test, Y_test))
Epoch 1/20
50/50 [==============================] - 4s 56ms/step - loss: 2.9535 - accuracy: 0.1052 - val_loss: 2.6772 - val_accuracy: 0.1960
Epoch 2/20
50/50 [==============================] - 2s 37ms/step - loss: 2.4855 - accuracy: 0.2418 - val_loss: 2.5755 - val_accuracy: 0.2080
Epoch 3/20
50/50 [==============================] - 2s 37ms/step - loss: 2.2325 - accuracy: 0.3134 - val_loss: 2.4603 - val_accuracy: 0.2520
Epoch 4/20
50/50 [==============================] - 2s 39ms/step - loss: 2.0355 - accuracy: 0.3996 - val_loss: 2.4024 - val_accuracy: 0.2760
Epoch 5/20
50/50 [==============================] - 2s 38ms/step - loss: 1.8670 - accuracy: 0.4200 - val_loss: 2.4080 - val_accuracy: 0.3120
Epoch 6/20
50/50 [==============================] - 2s 37ms/step - loss: 1.6604 - accuracy: 0.4909 - val_loss: 2.4047 - val_accuracy: 0.3280
Epoch 7/20
50/50 [==============================] - 2s 37ms/step - loss: 1.5919 - accuracy: 0.5237 - val_loss: 2.5766 - val_accuracy: 0.3120
Epoch 8/20
50/50 [==============================] - 2s 38ms/step - loss: 1.3910 - accuracy: 0.5578 - val_loss: 2.6057 - val_accuracy: 0.3200
Epoch 9/20
50/50 [==============================] - 2s 37ms/step - loss: 1.2842 - accuracy: 0.6188 - val_loss: 2.6491 - val_accuracy: 0.3160
Epoch 10/20
50/50 [==============================] - 2s 37ms/step - loss: 1.0891 - accuracy: 0.6734 - val_loss: 2.9650 - val_accuracy: 0.3000
Epoch 11/20
50/50 [==============================] - 2s 38ms/step - loss: 1.0029 - accuracy: 0.6969 - val_loss: 2.9276 - val_accuracy: 0.3400
Epoch 12/20
50/50 [==============================] - 2s 37ms/step - loss: 0.8177 - accuracy: 0.7670 - val_loss: 3.0201 - val_accuracy: 0.3680
Epoch 13/20
50/50 [==============================] - 2s 38ms/step - loss: 0.7925 - accuracy: 0.7684 - val_loss: 3.2365 - val_accuracy: 0.3640
Epoch 14/20
50/50 [==============================] - 2s 39ms/step - loss: 0.7578 - accuracy: 0.7711 - val_loss: 3.6040 - val_accuracy: 0.3520
Epoch 15/20
50/50 [==============================] - 2s 38ms/step - loss: 0.6582 - accuracy: 0.8034 - val_loss: 3.4311 - val_accuracy: 0.3800
Epoch 16/20
50/50 [==============================] - 2s 45ms/step - loss: 0.6125 - accuracy: 0.8210 - val_loss: 3.4721 - val_accuracy: 0.3520
Epoch 17/20
50/50 [==============================] - 2s 38ms/step - loss: 0.5335 - accuracy: 0.8556 - val_loss: 3.8178 - val_accuracy: 0.3760
Epoch 18/20
50/50 [==============================] - 2s 37ms/step - loss: 0.4607 - accuracy: 0.8764 - val_loss: 3.7193 - val_accuracy: 0.3480
Epoch 19/20
50/50 [==============================] - 2s 39ms/step - loss: 0.4444 - accuracy: 0.8820 - val_loss: 3.8073 - val_accuracy: 0.3800
Epoch 20/20
50/50 [==============================] - 2s 37ms/step - loss: 0.3612 - accuracy: 0.9125 - val_loss: 3.8732 - val_accuracy: 0.3720





<tensorflow.python.keras.callbacks.History at 0x7ff66c0f7320>

预测测试集

def extract_features(test_dir, file_ext="*.wav"):
    feature = []
    for fn in tqdm(glob.glob(os.path.join(test_dir, file_ext))[:]): # 遍历数据集的所有文件
        X, sample_rate = librosa.load(fn,res_type='kaiser_fast')
        mels = np.mean(librosa.feature.melspectrogram(y=X,sr=sample_rate).T,axis=0) # 计算梅尔频谱(mel spectrogram),并把它作为特征
        feature.extend([mels])
    return feature
X_test = extract_features('./test_a/')
100%|██████████| 2000/2000 [10:28<00:00,  3.34it/s]
X_test = np.vstack(X_test)
predictions = model.predict(X_test.reshape(-1, 16, 8, 1))
preds = np.argmax(predictions, axis = 1)
preds = [label_dict_inv[x] for x in preds]

path = glob.glob('./test_a/*.wav')
result = pd.DataFrame({'name':path, 'label': preds})

result['name'] = result['name'].apply(lambda x: x.split('/')[-1])
result.to_csv('submit.csv',index=None)
!ls ./test_a/*.wav | wc -l
2000

!wc -l submit.csv

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐