MLIR Transform dialect에서 match 연산을 사용해 payload IR의 연산 체인을 찾고, 이름/체인/추론된 성질 기반 매칭과 사용자 정의 MatchOpInterface 연산을 구현하는 방법을 다룬다.
지속적으로 테스트되는 MLIR 파일 버전은 mlir/test/Examples/transform/Ch4에서 확인할 수 있습니다.
지금까지 우리는 transform dialect 인터프리터가 호출될 때 호출자가 특정 payload 연산들을 식별해 준다는 가정하에 transform dialect 스크립트를 적용해 왔습니다. 이는 변환 대상이 transform dialect 인터프리터 외부의 메커니즘(예: C++에서 인터프리터를 프로그래매틱하게 호출할 때나 이전 장에서 보았던 pass 인자)을 통해 식별되어야 한다는 점에서, dialect로부터 변환을 주도한다는 아이디어와는 상반되어 보일 수 있습니다. 또한 C++에서 인터프리터와의 상호작용이 늘어나며 실질적인 오버헤드가 증가하고, 두 개의 인터페이스를 동시에 다뤄야 하는 인지적 부담도 생깁니다. 이를 해소하기 위해 Transform dialect는 변환이 필요한 payload 연산을 매칭 하기 위한 연산의 부분집합을 제안합니다.
Match 연산은 몇 가지 추가적인 보장을 제공하는 transform 연산일 뿐입니다. 특히, payload IR을 수정하지 않을 것으로 기대되며, 그 피연산자(일반적으로 payload 연산 핸들)가 연산 이름이나 인자 종류 같은 원하는 성질을 가진 payload IR 객체에 연결되어 있지 않다면 실패하도록 설계되어 있습니다. 간단한 조합(combinator) 연산들을 사용하면, transform dialect 내부에서 직접 더 높은 수준의 match/rewrite 인프라를 구축할 수 있습니다.
편의를 위해 1장의 “fully connected layer” 예제를 다시 살펴보겠습니다.
mlir// 최적화 대상 원래 함수. func.func @fc_relu(%lhs: tensor<512x512xf32>, %rhs: tensor<512x512xf32>, %bias: tensor<512x512xf32>, %output: tensor<512x512xf32>) -> tensor<512x512xf32> { // 행렬-행렬 곱셈. %matmul = linalg.matmul ins(%lhs, %rhs: tensor<512x512xf32>, tensor<512x512xf32>) outs(%output: tensor<512x512xf32>) -> tensor<512x512xf32> // elementwise 덧셈. %biased = linalg.elementwise kind=#linalg.elementwise_kind<add> ins(%matmul, %bias : tensor<512x512xf32>, tensor<512x512xf32>) outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> // 0과의 elementwise max (ReLU). %c0f = arith.constant 0.0 : f32 %relued = linalg.elementwise kind=#linalg.elementwise_kind<max_signed> ins(%biased, %c0f : tensor<512x512xf32>, f32) outs(%output : tensor<512x512xf32>) -> tensor<512x512xf32> func.return %relued : tensor<512x512xf32> }
1장에서는 테스트용 transform 인터프리터 pass를 호출하면서 bind-first-extra-to-ops=linalg.matmul bind-second-extra-to-ops=linalg.elementwise 같은 추가 인자를 사용하여 연산 핸들에 대한 초기 연결을 제공했습니다. 대신, match 연산을 사용해 payload IR에서 관련된 연산들을 찾아낼 수 있습니다. match 연산은 예를 들어 named sequence 개념을 활용하는 transform.collect_matching 조합 연산으로 “일반적인” transform 연산과 결합할 수 있습니다.
mlir// named sequence를 포함하는 모듈은 검증을 활성화하기 위한 // 속성을 가져야 합니다. module @transforms attributes { transform.with_named_sequence } { // 진입점. 이는 transform 인터프리터에 전달되는 루트 연산(일반적으로 // pass 루트)을 유일한 인자로 받습니다. transform.named_sequence @__transform_main( %root: !transform.any_op {transform.readonly}) { // named sequence에 지정된 기준에 맞는 연산들을 수집합니다. // named sequence가 silenceable failure와 함께 실패하면 이를 무시하고 // (메시지는 디버그 스트림으로 전달됩니다), 성공하면 이 연산의 // 결과에 그 결과들을 이어붙입니다. %elemwise = transform.collect_matching @match_elemwise in %root : (!transform.any_op) -> !transform.any_op %matmul = transform.collect_matching @match_matmul in %root : (!transform.any_op) -> !transform.any_op transform.include @print_elemwise failures(propagate) (%elemwise) : (!transform.any_op) -> () transform.include @print_matmul failures(propagate) (%matmul) : (!transform.any_op) -> () transform.yield } // 이것은 matcher sequence입니다. 매칭 대상 연산을 인자로 받고, // 중첩된 연산 중 하나라도 실패를 내면 매칭이 실패한 것으로 // 간주됩니다. 이 연산에서 yield된 값들은 성공 시 rewriter sequence에 // 전달됩니다. transform.named_sequence @match_elemwise( %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { transform.match.operation_name %entry ["linalg.elementwise"] : !transform.any_op transform.yield %entry : !transform.any_op } transform.named_sequence @match_matmul( %entry: !transform.any_op {transform.readonly}) -> !transform.any_op { transform.match.operation_name %entry ["linalg.matmul"] : !transform.any_op transform.yield %entry : !transform.any_op } // 이것은 rewriter sequence입니다. transform.named_sequence @print_elemwise( %elemwise_binary: !transform.any_op {transform.readonly}) { transform.debug.emit_remark_at %elemwise_binary, "elementwise binary" : !transform.any_op transform.yield } transform.named_sequence @print_matmul( %matmul: !transform.any_op {transform.readonly}) { transform.debug.emit_remark_at %matmul, "matmul" : !transform.any_op transform.yield } }
이 스크립트는 번역 단위의 루트 연산에 대해 non-test 인터프리터 pass를 사용해, 추가 플래그 없이 mlir-opt --transform-interpreter로 실행할 수 있습니다. 그러면 linalg.elementwise와 linalg.matmul 연산에서 각각의 remark가 출력됩니다. 디버그 빌드에서는 mlir-opt 또는 파생 도구에 -debug-only=transform-matcher를 전달하여 매칭 과정을 이해하기 위한 편리한 방법을 제공합니다. 그러면 match 연산이 생성한 silenceable failure 메시지가 디버그 스트림에 출력됩니다. 예를 들어:
text<...> [transform-matcher] matching %0 = linalg.matmul ins(%arg0, %arg1 : tensor<512x512xf32>, tensor<512x512xf32>) outs(%arg3 : tensor<512x512xf32>) -> tensor<512x512xf32> @0x5622eee08410 [transform-matcher] matcher match_elemwise failed: wrong operation name <...>
이제 1장의 나머지 transform 스크립트를 실행하기에 충분하며, 단지 %arg1을 %matmul로, %arg2를 %elemwise로 바꿔주면 됩니다.
위의 matcher는 payload 루트 아래의 모든 해당 종류의 연산을 매칭하기 때문에 여전히 단순한 수준입니다. 이 연산들은 서로 관련이 있을 수도, 없을 수도 있고, 예를 들어 서로 다른 함수에 속해 있을 수도 있습니다. 설령 하나의 함수 안에 있더라도, 이런 연산들이 여러 그룹으로 존재하면 이 접근법으로는 서로 구분할 수 없습니다. 실제로 우리가 원하는 것은 matmul 연산의 결과가 어떤 elementwise 연산에 사용되고, 그 결과가 다시 다른 elementwise 연산에 유사한 방식으로 전달되는 특정 연산 그룹을 매칭하는 것입니다.
이는 다음과 같은 matcher sequence로 구현할 수 있습니다.
mlir// 이것 역시 matcher sequence입니다. 마찬가지로 매칭 대상 연산을 인자로 // 받고, 중첩 연산이 모두 성공해야 매칭이 성공으로 간주됩니다. // 사용-정의(use-def) 체인의 마지막 연산에서 시작해 역방향으로 // 매칭하는데, 각 피연산자(use)는 정확히 하나의 정의를 갖기 때문입니다. transform.named_sequence @match_matmul_elemwise( %last: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_op, !transform.any_op) { // 마지막 연산은 elementwise binary여야 합니다. transform.match.operation_name %last ["linalg.elementwise"] : !transform.any_op // 첫 번째 피연산자는 다른 연산에 의해 정의되어야 하며, 여기에 // 그 연산에 대한 핸들을 가져옵니다. 이 연산이 이항 연산이라는 걸 // 알고 있으므로 첫 번째 피연산자가 존재한다는 보장이 있지만, // 그러한 보장이 없더라도 `%last`에 충분한 피연산자가 없을 때는 // 이 연산이 silenceable failure를 발생시켰을 것입니다. %middle = transform.get_producer_of_operand %last[0] : (!transform.any_op) -> !transform.any_op // 그 정의 연산 역시 elementwise binary여야 합니다. transform.match.operation_name %middle ["linalg.elementwise"] : !transform.any_op // 그리고 그 연산의 첫 번째 피연산자는 또 다른 연산에 의해 // 정의되어야 합니다. %matmul = transform.get_producer_of_operand %middle[0] : (!transform.any_op) -> !transform.any_op // 그 연산은 matmul이어야 합니다. transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op // matmul과 두 개의 elementwise 연산에 대한 핸들을 개별적으로 // yield합니다. transform.yield %matmul, %middle, %last : !transform.any_op, !transform.any_op, !transform.any_op }
이 matcher는 다른 elemwise 및 matmul 연산이 존재하는 상황에서도 올바르게 동작하며, 발견된 순서가 아니라 서로 관련된 연산들의 triple을 반환합니다. 다음과 같이 이전 버전과 유사한 방식으로 사용할 수 있습니다.
mlir// 대체 진입점. transform.named_sequence @__transform_main( %root: !transform.any_op {transform.readonly}) { // named sequence에 지정된 기준에 맞는 연산 그룹들을 수집합니다. %matmul, %el1, %el2 = transform.collect_matching @match_matmul_elemwise in %root : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) %elemwise = transform.merge_handles %el1, %el2 : !transform.any_op transform.include @print_elemwise failures(propagate) (%elemwise) : (!transform.any_op) -> () transform.include @print_matmul failures(propagate) (%matmul) : (!transform.any_op) -> () transform.yield }
연산 체인의 matcher는 다른 연산들의 존재 하에서도 올바르지만, 여전히 많은 관심 사례에 대해 충분히 견고하지는 않습니다. 특히 transform.get_producer_of_operand %last[0]을 사용하면 elementwise 연산의 첫 번째 피연산자가 다른 연산에 의해 생성되어 있어야 합니다. 하지만 같은 변환 전략이 피연산자 위치와 무관하게 적용될 수도 있습니다. 많은 이항 연산은 결합법칙을 따르기 때문입니다. 이를 계기로 새로운 match 연산을 소개해 보겠습니다. 구체적으로, 이 연산은 피연산자들 중 어느 하나라도 다른 match 연산으로 표현할 수 있는 특정 조건을 만족하면 성공하도록 만들고자 합니다. 또한 매칭된 피연산자의 위치와, 그 과정에서 얻은 일부 상태를 반환하길 원합니다.
Match 연산은 다른 transform 연산과 유사하게 정의되며, 추가로 MatchOpInterface를 구현한다는 점만 다릅니다. 이 인터페이스에는 추가 메서드가 전혀 없으며 (향후 추가될 수는 있지만) 이 연산이 매칭을 위한 것이며 payload를 변환하려 하지 않을 것이라는 검증 계약으로만 사용된다는 점에 유의해야 합니다. 우리의 연산에 대한 최소 정의는 다음과 같습니다.
tablegen// 새로운 연산을 정의합니다. 관례상 이름은 `match`와 dialect 확장 // 이름을 이어 붙여서 만듭니다. def HasOperandSatisfyingOp : TransformDialectOp<"match.my.has_operand_satisfying", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, DeclareOpInterfaceMethods<TransformOpInterface>, // 이 연산이 TransformOpInterface에 더해 MatchOpInterface도 구현함을 // 나타냅니다. 이 인터페이스는 현재 태그 용도로만 사용되며, // 반드시 구현해야 하는 메서드는 없습니다. MatchOpInterface, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { let summary = "Succeed if any of the operands matches all nested criteria"; let arguments = (ins TransformHandleTypeInterface:$op); let results = (outs TransformParamTypeInterface:$position, Variadic<Transform_AnyHandleOrParamType>:$results); // Match 연산은 영역(region)을 포함해 임의로 복잡할 수 있습니다. let regions = (region SizedRegion<1>:$body); let hasVerifier = 1; let assemblyFormat = [{ $op `:` functional-type($op, results) attr-dict-with-keyword $body }]; }
이 연산은 피연산자를 매칭할 payload 연산에 연결된 핸들을 인자로 받고, 매칭 기준을 담은 단일 블록(region)을 하나 가지며, 매칭에 성공하면 매칭된 피연산자의 위치와 body에서 yield된 임의의 transform 값을 반환합니다.
매칭 로직은 TransformOpInterface의 apply 메서드 안에 구현되며, 다른 transform 연산과 손쉽게 조합할 수 있습니다. 인터프리터 상태를 관리하고 블록에 재귀적으로 진입하는 모든 기능이 “일반” transform 연산에서와 완전히 동일한 방식으로 제공됩니다. Match 연산은 매칭 실패를 나타낼 때 silenceable failure를 반환해야 하며, definite failure는 즉시 전파해야 합니다. 중첩된 연산이 있다면, 이들이 적용될 때 발생하는 silenceable failure를 적절히 처리하고, 대부분의 경우 이를 무시(silence)해야 합니다. 우리의 연산에서는 본질적으로 (단일) payload 연산의 모든 피연산자를 순회하며, 어떤 피연산자에 대해 중첩된 transform 연산들이 모두 성공할 때까지 반복합니다.
c// Matcher op는 다른 transform op와 유사하게 `apply`를 구현합니다. // payload를 수정해서는 안 되며, tri-state 결과를 사용해 매칭 실패나 // 성공, 그리고 잠재적인 회복 불가능 오류를 표현합니다. mlir::DiagnosedSilenceableFailure mlir::transform::HasOperandSatisfyingOp::apply( mlir::transform::TransformRewriter &rewriter, mlir::transform::TransformResults &results, mlir::transform::TransformState &state) { // 단순화를 위해 단일 payload op만 처리합니다. 실제 구현에서는 // `SingleOpMatcher` trait을 사용해 구현을 단순화하고, 이러한 기대를 // 문서화할 수 있습니다. auto payloadOps = state.getPayloadOps(getOp()); if (!llvm::hasSingleElement(payloadOps)) return emitSilenceableError() << "expected single payload"; // body를 사용해 매칭할 수 있는지 보기 위해 payload op의 모든 // 피연산자를 순회합니다. Operation *payload = *payloadOps.begin(); for (OpOperand &operand : payload->getOpOperands()) { // body에서 정의된 transform 값에 대한 스코프를 생성합니다. 이는 // 이 op에 연결된 region의 문법적 스코프에 해당합니다. 이제부터 // payload에 연결되는 모든 값은 이 객체가 파괴될 때, 즉 이 // 반복이 끝날 때 자동으로 연결이 해제됩니다. // 블록 인자 핸들을 해당 피연산자에 연결합니다. auto matchScope = state.make_region_scope(getBody()); if (failed(state.mapBlockArgument(getBody().getArgument(0), {operand.get()}))) { return DiagnosedSilenceableFailure::definiteFailure(); } // 현재 매핑을 사용해 모든 중첩 matcher를 순회하며 성공하는지 // 확인합니다. bool matchSucceeded = true; for (Operation &matcher : getBody().front().without_terminator()) { // Matcher op는 다른 transform op와 동일한 방식으로 적용됩니다. DiagnosedSilenceableFailure diag = state.applyTransform(cast<TransformOpInterface>(matcher)); // definite failure는 회복 불가능하므로 즉시 전파합니다. if (diag.isDefiniteFailure()) return diag; // 성공했다면 나머지 조건들을 계속 검사합니다. if (diag.succeeded()) continue; // 디버깅을 위해 매칭 실패를 보고하고, 이 피연산자에 대한 // 매칭을 중단합니다. assert(diag.isSilenceableFailure()); DEBUG_MATCHER(DBGS_MATCHER() << "failed to match operand #" << operand.getOperandNumber() << ": " << diag.getMessage()); (void)diag.silence(); matchSucceeded = false; break; } // 이 피연산자 매칭에 실패했다면 다른 피연산자들을 시도합니다. if (!matchSucceeded) continue; // 여기까지 도달했다면 현재 피연산자에 대해 매칭이 성공한 것입니다. // terminator 피연산자에 연결된 값들을 op 결과에 연결되도록 다시 // 매핑하고, 파라미터 결과에는 피연산자의 위치를 설정합니다. // `results`는 인터프리터가 `apply`가 반환된 후에야 `state`에 // 통합하므로, 스코프가 끝나기 직전이어도 여기서 이를 수행하는 것이 // 안전합니다. SmallVector<SmallVector<MappedValue>> yieldedMappings; transform::detail::prepareValueMappings( yieldedMappings, getBody().front().getTerminator()->getOperands(), state); results.setParams(cast<OpResult>(getPosition()), {rewriter.getI32IntegerAttr(operand.getOperandNumber())}); for (auto &&[result, mapping] : llvm::zip(getResults(), yieldedMappings)) results.setMappedValues(result, mapping); return DiagnosedSilenceableFailure::success(); } // 여기까지 도달했다면 어느 피연산자도 매칭에 성공하지 못한 것입니다. return emitSilenceableError() << "none of the operands satisfied the conditions"; }
관례상, MatchOpInterface를 구현하는 연산은 payload IR을 수정해서는 안 되며, 따라서 피연산자 핸들과 payload를 읽기만 한다는 효과(effects)를 명시해야 합니다.
cvoid transform::CollectMatchingOp::getEffects( SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { onlyReadsHandle(getRoot(), effects); producesHandle(getResults(), effects); onlyReadsPayload(effects); }
이제 이 연산을 transform dialect 확장에 포함시키고, 로드하여 matcher에서 사용할 수 있습니다. 특히, 예제에서 “max” elementwise 연산의 어느 피연산자든 이전 elementwise 연산에 의해 생성될 수 있음을 표현하는 데 사용할 것입니다. 단순화를 위해, 이전 연산은 여전히 matmul이 첫 번째 피연산자를 생성해야 합니다. 갱신된 matcher sequence는 다음과 같습니다.
mlirtransform.named_sequence @match_matmul_elemwise( %last: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.param<i32>) { // 마지막 연산은 elementwise binary여야 합니다. transform.match.operation_name %last ["linalg.elementwise"] : !transform.any_op // 피연산자들 중 하나는 다른 연산에 의해 정의되어야 하며, 여기에 // 그 연산에 대한 핸들을 가져옵니다. 이는 새로 정의한 연산 덕분에 // 가능하며, 이 연산은 region에 중첩된 match 연산들을 사용해 // 피연산자들을 하나씩 매칭합니다. %pos, %middle = transform.match.my.has_operand_satisfying %last : (!transform.any_op) -> (!transform.param<i32>, !transform.any_op) { ^bb0(%operand: !transform.any_value): // 피연산자는 어떤 연산에 의해 정의되어야 합니다. %def = transform.get_defining_op %operand : (!transform.any_value) -> !transform.any_op // 그 정의 연산 역시 elementwise binary여야 합니다. transform.match.operation_name %def ["linalg.elementwise"] : !transform.any_op transform.yield %def : !transform.any_op } // 그리고 그 연산의 첫 번째 피연산자는 또 다른 연산에 의해 // 정의되어야 합니다. %matmul = transform.get_producer_of_operand %middle[0] : (!transform.any_op) -> !transform.any_op // 그 연산은 matmul이어야 합니다. transform.match.operation_name %matmul ["linalg.matmul"] : !transform.any_op // matmul과 두 개의 elementwise 연산에 대한 핸들을 개별적으로 // yield합니다. transform.yield %matmul, %middle, %last, %pos : !transform.any_op, !transform.any_op, !transform.any_op, !transform.param<i32> }
이로써 max(add(matmul(...), bias), 0)와 max(0, add(matmul(...), bias)) 두 경우를 동일한 값들로 매칭하는 원하는 효과를 얻었습니다. %pos 값은 transform dialect의 파라미터(parameter) 로, 변환 적용 과정 전체에서 상수임이 알려진 엔티티들의 리스트를 저장하는 데 사용됩니다. 대부분의 경우 파라미터는 수치 값이지만, 일반적으로는 임의의 MLIR attribute가 될 수 있습니다.
여러 연산 그룹이 서로 독립적으로 매칭된다는 점을 보여주기 위해, transform dialect 내부에서 간단한 고수준 패턴 rewriting 방식을 구현할 수 있는 transform.foreach_match 연산을 사용해 보겠습니다(고급 또는 저수준 패턴 rewriting에는 PDL(L)이나 C++ rewriting API를 고려하세요). 이 연산은 matcher named sequence를 action named sequence에 매핑하고, 전자가 성공할 때마다 후자를 호출합니다.
mlir// 피연산자 핸들과 연결된 payload IR을 순회하며, 각 연산에 대해 // @match_matmul_elemwise를 호출합니다. named sequence가 성공, // 즉 중첩된 match(transform) 연산 중 어느 것도 silenceable failure를 // 발생시키지 않으면 @print_matmul_elemwise를 호출하고, 그때 yield된 // 값들을 새 호출의 인자로 전달합니다. named sequence가 // silenceable failure와 함께 실패하면, 이를 무시(silence)하고 // (메시지는 디버그 스트림으로 전달됩니다), definite failure는 // 다른 경우와 마찬가지로 즉시 무조건 전파합니다. transform.foreach_match in %root @match_matmul_elemwise -> @print_matmul_elemwise : (!transform.any_op) -> !transform.any_op
multiple.mlir에 정의된 @print_matmul_elemwise named sequence는 피연산자의 위치를 나타내는 파라미터를 사용하여 두 그룹을 구분합니다.
지금까지 설명한 matcher sequence들은 transform dialect 인터프리터 내부에서 변환을 구동하는 데 유용하지만, 주로 연산 이름과 use-def 체인에 의존하기 때문에 상당히 기초적인 수준입니다. 다양한 선언적 rewrite 규칙이나 API를 사용한 대안 구현도 표현력 면에서 크게 떨어지지 않고, 때로는 더 간결하기도 합니다. transform dialect matcher 연산의 진정한 강점은 payload의 추론된 성질(inferred properties) 을 매칭하는 matcher를 정의할 수 있다는 데 있습니다. 즉, 연산의 attribute나 IR 구성 요소 간의 단순한 관계로는 직접 접근할 수 없는 성질들을 매칭할 수 있다는 뜻입니다.
이러한 matcher의 유용성은 원래 예제를 약간 수정해도 쉽게 드러납니다. 행렬 곱셈이 linalg.matmul 대신 linalg.generic을 사용한 텐서 contraction의 특수한 경우로 표현된다면, 연산 이름 기반 matcher는 더 이상 사용할 수 없습니다. 그러나 이런 표현은 매우 일반적이며, 초기 입력뿐 아니라 변환 과정 중에도 자주 등장할 수 있습니다. 예를 들어, 더 높은 차원의 contraction을 행렬 곱셈 주위의 루프로 분해하는 경우가 이에 해당합니다.
(잠재적으로 전치된) 행렬 곱셈이 되기 위해 linalg.generic 연산이 가져야 할 특징은 다음과 같습니다.
이들 특징 대부분은 연산의 속성에서 도출할 수 있습니다. 예를 들어 총 rank는 iterators attribute의 엔트리 개수에 대응합니다. 하지만 이들 중 거의 아무것도 IR이나 선언적 형식만으로 바로 접근할 수는 없습니다. 선언적 형식은 보통 attribute나 타입의 존재 여부 또는 정확한 일치를 검사하는 수준에 그치기 때문입니다. transform dialect는 이러한 특징들을 matcher op의 apply 메서드 안에서 구현하고, 여러 매칭 경우에 걸쳐 재사용하도록 허용합니다. 구조적 선형대수 payload 연산에 대해, 이러한 match 연산 중 상당수는 이미 structured 확장에 준비되어 있습니다. 이들만으로도 위에서 나열한 특징을 거의 그대로 활용해 행렬 곱셈 matcher를 구현할 수 있습니다.
mlirtransform.named_sequence @match_generic_matmul( %candidate: !transform.any_op {transform.readonly}) -> !transform.any_op { // 구조적 선형대수 연산을 매칭합니다. transform.match.structured %candidate : !transform.any_op { ^bb0(%c: !transform.any_op): // rank가 3이어야 합니다. %rank = transform.match.structured.rank %c : (!transform.any_op) -> !transform.param<i64> %c3 = transform.param.constant 3 : i64 -> !transform.param<i64> transform.match.param.cmpi eq %rank, %c3 : !transform.param<i64> // 입력이 2개여야 합니다. %n_ins = transform.match.structured.num_inputs %c : (!transform.any_op) -> !transform.param<i64> %c2 = transform.param.constant 2 : i64 -> !transform.param<i64> transform.match.param.cmpi eq %n_ins, %c2 : !transform.param<i64> // 출력은 1개여야 합니다(목적지 전달 스타일의 structured op는 // output 개수만큼 init을 가집니다). %n_inits = transform.match.structured.num_inits %c : (!transform.any_op) -> !transform.param<i64> %c1 = transform.param.constant 1 : i64 -> !transform.param<i64> transform.match.param.cmpi eq %n_inits, %c1 : !transform.param<i64> // 모든 입력과 init은 projected permutation으로 접근해야 합니다. transform.match.structured.input %c[all] {projected_permutation} : !transform.any_op transform.match.structured.init %c[0] {projected_permutation} : !transform.any_op // body는 적절한 차원을 가진 mulf/addf contraction이어야 합니다. transform.match.structured.body %c { contraction = ["arith.mulf", "arith.addf"] } : !transform.any_op %batch, %lhs, %rhs, %reduction = transform.match.structured.classify_contraction_dims %c : (!transform.any_op) -> (!transform.param<i64>, !transform.param<i64>, !transform.param<i64>, !transform.param<i64>) // lhs, rhs, reduction 차원은 각각 하나씩, batch 차원은 없어야 합니다. %n_batch = transform.num_associations %batch : (!transform.param<i64>) -> !transform.param<i64> %n_lhs = transform.num_associations %lhs : (!transform.param<i64>) -> !transform.param<i64> %n_rhs = transform.num_associations %rhs : (!transform.param<i64>) -> !transform.param<i64> %n_reduction = transform.num_associations %reduction : (!transform.param<i64>) -> !transform.param<i64> %c0 = transform.param.constant 0 : i64 -> !transform.param<i64> transform.match.param.cmpi eq %n_batch, %c0 : !transform.param<i64> transform.match.param.cmpi eq %n_lhs, %c1 : !transform.param<i64> transform.match.param.cmpi eq %n_rhs, %c1 : !transform.param<i64> transform.match.param.cmpi eq %n_reduction, %c1 : !transform.param<i64> } transform.yield %candidate : !transform.any_op }
이 예제는 contraction 전용 matcher를 활용하는데, 이 matcher들은 상당히 비(非)자명한 C++ 구현을 가지고 있습니다. 하지만 transform dialect는 원한다면 이러한 추론을 직접 구현할 만큼 충분히 유연합니다. 예를 들어, 각 입력의 access map을 파라미터로 가져온 다음, 또 다른 파라미터로 접근 차원을 추출하여 서로 비교함으로써, 루프가 m,n,k라는 표기일 때 LHS의 인덱스가 m,k, RHS가 k,n, init/결과가 m,n인지 확인할 수 있습니다.
transform.match.my.has_operand_satisfying (transform::HasOperandSatisfyingOp)¶피연산자들 중 하나라도 모든 중첩 기준을 만족하면 성공
구문(Syntax):
textoperation ::= `transform.match.my.has_operand_satisfying` $op `:` functional-type($op, results) attr-dict-with-keyword $body
Traits: SingleBlockImplicitTerminator<::mlir::transform::YieldOp>, SingleBlock
Interfaces: MatchOpInterface, MemoryEffectOpInterface, TransformOpInterface
| Operand | 설명 |
|---|---|
op | TransformHandleTypeInterface 인스턴스 |
| Result | 설명 |
|---|---|
position | TransformParamTypeInterface 인스턴스 |
results | 임의의 transform handle 또는 parameter의 가변 인자 |