IntermediateLayerGetter 获取中间结果
获取中间结果
利用 IntermediateLayerGetter 方法可以非常方便的获取中间结果,例如获取 resnet 的中间结果——features,这一点很有用,因为很多模型都会使用 resnet 作为 backbone 来提取 features。
123456789101112131415161718192021import torchfrom torchvision.models import resnet18from torchvision.models._utils import IntermediateLayerGetterinput = torch.rand(size=(1, 3, 224, 224))model = resnet18()output = model(input)print(output.shape)# torch.Size([1, 1000])'''extract layer1, layer2, layer3 and layer4, giving as names feat1, feat2, feat3 and fea ...
Python 包裹传递和解包传递
简介
在 Python 中,参数传递方式有两种主要形式:位置参数和关键字参数。通过使用包裹传递(packing)和解包传递(unpacking),可以灵活地处理任意数量的参数。
单星号 *
单星号 * 用于处理位置参数的打包和解包。
单星号打包
单星号 * 打包会将多个位置参数打包成一个 tuple 对象
123456def sum_all(*args): print(f"{args = }, {type(args) = }") return sum(args)result = sum_all(1, 2, 3, 4)# args = (1, 2, 3, 4), type(args) = <class 'tuple'>
单星号解包
单星号 * 解包用于将一个可迭代对象解包为位置参数
123456def add(a, b, c): return a + b + cdata = [1, 2, 3]print(*data) # 1 2 3result = add(*data) # add(1, 2 ...
Pytorch 对 ResNet 的实现
参考
pytorch vision resnet
resnet.py | pytorch github
借助 torch hub
可以直接借助 torch hub pytorch vision resnet 来加载预训练的 resnet。
12345678import torchmodel = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)'''or any of these variants'''# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)# model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)# mod ...
Python 标准化输出
Document
Input and Output - formatted-string-literals | python docs
Python format 格式化函数 | 菜鸟教程
使用 f-strings
f-strings 是 Python 从 3.6 版本开始引入的一种字符串格式化方式,其全称为“格式化字符串字面量”。f-strings 使用更加简洁的语法,并且具有更高的可读性和性能。
f-strings 的语法使得可以直接在字符串中使用 {} 嵌入变量和python表达式,并允许在 : 后添加格式化字符说明。
123name, age, height = "zhang", 18, 1.7print(f"My name is {name}, I'm {age} years old and {height:.2f} meter tall.")# My name is zhang, I'm 18 years old and 1.70 me ...
图像分割的评价指标
TP TN FP FN
The following are the 4 basic terminologies you need to know.
True Positives (TP): when the actual value is Positive and predicted is also Positive.
True Negatives (TN): when the actual value is Negative and prediction is also Negative.
False Positives (FP): When the actual is negative but prediction is Positive. Also known as the Type 1 error.
False Negatives (FN): When the actual is Positive but the prediction is Negative. Also known as the Type 2 error.
OA, Overall Accuracy, ...
Python Config
简介
使用 Python 编写一些深度学习算法时,经常会涉及到大量超参数的配置。通过 Python 的两个标准库,ml_collections 和 dataclasses,可以实现对这些配置进行管理。
ml_collections
ml_collections 是一个 Python 库,主要用于管理实验配置,其中的 ConfigDict 是一个核心工具。ConfigDict 提供了一种结构化的、易于管理的方式来定义和操作配置。它类似于 Python 的字典,但增强了功能,以便更方便地进行深度学习实验或其他场景中的配置管理。
1234567891011121314from ml_collections import ConfigDictdef get_config() -> ConfigDict: config = ConfigDict() config.model = ConfigDict() config.model.type = 'ResNet' config.model.num_layers = 50 config.optimizer = ...
Python 格式化时间
datetime
在 Python 中使用 datetime 库可以获取当前时间
12345678from datetime import datetime# 获取当前时间now = datetime.now()# 格式化时间为指定格式formatted_time = now.strftime("%Y-%m-%d %H:%M:%S %p")print(formatted_time) # 2024-12-09 20:34:41 PM
其中各个字母的含义如下
%Y - 年份(四位数,例如 2024)
%m - 月份(两位数,例如 01 到 12)
%d - 日期(两位数,例如 01 到 31)
%H - 小时(24 小时制,两位数,例如 00 到 23)
%I - 小时(12 小时制,两位数,例如 01 到 12)
%M - 分钟(两位数,例如 00 到 59)
%S - 秒数(两位数,例如 00 到 59)
%p - 上下午标识(AM 或 PM)
Python Logging
参考
logging — Python 的日志记录工具{target="_blank"}
通过 logging 调用
123456789101112131415161718192021222324import loggingdef main() -> None: # logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S") # logging.basicConfig(level=logging.WARNING) logging.debug("This is a debug message.") logging.info("This is a in ...
Python 类型提示
简介
Python 的类型提示(Type Hints)是在 Python 3.5 中正式引入的,作为 PEP 484 – Type Hints 的一部分。它是一种在代码中显式声明变量、参数、返回值等的类型信息的方式。它对代码运行没有实际影响,只是作为一种可选的静态类型标注,意在为开发者提供变量、参数和返回值的类型信息,方便代码阅读。
可以把类型提示当作一种注释,就是告诉开发者这个变量“应该”是提示的类型,但是 python 编译器本身不做类型检查,即便代码运行时这个变量不是标注的类型,也不会有任何提示。
普通变量的类型提示
在创建变量时可以进行类型提示
123x: int = 1y: float = 1.0z: str = "hello"
还可以对一些自定义的数据类型进行提示
1234import numpyimport torchx: numpy.ndarray = numpy.array([1, 2, 3])y: torch.tensor = torch.tensor([1, 2, 3])
函数参数和返回值类型
在函数定义中,可以为参数和返回值添加类型提示:
1 ...
关于image的shape是CxHxW还是HxWxC
H×W×C Or C×H×W
123456789101112131415from PIL import Imageimport numpy as npfrom torchvision import transformsimage_path = r"path/to/image"mask_path = r"path/to/mask"img = Image.open(image_path) # size: (width, height)mask = Image.open(mask_path) # size: (width, height)np_img = np.array(img) # shape: (height, width, 3)np_mask = np.array(img) # shape: (height, width, 3)ts_img = transforms.PILToTensor()(img) # shape: torch.Size([3, height, width])ts_mask = transforms.PILToTensor ...