ML Collections的介绍(一)

在看一篇论文的源码时,看到import ml_collections这行代码,系统报错,经过寻找之后发现并不是源码中的一个文件,之后感觉应该是一个深度学习的安装包,然后就在网上搜索这个包的安装过程,却没有任何信息,于是自己直接尝试了pip install ml-collections,于是就安装成功了。今天在github上搜索了这个安装包,潜心学习了一下。

ML Collections

ML Collections是为ML use cases而设计的一个Python Collections的一个库。它的两个类是ConfigDictFrozenConfigDict,是"dict-like" 的数据结构,以下是ConfigDictFrozenConfigDictFieldReference的示例用法,直接上代码吧。

Basic Usage

import ml_collections

cfg = ml_collections.ConfigDict()
cfg.float_field = 12.6 #float类型
cfg.integer_field = 123 #int类型
cfg.another_integer_field = 234 #int类型
cfg.nested = ml_collections.ConfigDict() #嵌套了ml_collections.ConfigDict()
cfg.nested.string_field = 'tom' #str类型

print(cfg.integer_field)  # 输出结果 123.
print(cfg['integer_field'])  # 也输出123.

try:
  cfg.integer_field = 'tom'  # 输出会报错错误类型是TypeError,因为这你field是整数类型
except TypeError as e:
  print(e)

cfg.float_field = 12  # int类型也可以指定给float类型.
cfg.nested.string_field = u'bob'  # string可以储存Unicode字符串
print(cfg)

下面是这段代码的输出结果

123
123
Could not override field 'integer_field' (reference). tom is of type <class 'str'> but should be of type <class 'int'> #报错的地方
#以下输出是所有的cfg的结果,以及嵌套的nested
another_integer_field: 234
float_field: 12.0
integer_field: 123
nested:
  string_field: bob

FrozenConfigDict

不可以改变的ConfigDict

import ml_collections
#初始化一个字典
initial_dictionary = {
    'int': 1,
    'list': [1, 2],
    'tuple': (1, 2, 3),
    'set': {1, 2, 3, 4},
    'dict_tuple_list': {'tuple_list': ([1, 2], 3)}
}

cfg = ml_collections.ConfigDict(initial_dictionary)#把这个字典通过ConfigDict赋值给cfg
frozen_dict = ml_collections.FrozenConfigDict(initial_dictionary)#把这个字典通过FrozenConfigDict赋值给frozen_dict

print(frozen_dict.tuple)  # 输出(1, 2, 3)
print(frozen_dict.list)  # 输出 (1, 2)
print(frozen_dict.set)  # 输出 {1, 2, 3, 4}
print(frozen_dict.dict_tuple_list.tuple_list[0])  # 输出 (1, 2)

frozen_cfg = ml_collections.FrozenConfigDict(cfg)#将cfg变成Forzen类型,即不可再改变其中常量的值
print(frozen_cfg == frozen_dict)  # True
print(hash(frozen_cfg) == hash(frozen_dict))  # True

try:
  frozen_dict.int = 2 # 会报错,因为FrozenConfigDict是不可以改变其中的值的
except AttributeError as e:
  print(e)

# 在`FrozenConfigDict` 与 `ConfigDict`之间进行转换:
thawed_frozen_cfg = ml_collections.ConfigDict(frozen_dict) #将frozen_dict转化为ConfigDict
print(thawed_frozen_cfg == cfg)  # True
frozen_cfg_to_cfg = frozen_dict.as_configdict()#将frozen_dict通过as_configdict方法转化为ConfigDict
print(frozen_cfg_to_cfg == cfg)  # True

以上代码输出结果:

(1, 2, 3)
(1, 2)
frozenset({1, 2, 3, 4})
(1, 2)
True
True
FrozenConfigDict is immutable. Cannot call __setattr__().#报错的地方
True
True

FieldReference

FieldReference有助于让多个字段使用相同的值。它也可以用于lazy computation
可以使用placeholder()方法来创建具有None默认值的FieldReference(field)。当程序使用可选的配置字段时,这将非常有用。

import ml_collections
from ml_collections.config_dict import config_dict

placeholder = ml_collections.FieldReference(0) #占位符,值为0
cfg = ml_collections.ConfigDict()
cfg.placeholder = placeholder
cfg.optional = config_dict.placeholder(int)
cfg.nested = ml_collections.ConfigDict()
cfg.nested.placeholder = placeholder

try:
  cfg.optional = 'tom'  # 会报类型错误,因为是int类型
except TypeError as e:
  print(e)

cfg.optional = 1555  
cfg.placeholder = 1  # 改变placeholder and nested.placeholder 的值.

print(cfg)

输出结果:

Could not override field 'optional' (reference). tom is of type <class 'str'> but should be of type <class 'int'> #报错的地方
nested:
  placeholder: 1
optional: 1555
placeholder: 1

请注意,如果通过ConfigDict进行访问,FieldReferences提供的间接寻址将丢失。

import ml_collections

