smpanaro's picture
add model
691b3b7
program(1.0)
[buildInfo = dict<tensor<string, []>, tensor<string, []>>({{"coremlc-component-MIL", "3304.5.2"}, {"coremlc-version", "3304.6.2"}, {"coremltools-component-torch", "2.1.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.0b1"}})]
{
func main<ios16>(tensor<fp16, [1, 64, 1, 1024]> new_k_cache, tensor<fp16, [1, 1024, 1, 64]> new_v_cache, tensor<fp16, [1, 448, 1, 1024]> old_k_cache, tensor<fp16, [1, 1024, 1, 448]> old_v_cache) {
tensor<int32, []> var_6 = const()[name = tensor<string, []>("op_6"), val = tensor<int32, []>(-3)];
tensor<bool, []> cat_k_1_interleave_0 = const()[name = tensor<string, []>("cat_k_1_interleave_0"), val = tensor<bool, []>(false)];
tensor<fp16, [1, 512, 1, 1024]> cat_k_1_cast_fp16 = concat(axis = var_6, interleave = cat_k_1_interleave_0, values = (old_k_cache, new_k_cache))[name = tensor<string, []>("cat_k_1_cast_fp16")];
tensor<int32, []> var_9 = const()[name = tensor<string, []>("op_9"), val = tensor<int32, []>(-1)];
tensor<bool, []> cat_v_interleave_0 = const()[name = tensor<string, []>("cat_v_interleave_0"), val = tensor<bool, []>(false)];
tensor<fp16, [1, 1024, 1, 512]> cat_v_cast_fp16 = concat(axis = var_9, interleave = cat_v_interleave_0, values = (old_v_cache, new_v_cache))[name = tensor<string, []>("cat_v_cast_fp16")];
tensor<int32, [4]> var_20_begin_0 = const()[name = tensor<string, []>("op_20_begin_0"), val = tensor<int32, [4]>([0, 64, 0, 0])];
tensor<int32, [4]> var_20_end_0 = const()[name = tensor<string, []>("op_20_end_0"), val = tensor<int32, [4]>([1, 3072, 1, 1024])];
tensor<bool, [4]> var_20_end_mask_0 = const()[name = tensor<string, []>("op_20_end_mask_0"), val = tensor<bool, [4]>([true, false, true, true])];
tensor<fp16, [1, 448, 1, 1024]> updated_k_cache = slice_by_index(begin = var_20_begin_0, end = var_20_end_0, end_mask = var_20_end_mask_0, x = cat_k_1_cast_fp16)[name = tensor<string, []>("op_20_cast_fp16")];
tensor<int32, [4]> var_50_begin_0 = const()[name = tensor<string, []>("op_50_begin_0"), val = tensor<int32, [4]>([0, 0, 0, 64])];
tensor<int32, [4]> var_50_end_0 = const()[name = tensor<string, []>("op_50_end_0"), val = tensor<int32, [4]>([1, 1024, 1, 3072])];
tensor<bool, [4]> var_50_end_mask_0 = const()[name = tensor<string, []>("op_50_end_mask_0"), val = tensor<bool, [4]>([true, true, true, false])];
tensor<fp16, [1, 1024, 1, 448]> updated_v_cache = slice_by_index(begin = var_50_begin_0, end = var_50_end_0, end_mask = var_50_end_mask_0, x = cat_v_cast_fp16)[name = tensor<string, []>("op_50_cast_fp16")];
tensor<fp16, []> var_51_promoted_to_fp16 = const()[name = tensor<string, []>("op_51_promoted_to_fp16"), val = tensor<fp16, []>(0x1p+1)];
tensor<fp16, [1, 448, 1, 1024]> prod_cast_fp16 = mul(x = updated_k_cache, y = var_51_promoted_to_fp16)[name = tensor<string, []>("prod_cast_fp16")];
tensor<bool, []> var_53_keep_dims_0 = const()[name = tensor<string, []>("op_53_keep_dims_0"), val = tensor<bool, []>(false)];
tensor<fp16, []> ignore_me_im_only_here_so_this_runs_on_the_ane = reduce_min(keep_dims = var_53_keep_dims_0, x = prod_cast_fp16)[name = tensor<string, []>("op_53_cast_fp16")];
} -> (updated_k_cache, updated_v_cache, ignore_me_im_only_here_so_this_runs_on_the_ane);
}