ML Collections的介绍(一)
ML Collections的介绍(一)在看一篇论文的源码时,看到import ml_collections这行代码,系统报错,经过寻找之后发现并不是源码中的一个文件,之后感觉应该是一个深度学习的安装包,然后就在网上搜索这个包的安装过程,却没有任何信息,于是自己直接尝试了pip install ml-collections,于是就安装成功了。今天在github上搜索了这个安装包,潜心学习了一下。M
ML Collections的介绍(一)
在看一篇论文的源码时,看到import ml_collections这行代码,系统报错,经过寻找之后发现并不是源码中的一个文件,之后感觉应该是一个深度学习的安装包,然后就在网上搜索这个包的安装过程,却没有任何信息,于是自己直接尝试了pip install ml-collections,于是就安装成功了。今天在github上搜索了这个安装包,潜心学习了一下。
ML Collections
ML Collections是为ML use cases而设计的一个Python Collections的一个库。它的两个类是ConfigDict
和FrozenConfigDict
,是"dict-like" 的数据结构,以下是ConfigDict
、FrozenConfigDict
和FieldReference
的示例用法,直接上代码吧。
Basic Usage
import ml_collections
cfg = ml_collections.ConfigDict()
cfg.float_field = 12.6 #float类型
cfg.integer_field = 123 #int类型
cfg.another_integer_field = 234 #int类型
cfg.nested = ml_collections.ConfigDict() #嵌套了ml_collections.ConfigDict()
cfg.nested.string_field = 'tom' #str类型
print(cfg.integer_field) # 输出结果 123.
print(cfg['integer_field']) # 也输出123.
try:
cfg.integer_field = 'tom' # 输出会报错错误类型是TypeError,因为这你field是整数类型
except TypeError as e:
print(e)
cfg.float_field = 12 # int类型也可以指定给float类型.
cfg.nested.string_field = u'bob' # string可以储存Unicode字符串
print(cfg)
下面是这段代码的输出结果
123
123
Could not override field 'integer_field' (reference). tom is of type <class 'str'> but should be of type <class 'int'> #报错的地方
#以下输出是所有的cfg的结果,以及嵌套的nested
another_integer_field: 234
float_field: 12.0
integer_field: 123
nested:
string_field: bob
FrozenConfigDict
不可以改变的ConfigDict
import ml_collections
#初始化一个字典
initial_dictionary = {
'int': 1,
'list': [1, 2],
'tuple': (1, 2, 3),
'set': {1, 2, 3, 4},
'dict_tuple_list': {'tuple_list': ([1, 2], 3)}
}
cfg = ml_collections.ConfigDict(initial_dictionary)#把这个字典通过ConfigDict赋值给cfg
frozen_dict = ml_collections.FrozenConfigDict(initial_dictionary)#把这个字典通过FrozenConfigDict赋值给frozen_dict
print(frozen_dict.tuple) # 输出(1, 2, 3)
print(frozen_dict.list) # 输出 (1, 2)
print(frozen_dict.set) # 输出 {1, 2, 3, 4}
print(frozen_dict.dict_tuple_list.tuple_list[0]) # 输出 (1, 2)
frozen_cfg = ml_collections.FrozenConfigDict(cfg)#将cfg变成Forzen类型,即不可再改变其中常量的值
print(frozen_cfg == frozen_dict) # True
print(hash(frozen_cfg) == hash(frozen_dict)) # True
try:
frozen_dict.int = 2 # 会报错,因为FrozenConfigDict是不可以改变其中的值的
except AttributeError as e:
print(e)
# 在`FrozenConfigDict` 与 `ConfigDict`之间进行转换:
thawed_frozen_cfg = ml_collections.ConfigDict(frozen_dict) #将frozen_dict转化为ConfigDict
print(thawed_frozen_cfg == cfg) # True
frozen_cfg_to_cfg = frozen_dict.as_configdict()#将frozen_dict通过as_configdict方法转化为ConfigDict
print(frozen_cfg_to_cfg == cfg) # True
以上代码输出结果:
(1, 2, 3)
(1, 2)
frozenset({1, 2, 3, 4})
(1, 2)
True
True
FrozenConfigDict is immutable. Cannot call __setattr__().#报错的地方
True
True
FieldReference
FieldReference
有助于让多个字段使用相同的值。它也可以用于lazy computation
。
可以使用placeholder()方法来创建具有None默认值的FieldReference
(field)。当程序使用可选的配置字段时,这将非常有用。
import ml_collections
from ml_collections.config_dict import config_dict
placeholder = ml_collections.FieldReference(0) #占位符,值为0
cfg = ml_collections.ConfigDict()
cfg.placeholder = placeholder
cfg.optional = config_dict.placeholder(int)
cfg.nested = ml_collections.ConfigDict()
cfg.nested.placeholder = placeholder
try:
cfg.optional = 'tom' # 会报类型错误,因为是int类型
except TypeError as e:
print(e)
cfg.optional = 1555
cfg.placeholder = 1 # 改变placeholder and nested.placeholder 的值.
print(cfg)
输出结果:
Could not override field 'optional' (reference). tom is of type <class 'str'> but should be of type <class 'int'> #报错的地方
nested:
placeholder: 1
optional: 1555
placeholder: 1
请注意,如果通过ConfigDict
进行访问,FieldReferences
提供的间接寻址将丢失。
import ml_collections
placeholder = ml_collections.FieldReference(0)
cfg.field1 = placeholder
print(cfg.field1)#输出为0
cfg.field2 = placeholder# 此字段将被tied到cfg.field1字段
print(cfg.field2)#输出为0
cfg.field3 = cfg.field1 # 这只是一个初始化为0的int字段
print(cfg.field3)#输出为0
以上代码的输出结果:
0
0
0
Lazy computation
在标准操作(加法、减法、乘法等)中使用FieldReference
将返回另一个指向原始值的FieldReference
。你可以用字段引用get()
方法执行操作并获取引用的计算值,以及字段引用set()
方法更改原始引用的值
import ml_collections
ref = ml_collections.FieldReference(1)
print("ref:",ref.get()) # 输出为 1,通过get()方法获取ref的值
add_ten = ref.get() + 10 # ref.get()是一个整数,所以加法也是整数相加
add_ten_lazy = ref + 10 # add_ten_lazy是FieldReference,并不是一个整数
print("add_ten:",add_ten) # 输出为 11
print("add_ten_lazy:",add_ten_lazy.get()) # 输出为11,因为ref的值是1
# Addition is lazily computed for FieldReferences so changing ref will change
# the value that is used to compute add_ten.
ref.set(5) #更改ref的值为5
print("ref_:",ref.get()) #输出为5
print("add_ten_:",add_ten) # 输出为 11
print("add_ten_lazy_",add_ten_lazy.get()) # 输出 15因为此时ref的值为5
以上代码的输出:
ref: 1
add_ten: 11
add_ten_lazy: 11
ref_ 5
add_ten_: 11
add_ten_lazy_ 15
如果FieldReference
的原始值为None,或者任何操作的参数为None,则lazy computation的结果将为None。
我们也可以在lazy computation中使用ConfigDict
中的字段。在这种情况下,只有ConfigDict.get_ref()
方法用于获取它
import ml_collections
config = ml_collections.ConfigDict()
config.reference_field = ml_collections.FieldReference(1)
config.integer_field = 2
config.float_field = 2.5
# 因为我们没有使用`get_ref()`所以没有lazy evaluatuations
config.no_lazy = config.integer_field * config.float_field #2*2.5
# 这里的lazily evaluate 只能是 config.integer_field
config.lazy_integer = config.get_ref('integer_field') * config.float_field
# 这个lazily evaluate 只能是 config.float_field
config.lazy_float = config.integer_field * config.get_ref('float_field')
# 这里的lazily evaluate 既有 config.integer_field 也有 config.float_Field
config.lazy_both = (config.get_ref('integer_field') *
config.get_ref('float_field'))
config.integer_field = 3
print(config.no_lazy) # 输出 5.0 - 这里用了integer_field'的原始值
print(config.lazy_integer) # 输出 7.5 3*2.5
config.float_field = 3.5
print(config.lazy_float) # 输出 7.0,2*3.5 这里的integer_field 是2
print(config.lazy_both) # 输出 10.5 3*3.5
以上代码的输出结果
5.0
7.5
7.0
10.5
Changing lazily computed values
ConfigDict
中lazily computed的值可以与常规值相同的方式进行重新赋值。对用于lazily computed的FieldReference
的之前的值将舍弃,并且之后的所有计算都将使用新值
import ml_collections
config = ml_collections.ConfigDict()
config.reference = 1
config.reference_0 = config.get_ref('reference') + 10 #1+10
config.reference_1 = config.get_ref('reference') + 20 #1+20
config.reference_1_0 = config.get_ref('reference_1') + 100 #21+100
print(config.reference) # Prints 1.
print(config.reference_0) # Prints 11.
print(config.reference_1) # Prints 21.
print(config.reference_1_0) # Prints 121.
config.reference_1 = 30 #此处给reference_1 赋新的值,之后的reference_1 都将是30
print(config.reference) # Prints 1 (unchanged).
print(config.reference_0) # Prints 11 (unchanged).
print(config.reference_1) # Prints 30.
print(config.reference_1_0) # Prints 130.
以上代码的结果
1
11
21
121
1
11
30
130
Cycles
不能使用references创建循环。(这里没有看太明白,希望有大佬可以指点一下)
“You cannot create cycles using references. Fortunately the only way to create a cycle is by assigning a computed field to one that is not the result of computation. This is forbidden:”这段话是github的原始解释。
import ml_collections
from ml_collections.config_dict import config_dict
config = ml_collections.ConfigDict()
config.integer_field = 1
config.bigger_integer_field = config.get_ref('integer_field') + 10 #11
print(config.bigger_integer_field)
try:
# 引发一个MutabilityError,因为设置config.integer()字段会引起循环。
config.integer_field = config.get_ref('bigger_integer_field') + 2
except config_dict.MutabilityError as e:
print(e)
以上代码输出结果:
11
Found cycle in reference graph.#引发的错误
后面还有一半的内容,下次再来填坑。第一次写博客,也是自己学习之路的一个记录吧,错误之处,希望可以指点一二。
更多推荐
所有评论(0)