动态时间规整matlab和python代码
几篇写得很好的文章:【重大修改】动态时间规整(Dynamic Time Warping)算法笔记-DTW动态时间规整动态时间规整算法(Dynamic Time Warping, DTW)之初探单词语音识别matlab代码:参考这里: 【重大修改】动态时间规整(Dynamic Time Warping)python代码:import numpy as npimport matpl...
几篇写得很好的文章
【重大修改】动态时间规整(Dynamic Time Warping)
算法笔记-DTW动态时间规整
动态时间规整算法(Dynamic Time Warping, DTW)之初探单词语音识别
(1)相似度计算
先计算欧式距离,然后计算累计损失距离,计算到最后一个值即为两个序列的最终距离。
这个最终距离表示两个序列的相似度,距离越小,相似度越大。
借用上面第二篇的例子:
X:3,5,6,7,7,1
Y:3,6,6,7,8,1,1
求出第一个欧式距离表后,通过下面公式计算
M
c
(
i
,
j
)
=
M
i
n
(
M
c
(
i
−
1
,
j
−
1
)
,
M
c
(
i
−
1
,
j
)
,
M
c
(
i
,
j
−
1
)
)
+
M
(
i
,
j
)
M_c(i,j)=Min(M_c(i-1,j-1),M_c(i-1,j),M_c(i,j-1))+M(i,j)
Mc(i,j)=Min(Mc(i−1,j−1),Mc(i−1,j),Mc(i,j−1))+M(i,j)
例如图中要计算 ? 的值,考虑其左上相邻3个值和其欧式距离表对应位置上的值。这里 min(0, 3, 6)+3 = 3.
(2)输出最佳路径path
这里参数包含了对应两个输入序列的规整序列,其实就是索引值。这个算法认为对应索引位置上的值相似度比较大。见代码如何得到规整后的序列。
(3)代码
matlab代码
参考这里: 【重大修改】动态时间规整(Dynamic Time Warping)
python代码
import numpy as np
import matplotlib.pyplot as plt
def dtw(x, y, dist):
"""
Computes Dynamic Time Warping (DTW) of two sequences.
:param array x: N1*M array
:param array y: N2*M array
:param func dist: distance used as cost measure
Returns the minimum distance, the cost matrix, the accumulated cost matrix, and the wrap path.
"""
assert len(x) # Report error while x is none
assert len(y)
r, c = len(x), len(y)
D0 = zeros((r + 1, c + 1))
D0[0, 1:] = inf
D0[1:, 0] = inf
D1 = D0[1:, 1:] # view
for i in range(r):
for j in range(c):
D1[i, j] = dist(x[i], y[j])
C = D1.copy()
for i in range(r):
for j in range(c):
D1[i, j] += min(D0[i, j], D0[i, j + 1], D0[i + 1, j])
if len(x) == 1:
path = zeros(len(y)), range(len(y))
elif len(y) == 1:
path = range(len(x)), zeros(len(x))
else:
path = _traceback(D0)
return D1[-1, -1] / sum(D1.shape), C, D1, path
x = np.array([0, 0, 1, 1, 2, 4, 2, 1, 2, 0]).reshape(-1, 1)
y = np.array([1, 1, 1, 2, 2, 2, 2, 3, 2, 0]).reshape(-1, 1)
dist, cost, acc, path = dtw(x, y, dist=lambda x, y: np.linalg.norm(x - y, ord=1))
print('Minimum distance found:', dist)
plt.figure()
plt.subplot(131)
plt.plot(x)
plt.plot(y)
plt.title('orignal data')
plt.subplot(132)
plt.imshow(cost.T, origin='lower', cmap=plt.cm.gray, interpolation='nearest')
plt.plot(path[0], path[1], 'w')
plt.xlim((-0.5, cost.shape[0]-0.5))
plt.ylim((-0.5, cost.shape[1]-0.5))
plt.title('the cost matrix and the wrap path')
plt.subplot(133)
plt.plot(x[path[0]])
plt.plot(y[path[1]])
plt.xlim((-0.5, cost.shape[0]-0.5))
plt.ylim((-0.5, cost.shape[1]-0.5))
plt.title('wrapped signal')
plt.show()
更多推荐





所有评论(0)