def__check_symmetric_quant(quant_cls:_SymmetryQuant,float_func:Callable,input_amax:float,bit:int,narrow:bool)->bool:"""Check whether the output of quant_cls is correct Args: quant_cls (_SymmetryQuant): an symmetric quantization operator float_func (Callable): ground truth function in floating-point input_amax (float): the amax of input for quantization bit (int): the bit number of quantization narrow (bool): Ture: quant_min = -2^(bit - 1) + 1, False: quant_min = 2^(bit - 1) Returns: bool: Ture: all elements of quantization function output are correct, False: any elements is wrong """input_shape=(1,128)quant_input=torch.randint(utils.quant_min(bit,narrow),utils.quant_max(bit)+1,input_shape,dtype=torch.int8)# (quant_input) -> quant_func -> (quant_output)quant_func=quant_cls(input_amax,bit,narrow)quant_output=quant_func(quant_input)# (quant_input) -> DQ -> float_func -> Q -> (quant_output)ground_truth_float_input=utils.dequantize(quant_input,quant_func.input_scale)ground_truth_float_output=float_func(ground_truth_float_input)ground_truth_quant_output=utils.quantize(ground_truth_float_output,quant_func.output_scale,bit,narrow)# every element should be the samereturn(quant_output==ground_truth_quant_output).all()
class_SymmetryQuant(torch.nn.Module):def__init__(self,func:Callable,input_amax:float,bit:int,narrow:bool=False,output_amax:float=None)->None:"""Initialize quant-input to quant-output mapping table for symmetry quantization. Args: func (Callable): corresponding standard floating-point function input_amax (float): the amax of input for quantization bit (int): the bit number narrow (bool, optional): True: quant_min = -2^(bit - 1) + 1. Defaults to False, quant_min = -2^(bit - 1) output_amax (float, optional): the amax of output for quantization. Defaults to None, the amax = amax(nonlinear(DQ(quant_input))) """super().__init__()# (input_quant) -> DQ -> (input_float)self.__input_scale=input_amax/quant_max(bit)input_quant=torch.arange(quant_min(bit,narrow),quant_max(bit)+1,dtype=torch.int8)input_float=input_quant*self.__input_scale# (input_float) -> float_func -> Q -> (output_quant)output_float=func(input_float)output_amax=output_amaxifoutput_amaxelsetorch.absolute(output_float).max()self.__output_scale=output_amax/quant_max(bit)output_quant=quantize(output_float,self.__output_scale,bit,narrow)# adjust sequence of output_quant for easier retrieveindex=quant_max(bit)ifnarrowelsequant_max(bit)+1self._table=torch.cat((output_quant[index:],output_quant[:index]))defforward(self,x:torch.Tensor):y=self._table[x.to(torch.int64)]returny@propertydefinput_scale(self):returnself.__input_scale@propertydefoutput_scale(self):returnself.__output_scale