placeholder = ml_collections.FieldReference(0)
cfg.field1 = placeholder
print(cfg.field1)#输出为0
cfg.field2 = placeholder# 此字段将被tied到cfg.field1字段
print(cfg.field2)#输出为0
cfg.field3 = cfg.field1  # 这只是一个初始化为0的int字段
print(cfg.field3)#输出为0

以上代码的输出结果:

0
0
0

Lazy computation

在标准操作(加法、减法、乘法等)中使用FieldReference将返回另一个指向原始值的FieldReference。你可以用字段引用get()方法执行操作并获取引用的计算值,以及字段引用set()方法更改原始引用的值

import ml_collections

ref = ml_collections.FieldReference(1)
print("ref:",ref.get())  # 输出为 1,通过get()方法获取ref的值

add_ten = ref.get() + 10  # ref.get()是一个整数,所以加法也是整数相加
add_ten_lazy = ref + 10  # add_ten_lazy是FieldReference,并不是一个整数

print("add_ten:",add_ten)  # 输出为 11
print("add_ten_lazy:",add_ten_lazy.get())  # 输出为11,因为ref的值是1

# Addition is lazily computed for FieldReferences so changing ref will change
# the value that is used to compute add_ten.
ref.set(5) #更改ref的值为5
print("ref_:",ref.get()) #输出为5
print("add_ten_:",add_ten)  # 输出为 11
print("add_ten_lazy_",add_ten_lazy.get())  # 输出 15因为此时ref的值为5

以上代码的输出:

ref: 1
add_ten: 11
add_ten_lazy: 11
ref_ 5
add_ten_: 11
add_ten_lazy_ 15

如果FieldReference的原始值为None,或者任何操作的参数为None,则lazy computation的结果将为None。
我们也可以在lazy computation中使用ConfigDict中的字段。在这种情况下,只有ConfigDict.get_ref()方法用于获取它

import ml_collections

config = ml_collections.ConfigDict()
config.reference_field = ml_collections.FieldReference(1)
config.integer_field = 2
config.float_field = 2.5

# 因为我们没有使用`get_ref()`所以没有lazy evaluatuations
config.no_lazy = config.integer_field * config.float_field #2*2.5

# 这里的lazily evaluate 只能是 config.integer_field
config.lazy_integer = config.get_ref('integer_field') * config.float_field 

# 这个lazily evaluate 只能是 config.float_field
config.lazy_float = config.integer_field * config.get_ref('float_field')

# 这里的lazily evaluate 既有 config.integer_field 也有 config.float_Field
config.lazy_both = (config.get_ref('integer_field') *
                    config.get_ref('float_field'))

config.integer_field = 3
print(config.no_lazy)  # 输出 5.0 - 这里用了integer_field'的原始值

print(config.lazy_integer)  # 输出 7.5 3*2.5 

config.float_field = 3.5 
print(config.lazy_float)  # 输出 7.0,2*3.5 这里的integer_field 是2
print(config.lazy_both)  # 输出 10.5 3*3.5

以上代码的输出结果

5.0
7.5
7.0
10.5

Changing lazily computed values
ConfigDict中lazily computed的值可以与常规值相同的方式进行重新赋值。对用于lazily computed的FieldReference的之前的值将舍弃,并且之后的所有计算都将使用新值

import ml_collections

config = ml_collections.ConfigDict()
config.reference = 1
config.reference_0 = config.get_ref('reference') + 10 #1+10
config.reference_1 = config.get_ref('reference') + 20 #1+20
config.reference_1_0 = config.get_ref('reference_1') + 100 #21+100

print(config.reference)  # Prints 1.
print(config.reference_0)  # Prints 11.
print(config.reference_1)  # Prints 21.
print(config.reference_1_0)  # Prints 121.

config.reference_1 = 30 #此处给reference_1 赋新的值,之后的reference_1 都将是30

print(config.reference)  # Prints 1 (unchanged).
print(config.reference_0)  # Prints 11 (unchanged).
print(config.reference_1)  # Prints 30.
print(config.reference_1_0)  # Prints 130.

以上代码的结果

1
11
21
121
1
11
30
130

Cycles
不能使用references创建循环。(这里没有看太明白,希望有大佬可以指点一下)
“You cannot create cycles using references. Fortunately the only way to create a cycle is by assigning a computed field to one that is not the result of computation. This is forbidden:”这段话是github的原始解释。

import ml_collections
from ml_collections.config_dict import config_dict

config = ml_collections.ConfigDict()
config.integer_field = 1
config.bigger_integer_field = config.get_ref('integer_field') + 10 #11
print(config.bigger_integer_field)

try:
  # 引发一个MutabilityError,因为设置config.integer()字段会引起循环。
  config.integer_field = config.get_ref('bigger_integer_field') + 2
except config_dict.MutabilityError as e:
  print(e)

以上代码输出结果:

11
Found cycle in reference graph.#引发的错误

后面还有一半的内容,下次再来填坑。第一次写博客,也是自己学习之路的一个记录吧,错误之处,希望可以指点一二。

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